Source code for RAI.AISystem.model

# Copyright 2022 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

__all__ = ['Model']
import RAI.dataset


[docs] class Model: """ Model is RAIs abstraction for the ML Model performing inferences. When constructed, models are optionally passed the name, the models functions for inferences, its name, the model, its optimizer, its loss function, its class and a description. Attributes of the model are used to determine which metrics are relevant. """ def __init__(self, output_features=None, predict_fun=None, predict_prob_fun=None, generate_text_fun=None, generate_image_fun=None, name=None, display_name=None, agent=None, loss_function=None, optimizer=None, model_class=None, description=None) -> None: assert name is not None, "Please provide a model name" self.output_types = {} self.predict_fun = predict_fun if isinstance(output_features, RAI.dataset.Feature): output_features = [output_features] elif output_features is not None: assert "output_features must be a Feature or array of Features" self.output_features = output_features if output_features is not None else [] if predict_fun is not None: self.output_types["predict"] = predict_fun self.predict_prob_fun = predict_prob_fun if predict_prob_fun is not None: self.output_types["predict_proba"] = predict_prob_fun self.generate_text_fun = generate_text_fun self.generate_image_fun = generate_image_fun if generate_text_fun is not None: self.output_types["generate_text"] = generate_text_fun if generate_image_fun is not None: self.output_types["generate_image"] = generate_image_fun self.name = name self.display_name = name if display_name is not None: self.display_name = display_name self.agent = agent self.input_type = None self.loss_function = loss_function self.optimizer = optimizer self.model_class = model_class self.description = description