Source code for RAI.Analysis.analysis

# Copyright 2022 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0


from RAI.AISystem import AISystem
from abc import ABC, abstractmethod
from .analysis_registry import register_class
from RAI.utils import compare_runtimes
import json


[docs] class Analysis(ABC): def __init__(self, ai_system: AISystem, dataset: str, tag: str = None): self.ai_system = ai_system self.analysis = {} self.dataset = dataset self.tag = None self.max_progress_tick = 5 self.current_tick = 0 self.connection = None # connection is an optional function passed to share progress with dashboard print("Analysis created") def __init_subclass__(cls, class_location=None, **kwargs): super().__init_subclass__(**kwargs) config_file = class_location[:-2] + "json" cls.config = json.load(open(config_file)) register_class(cls.__name__, cls)
[docs] @classmethod def is_compatible(cls, ai_system, dataset: str): """ :param ai_system: input the ai_system object :param dataset: input the dataset :return: class object Returns the classifier and sklearn object data """ compatible = cls.config["compatibility"]["task_type"] is None \ or cls.config["compatibility"]["task_type"] == [] \ or ai_system.task in cls.config["compatibility"]["task_type"] \ or ("classification" in cls.config["compatibility"]["task_type"] and ai_system.task == "binary_classification") # noqa: W503 compatible = compatible and (cls.config["compatibility"]["data_type"] is None or cls.config["compatibility"][ "data_type"] == [] or all(item in ai_system.meta_database.data_format for item in cls.config["compatibility"]["data_type"])) compatible = compatible and (cls.config["compatibility"]["output_requirements"] is None or # noqa: W503, W504 all(item in ai_system.data_dict for item in cls.config["compatibility"]["output_requirements"])) # noqa: W503, W504 compatible = compatible and (cls.config["compatibility"]["dataset_requirements"] is None or # noqa: W503, W504 all(item in ai_system.meta_database.stored_data for item in cls.config["compatibility"]["dataset_requirements"])) # noqa: W503, W504 compatible = compatible and (cls.config["compatibility"]["data_requirements"] == [] or # noqa: W503, W504 all(type(item).__name__ in cls.config["compatibility"]["data_requirements"] for item in ai_system.dataset.data_dict.values())) compatible = compatible and compare_runtimes(ai_system.metric_manager.user_config.get("time_complexity"), cls.config["complexity_class"]) compatible = compatible and all(group in ai_system.get_metric_values()[dataset] for group in cls.config["compatibility"]["required_groups"]) return compatible
[docs] def progress_percent(self, percentage_complete): """ :parameter: percentage_complete :return: None Shows the progress percent value """ percentage_complete = int(percentage_complete) if self.conncetion is not None: self.connection(str(percentage_complete))
[docs] def progress_tick(self): """ :parameter: None :return: None On every compute it changes the current_tick value """ self.current_tick += 1 percentage_complete = min(100, int(100 * self.current_tick / self.max_progress_tick)) if self.connection is not None: self.connection(str(percentage_complete))
# connection is a function that accepts progress, and pings the dashboard
[docs] def set_connection(self, connection): """ :param connection: inputs connection data :return: None Connection is a function that accepts progress, and pings the dashboard """ self.connection = connection
[docs] @abstractmethod def initialize(self): pass
[docs] @abstractmethod def to_string(self): pass
[docs] @abstractmethod def to_html(self): pass