Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import PythonLexer
from unitxt.artifact import Artifact
from unitxt.artifact import get_class_from_artifact_type
from unitxt.text_utils import print_dict_as_python
from unitxt.utils import load_json

Expand Down Expand Up @@ -51,7 +51,7 @@ def imports_to_syntax_highlighted_html(subtypes: List[str]) -> str:
return ""
module_to_class_names = defaultdict(list)
for subtype in subtypes:
subtype_class = Artifact._class_register.get(subtype)
subtype_class = get_class_from_artifact_type(subtype)
module_to_class_names[subtype_class.__module__].append(subtype_class.__name__)

imports_txt = ""
Expand Down Expand Up @@ -150,7 +150,7 @@ def recursive_search(d):

@lru_cache(maxsize=None)
def artifact_type_to_link(artifact_type):
artifact_class = Artifact._class_register.get(artifact_type)
artifact_class = get_class_from_artifact_type(artifact_type)
type_class_name = artifact_class.__name__
artifact_class_id = f"{artifact_class.__module__}.{type_class_name}"
return f'<a class="reference internal" href="../{artifact_class.__module__}.html#{artifact_class_id}" title="{artifact_class_id}"><code class="xref py py-class docutils literal notranslate"><span class="pre">{type_class_name}</span></code></a>'
Expand All @@ -159,7 +159,7 @@ def artifact_type_to_link(artifact_type):
# flake8: noqa: C901
def make_content(artifact, label, all_labels):
artifact_type = artifact["__type__"]
artifact_class = Artifact._class_register.get(artifact_type)
artifact_class = get_class_from_artifact_type(artifact_type)
type_class_name = artifact_class.__name__
catalog_id = label.replace("catalog.", "")

Expand Down Expand Up @@ -243,7 +243,7 @@ def make_content(artifact, label, all_labels):
result += artifact_class.__doc__ + "\n"

for subtype in subtypes:
subtype_class = Artifact._class_register.get(subtype)
subtype_class = get_class_from_artifact_type(subtype)
subtype_class_name = subtype_class.__name__
if subtype_class.__doc__:
explanation_str = f"Explanation about `{subtype_class_name}`"
Expand Down
4 changes: 2 additions & 2 deletions src/unitxt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
)
from .catalog import add_to_catalog, get_from_catalog
from .logging_utils import get_logger
from .register import register_all_artifacts, register_local_catalog
from .register import ProjectArtifactRegisterer, register_local_catalog
from .settings_utils import get_constants, get_settings

register_all_artifacts()
ProjectArtifactRegisterer()
random.seed(0)

constants = get_constants()
Expand Down
103 changes: 67 additions & 36 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import difflib
import importlib
import inspect
import json
import os
import pkgutil
import re
import subprocess
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union, final
Expand All @@ -22,7 +23,7 @@
separate_inside_and_outside_square_brackets,
)
from .settings_utils import get_constants, get_settings
from .text_utils import camel_to_snake_case, is_camel_case
from .text_utils import camel_to_snake_case, is_camel_case, snake_to_camel_case
from .type_utils import isoftype, issubtype
from .utils import (
artifacts_json_cache,
Expand All @@ -36,6 +37,43 @@
constants = get_constants()


def import_module_from_file(file_path):
# Get the module name (file name without extension)
module_name = os.path.splitext(os.path.basename(file_path))[0]
# Create a module specification
spec = importlib.util.spec_from_file_location(module_name, file_path)
# Create a new module based on the specification
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


# snake_case_class_name is read from a catelog entry, the value of a key "__type__"
# this method replaces the Artifact._class_register lookup, for all unitxt classes defined
# top level in any of the src/unitxt/*.py modules, which are all the classes that were registered
# by register_all_artifacts
def get_class_from_artifact_type(snake_case_class_name: str):
if snake_case_class_name in Artifact._class_register:
return Artifact._class_register[snake_case_class_name]

module_path, class_name = find_unitxt_module_and_class_by_classname(
snake_to_camel_case(snake_case_class_name)
)

module = importlib.import_module(module_path)

if hasattr(module, class_name) and inspect.isclass(getattr(module, class_name)):
klass = getattr(module, class_name)
Artifact._class_register[
snake_case_class_name
] = klass # use _class_register as a cache
return klass

raise ValueError(
f"Could not find the definition of class whose name, snake-cased is {snake_case_class_name}"
)


def is_name_legal_for_catalog(name):
return re.match(r"^[\w" + constants.catalog_hierarchy_sep + "]+$", name)

Expand Down Expand Up @@ -133,21 +171,10 @@ def maybe_recover_artifacts_structure(obj):
return obj


def get_closest_artifact_type(type):
artifact_type_options = list(Artifact._class_register.keys())
matches = difflib.get_close_matches(type, artifact_type_options)
if matches:
return matches[0] # Return the closest match
return None


class UnrecognizedArtifactTypeError(ValueError):
def __init__(self, type) -> None:
maybe_class = "".join(word.capitalize() for word in type.split("_"))
message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed."
closest_artifact_type = get_closest_artifact_type(type)
if closest_artifact_type is not None:
message += f"\n\nDid you mean '{closest_artifact_type}'?"
super().__init__(message)


Expand Down Expand Up @@ -200,8 +227,6 @@ def verify_artifact_dict(cls, d):
)
if "__type__" not in d:
raise MissingArtifactTypeError(d)
if not cls.is_registered_type(d["__type__"]):
raise UnrecognizedArtifactTypeError(d["__type__"])

@classmethod
def get_artifact_type(cls):
Expand All @@ -218,13 +243,6 @@ def register_class(cls, artifact_class):

snake_case_key = camel_to_snake_case(artifact_class.__name__)

if cls.is_registered_type(snake_case_key):
assert (
str(cls._class_register[snake_case_key]) == str(artifact_class)
), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overridden by {artifact_class}."

return snake_case_key

cls._class_register[snake_case_key] = artifact_class

return snake_case_key
Expand All @@ -241,19 +259,6 @@ def is_artifact_file(cls, path):
d = json.load(f)
return cls.is_artifact_dict(d)

@classmethod
def is_registered_type(cls, type: str):
return type in cls._class_register

@classmethod
def is_registered_class_name(cls, class_name: str):
snake_case_key = camel_to_snake_case(class_name)
return cls.is_registered_type(snake_case_key)

@classmethod
def is_registered_class(cls, clz: object):
return clz in set(cls._class_register.values())

@classmethod
def _recursive_load(cls, obj):
if isinstance(obj, dict):
Expand All @@ -267,7 +272,7 @@ def _recursive_load(cls, obj):
pass
if cls.is_artifact_dict(obj):
cls.verify_artifact_dict(obj)
artifact_class = cls._class_register[obj.pop("__type__")]
artifact_class = get_class_from_artifact_type(obj.pop("__type__"))
obj = artifact_class.process_data_after_load(obj)
return artifact_class(**obj)

Expand Down Expand Up @@ -684,3 +689,29 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
return None

return data_classification.get(artifact)


def find_unitxt_module_and_class_by_classname(camel_case_class_name: str):
"""Find a module, a member of src/unitxt, that contains the definition of the class."""
dir = os.path.dirname(__file__) # dir src/unitxt
try:
result = subprocess.run(
["grep", "-irwE", "^class +" + camel_case_class_name, dir],
capture_output=True,
).stdout.decode("ascii")
results = result.split("\n")
assert len(results) == 2, f"returned: {results}"
assert results[-1] == "", f"last result is {results[-1]} rather than ''"
to_return_module = (
results[0].split(":")[0][:-3].replace("/", ".")
) # trim the .py and replace
to_return_class_name = results[0].split(":")[1][
6 : 6 + len(camel_case_class_name)
]
return to_return_module[
to_return_module.rfind("unitxt.") :
], to_return_class_name
except Exception as e:
raise ValueError(
f"Could not find the unitxt module, under unitxt/src/unitxt, in which class {camel_case_class_name} is defined"
) from e
27 changes: 1 addition & 26 deletions src/unitxt/register.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import importlib
import inspect
import os
from pathlib import Path

from .artifact import Artifact, Catalogs
from .artifact import Catalogs
from .catalog import EnvironmentLocalCatalog, GithubCatalog, LocalCatalog
from .error_utils import Documentation, UnitxtError, UnitxtWarning
from .settings_utils import get_constants, get_settings
Expand Down Expand Up @@ -89,36 +87,13 @@ def _reset_env_local_catalogs():
_register_catalog(EnvironmentLocalCatalog(location=path))


def _register_all_artifacts():
dir = os.path.dirname(__file__)
file_name = os.path.basename(__file__)

for file in os.listdir(dir):
if (
file.endswith(".py")
and file not in constants.non_registered_files
and file != file_name
):
module_name = file.replace(".py", "")

module = importlib.import_module("." + module_name, __package__)

for _name, obj in inspect.getmembers(module):
# Make sure the object is a class
if inspect.isclass(obj):
# Make sure the class is a subclass of Artifact (but not Artifact itself)
if issubclass(obj, Artifact) and obj is not Artifact:
Artifact.register_class(obj)


class ProjectArtifactRegisterer(metaclass=Singleton):
def __init__(self):
if not hasattr(self, "_registered"):
self._registered = False

if not self._registered:
_register_all_catalogs()
_register_all_artifacts()
self._registered = True


Expand Down
13 changes: 13 additions & 0 deletions src/unitxt/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ def camel_to_snake_case(s):
return s.lower()


def snake_to_camel_case(s):
"""Converts a snake_case string s to CamelCase. Assume a class name is in question so result to start with an upper case.

Not always the reciprocal of the above camel_to_snake_case. e.g: camel_to_snake_case(LoadHF) = load_hf,
whereas snake_to_camel_case(load_hf) = LoadHf
"""
s = s.strip()
words = s.split("_")
# Capitalize all words and join them
camel_case_parts = [word.capitalize() for word in words]
return "".join(camel_case_parts)


def to_pretty_string(
value,
indent=0,
Expand Down
Loading