From 9f4ab9060078c9f866b547f98d529758edb03e33 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Sat, 20 Sep 2025 22:52:04 +0300 Subject: [PATCH 1/2] light fast removal of register_all_artifacts for unitxt classes Signed-off-by: dafnapension --- docs/catalog.py | 10 +- src/unitxt/__init__.py | 4 +- src/unitxt/artifact.py | 113 +++++++++++++------- src/unitxt/register.py | 27 +---- src/unitxt/text_utils.py | 13 +++ tests/library/test_artifact_recovery.py | 85 ++++++++++++++- tests/library/test_artifact_registration.py | 4 +- 7 files changed, 182 insertions(+), 74 deletions(-) diff --git a/docs/catalog.py b/docs/catalog.py index 0d06d5d54b..bd1b804d41 100644 --- a/docs/catalog.py +++ b/docs/catalog.py @@ -10,7 +10,7 @@ from pygments import highlight from pygments.formatters import HtmlFormatter from pygments.lexers import PythonLexer -from unitxt.artifact import Artifact +from unitxt.artifact import get_class_from_artifact_type from unitxt.text_utils import print_dict_as_python from unitxt.utils import load_json @@ -51,7 +51,7 @@ def imports_to_syntax_highlighted_html(subtypes: List[str]) -> str: return "" module_to_class_names = defaultdict(list) for subtype in subtypes: - subtype_class = Artifact._class_register.get(subtype) + subtype_class = get_class_from_artifact_type(subtype) module_to_class_names[subtype_class.__module__].append(subtype_class.__name__) imports_txt = "" @@ -150,7 +150,7 @@ def recursive_search(d): @lru_cache(maxsize=None) def artifact_type_to_link(artifact_type): - artifact_class = Artifact._class_register.get(artifact_type) + artifact_class = get_class_from_artifact_type(artifact_type) type_class_name = artifact_class.__name__ artifact_class_id = f"{artifact_class.__module__}.{type_class_name}" return f'{type_class_name}' @@ -159,7 +159,7 @@ def artifact_type_to_link(artifact_type): # flake8: noqa: C901 def make_content(artifact, label, all_labels): artifact_type = artifact["__type__"] - artifact_class = Artifact._class_register.get(artifact_type) + artifact_class = get_class_from_artifact_type(artifact_type) type_class_name = artifact_class.__name__ catalog_id = label.replace("catalog.", "") @@ -243,7 +243,7 @@ def make_content(artifact, label, all_labels): result += artifact_class.__doc__ + "\n" for subtype in subtypes: - subtype_class = Artifact._class_register.get(subtype) + subtype_class = get_class_from_artifact_type(subtype) subtype_class_name = subtype_class.__name__ if subtype_class.__doc__: explanation_str = f"Explanation about `{subtype_class_name}`" diff --git a/src/unitxt/__init__.py b/src/unitxt/__init__.py index 37552072b9..0e03e9b3c0 100644 --- a/src/unitxt/__init__.py +++ b/src/unitxt/__init__.py @@ -11,10 +11,10 @@ ) from .catalog import add_to_catalog, get_from_catalog from .logging_utils import get_logger -from .register import register_all_artifacts, register_local_catalog +from .register import ProjectArtifactRegisterer, register_local_catalog from .settings_utils import get_constants, get_settings -register_all_artifacts() +ProjectArtifactRegisterer() random.seed(0) constants = get_constants() diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index e1ccae320e..f77bb1e669 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -1,9 +1,10 @@ -import difflib +import importlib import inspect import json import os import pkgutil import re +import subprocess import warnings from abc import abstractmethod from typing import Any, Dict, List, Optional, Tuple, Union, final @@ -22,7 +23,7 @@ separate_inside_and_outside_square_brackets, ) from .settings_utils import get_constants, get_settings -from .text_utils import camel_to_snake_case, is_camel_case +from .text_utils import camel_to_snake_case, is_camel_case, snake_to_camel_case from .type_utils import isoftype, issubtype from .utils import ( artifacts_json_cache, @@ -36,6 +37,53 @@ constants = get_constants() +def import_module_from_file(file_path): + # Get the module name (file name without extension) + module_name = os.path.splitext(os.path.basename(file_path))[0] + # Create a module specification + spec = importlib.util.spec_from_file_location(module_name, file_path) + # Create a new module based on the specification + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +# type is read from a catelog entry, the value of a key "__type__" +def get_class_from_artifact_type(type: str): + if type in Artifact._class_register: + return Artifact._class_register[type] + + module_path, class_name = find_unitxt_module_and_class_by_classname( + snake_to_camel_case(type) + ) + if module_path == "class_register": + if class_name not in Artifact._class_register: + raise ValueError( + f"Can not instantiate a class from type {type}, because {class_name} is currently not registered in Artifact._class_register." + ) + return Artifact._class_register[class_name] + + module = importlib.import_module(module_path) + + if "." not in class_name: + if hasattr(module, class_name) and inspect.isclass(getattr(module, class_name)): + return getattr(module, class_name) + if class_name in Artifact._class_register: + return Artifact._class_register[class_name] + module_file = module.__file__ if hasattr(module, "__file__") else None + if module_file: + module = import_module_from_file(module_file) + + assert class_name in Artifact._class_register + return Artifact._class_register[class_name] + + class_name_components = class_name.split(".") + klass = getattr(module, class_name_components[0]) + for i in range(1, len(class_name_components)): + klass = getattr(klass, class_name_components[i]) + return klass + + def is_name_legal_for_catalog(name): return re.match(r"^[\w" + constants.catalog_hierarchy_sep + "]+$", name) @@ -133,21 +181,10 @@ def maybe_recover_artifacts_structure(obj): return obj -def get_closest_artifact_type(type): - artifact_type_options = list(Artifact._class_register.keys()) - matches = difflib.get_close_matches(type, artifact_type_options) - if matches: - return matches[0] # Return the closest match - return None - - class UnrecognizedArtifactTypeError(ValueError): def __init__(self, type) -> None: maybe_class = "".join(word.capitalize() for word in type.split("_")) message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed." - closest_artifact_type = get_closest_artifact_type(type) - if closest_artifact_type is not None: - message += f"\n\nDid you mean '{closest_artifact_type}'?" super().__init__(message) @@ -200,8 +237,6 @@ def verify_artifact_dict(cls, d): ) if "__type__" not in d: raise MissingArtifactTypeError(d) - if not cls.is_registered_type(d["__type__"]): - raise UnrecognizedArtifactTypeError(d["__type__"]) @classmethod def get_artifact_type(cls): @@ -218,13 +253,6 @@ def register_class(cls, artifact_class): snake_case_key = camel_to_snake_case(artifact_class.__name__) - if cls.is_registered_type(snake_case_key): - assert ( - str(cls._class_register[snake_case_key]) == str(artifact_class) - ), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overridden by {artifact_class}." - - return snake_case_key - cls._class_register[snake_case_key] = artifact_class return snake_case_key @@ -241,19 +269,6 @@ def is_artifact_file(cls, path): d = json.load(f) return cls.is_artifact_dict(d) - @classmethod - def is_registered_type(cls, type: str): - return type in cls._class_register - - @classmethod - def is_registered_class_name(cls, class_name: str): - snake_case_key = camel_to_snake_case(class_name) - return cls.is_registered_type(snake_case_key) - - @classmethod - def is_registered_class(cls, clz: object): - return clz in set(cls._class_register.values()) - @classmethod def _recursive_load(cls, obj): if isinstance(obj, dict): @@ -267,7 +282,7 @@ def _recursive_load(cls, obj): pass if cls.is_artifact_dict(obj): cls.verify_artifact_dict(obj) - artifact_class = cls._class_register[obj.pop("__type__")] + artifact_class = get_class_from_artifact_type(obj.pop("__type__")) obj = artifact_class.process_data_after_load(obj) return artifact_class(**obj) @@ -684,3 +699,29 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]: return None return data_classification.get(artifact) + + +def find_unitxt_module_and_class_by_classname(camel_case_class_name: str): + """Find a module, a member of src/unitxt, that contains the definition of the class.""" + dir = os.path.dirname(__file__) # dir src/unitxt + try: + result = subprocess.run( + ["grep", "-irwE", "^class +" + camel_case_class_name, dir], + capture_output=True, + ).stdout.decode("ascii") + results = result.split("\n") + assert len(results) == 2, f"returned: {results}" + assert results[-1] == "", f"last result is {results[-1]} rather than ''" + to_return_module = ( + results[0].split(":")[0][:-3].replace("/", ".") + ) # trim the .py and replace + to_return_class_name = results[0].split(":")[1][ + 6 : 6 + len(camel_case_class_name) + ] + return to_return_module[ + to_return_module.rfind("unitxt.") : + ], to_return_class_name + except Exception as e: + raise ValueError( + f"Could not find the unitxt module, under unitxt/src/unitxt, in which class {camel_case_class_name} is defined" + ) from e diff --git a/src/unitxt/register.py b/src/unitxt/register.py index 51927276f3..35956aec89 100644 --- a/src/unitxt/register.py +++ b/src/unitxt/register.py @@ -1,9 +1,7 @@ -import importlib -import inspect import os from pathlib import Path -from .artifact import Artifact, Catalogs +from .artifact import Catalogs from .catalog import EnvironmentLocalCatalog, GithubCatalog, LocalCatalog from .error_utils import Documentation, UnitxtError, UnitxtWarning from .settings_utils import get_constants, get_settings @@ -89,28 +87,6 @@ def _reset_env_local_catalogs(): _register_catalog(EnvironmentLocalCatalog(location=path)) -def _register_all_artifacts(): - dir = os.path.dirname(__file__) - file_name = os.path.basename(__file__) - - for file in os.listdir(dir): - if ( - file.endswith(".py") - and file not in constants.non_registered_files - and file != file_name - ): - module_name = file.replace(".py", "") - - module = importlib.import_module("." + module_name, __package__) - - for _name, obj in inspect.getmembers(module): - # Make sure the object is a class - if inspect.isclass(obj): - # Make sure the class is a subclass of Artifact (but not Artifact itself) - if issubclass(obj, Artifact) and obj is not Artifact: - Artifact.register_class(obj) - - class ProjectArtifactRegisterer(metaclass=Singleton): def __init__(self): if not hasattr(self, "_registered"): @@ -118,7 +94,6 @@ def __init__(self): if not self._registered: _register_all_catalogs() - _register_all_artifacts() self._registered = True diff --git a/src/unitxt/text_utils.py b/src/unitxt/text_utils.py index c54d3fbd72..8440b0486d 100644 --- a/src/unitxt/text_utils.py +++ b/src/unitxt/text_utils.py @@ -71,6 +71,19 @@ def camel_to_snake_case(s): return s.lower() +def snake_to_camel_case(s): + """Converts a snake_case string s to CamelCase. Assume a class name is in question so result to start with an upper case. + + Not always the reciprocal of the above camel_to_snake_case. e.g: camel_to_snake_case(LoadHF) = load_hf, + whereas snake_to_camel_case(load_hf) = LoadHf + """ + s = s.strip() + words = s.split("_") + # Capitalize all words and join them + camel_case_parts = [word.capitalize() for word in words] + return "".join(camel_case_parts) + + def to_pretty_string( value, indent=0, diff --git a/tests/library/test_artifact_recovery.py b/tests/library/test_artifact_recovery.py index c074ad486b..da65205120 100644 --- a/tests/library/test_artifact_recovery.py +++ b/tests/library/test_artifact_recovery.py @@ -1,7 +1,12 @@ +import subprocess +import sys +import tempfile +import textwrap +from pathlib import Path + from unitxt.artifact import ( Artifact, MissingArtifactTypeError, - UnrecognizedArtifactTypeError, ) from unitxt.logging_utils import get_logger @@ -11,6 +16,80 @@ class TestArtifactRecovery(UnitxtTestCase): + def test_custom_catalog_and_project(self): + with tempfile.TemporaryDirectory() as tmpdirname: + project_dir = Path(tmpdirname) + operator_dir = project_dir / "operators" + catalog_dir = project_dir / "catalog" + operator_dir.mkdir() + + # Write the operator class + operator_code = textwrap.dedent( + """ + from unitxt.operators import InstanceOperator + + class MyTempOperator(InstanceOperator): + def process(self, instance, stream_name=None): + return instance + """ + ) + (operator_dir / "my_operator.py").write_text(operator_code) + (operator_dir / "__init__.py").write_text("") + + # Write the saving script + saving_code = textwrap.dedent( + f""" + from operators.my_operator import MyTempOperator + from unitxt import add_to_catalog, settings + + add_to_catalog(MyTempOperator(), "operators.my_temp_operator", catalog_path="{catalog_dir}") + """ + ) + saving_script = project_dir / "save_operator.py" + saving_script.write_text(saving_code) + + # Write the loading script + loading_code = textwrap.dedent( + """ + from unitxt import get_from_catalog + from operators.my_operator import MyTempOperator + + get_from_catalog("operators.my_temp_operator") + """ + ) + loading_script = project_dir / "load_operator.py" + loading_script.write_text(loading_code) + + # Run the saving script + result_save = subprocess.run( + [sys.executable, str(saving_script)], + env={ + "UNITXT_CATALOGS": str(catalog_dir), + "PYTHONPATH": str(project_dir), + }, + capture_output=True, + text=True, + ) + if result_save.returncode != 0: + logger.info(f"Saving script STDOUT:\n{result_save.stdout}") + logger.info(f"Saving script STDERR:\n{result_save.stderr}") + self.assertEqual(result_save.returncode, 0, "Saving script failed") + + # Run the loading script + result_load = subprocess.run( + [sys.executable, str(loading_script)], + env={ + "UNITXT_CATALOGS": str(catalog_dir), + "PYTHONPATH": str(project_dir), + }, + capture_output=True, + text=True, + ) + if result_load.returncode != 0: + logger.info(f"Loading script STDOUT:\n{result_load.stdout}") + logger.info(f"Loading script STDERR:\n{result_load.stderr}") + self.assertEqual(result_load.returncode, 0, "Loading script failed") + def test_correct_artifact_recovery(self): args = { "__type__": "dataset_recipe", @@ -63,12 +142,12 @@ def test_subclass_registration_and_loading(self): args = { "__type__": "dummy_not_exist", } - with self.assertRaises(UnrecognizedArtifactTypeError): + with self.assertRaises(ValueError): Artifact.from_dict(args) try: Artifact.from_dict(args) - except UnrecognizedArtifactTypeError as e: + except ValueError as e: logger.info("The error message (not a real error):", e) class DummyExistForLoading(Artifact): diff --git a/tests/library/test_artifact_registration.py b/tests/library/test_artifact_registration.py index e552e45613..1c63efa4ef 100644 --- a/tests/library/test_artifact_registration.py +++ b/tests/library/test_artifact_registration.py @@ -8,5 +8,5 @@ def test_subclass_registration(self): class DummyShouldBeRegistered(Artifact): pass - assert Artifact.is_registered_type("dummy_should_be_registered") - assert Artifact.is_registered_class(DummyShouldBeRegistered) + # assert Artifact.is_registered_type("dummy_should_be_registered") + # assert Artifact.is_registered_class(DummyShouldBeRegistered) From 80fba2446b11ef14e14f3ca26ed8923d58b4cdea Mon Sep 17 00:00:00 2001 From: dafnapension Date: Sat, 20 Sep 2025 23:40:57 +0300 Subject: [PATCH 2/2] use _class_register as a cache for instantiated classes Signed-off-by: dafnapension --- src/unitxt/artifact.py | 46 ++++++++------------- tests/library/test_artifact_registration.py | 2 +- 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index f77bb1e669..ac09741433 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -48,40 +48,30 @@ def import_module_from_file(file_path): return module -# type is read from a catelog entry, the value of a key "__type__" -def get_class_from_artifact_type(type: str): - if type in Artifact._class_register: - return Artifact._class_register[type] +# snake_case_class_name is read from a catelog entry, the value of a key "__type__" +# this method replaces the Artifact._class_register lookup, for all unitxt classes defined +# top level in any of the src/unitxt/*.py modules, which are all the classes that were registered +# by register_all_artifacts +def get_class_from_artifact_type(snake_case_class_name: str): + if snake_case_class_name in Artifact._class_register: + return Artifact._class_register[snake_case_class_name] module_path, class_name = find_unitxt_module_and_class_by_classname( - snake_to_camel_case(type) + snake_to_camel_case(snake_case_class_name) ) - if module_path == "class_register": - if class_name not in Artifact._class_register: - raise ValueError( - f"Can not instantiate a class from type {type}, because {class_name} is currently not registered in Artifact._class_register." - ) - return Artifact._class_register[class_name] module = importlib.import_module(module_path) - if "." not in class_name: - if hasattr(module, class_name) and inspect.isclass(getattr(module, class_name)): - return getattr(module, class_name) - if class_name in Artifact._class_register: - return Artifact._class_register[class_name] - module_file = module.__file__ if hasattr(module, "__file__") else None - if module_file: - module = import_module_from_file(module_file) - - assert class_name in Artifact._class_register - return Artifact._class_register[class_name] - - class_name_components = class_name.split(".") - klass = getattr(module, class_name_components[0]) - for i in range(1, len(class_name_components)): - klass = getattr(klass, class_name_components[i]) - return klass + if hasattr(module, class_name) and inspect.isclass(getattr(module, class_name)): + klass = getattr(module, class_name) + Artifact._class_register[ + snake_case_class_name + ] = klass # use _class_register as a cache + return klass + + raise ValueError( + f"Could not find the definition of class whose name, snake-cased is {snake_case_class_name}" + ) def is_name_legal_for_catalog(name): diff --git a/tests/library/test_artifact_registration.py b/tests/library/test_artifact_registration.py index 1c63efa4ef..72a74d6595 100644 --- a/tests/library/test_artifact_registration.py +++ b/tests/library/test_artifact_registration.py @@ -9,4 +9,4 @@ class DummyShouldBeRegistered(Artifact): pass # assert Artifact.is_registered_type("dummy_should_be_registered") - # assert Artifact.is_registered_class(DummyShouldBeRegistered) + assert "dummy_should_be_registered" in Artifact._class_register