@@ -48,40 +48,30 @@ def import_module_from_file(file_path):
4848 return module
4949
5050
51- # type is read from a catelog entry, the value of a key "__type__"
52- def get_class_from_artifact_type (type : str ):
53- if type in Artifact ._class_register :
54- return Artifact ._class_register [type ]
51+ # snake_case_class_name is read from a catelog entry, the value of a key "__type__"
52+ # this method replaces the Artifact._class_register lookup, for all unitxt classes defined
53+ # top level in any of the src/unitxt/*.py modules, which are all the classes that were registered
54+ # by register_all_artifacts
55+ def get_class_from_artifact_type (snake_case_class_name : str ):
56+ if snake_case_class_name in Artifact ._class_register :
57+ return Artifact ._class_register [snake_case_class_name ]
5558
5659 module_path , class_name = find_unitxt_module_and_class_by_classname (
57- snake_to_camel_case (type )
60+ snake_to_camel_case (snake_case_class_name )
5861 )
59- if module_path == "class_register" :
60- if class_name not in Artifact ._class_register :
61- raise ValueError (
62- f"Can not instantiate a class from type { type } , because { class_name } is currently not registered in Artifact._class_register."
63- )
64- return Artifact ._class_register [class_name ]
6562
6663 module = importlib .import_module (module_path )
6764
68- if "." not in class_name :
69- if hasattr (module , class_name ) and inspect .isclass (getattr (module , class_name )):
70- return getattr (module , class_name )
71- if class_name in Artifact ._class_register :
72- return Artifact ._class_register [class_name ]
73- module_file = module .__file__ if hasattr (module , "__file__" ) else None
74- if module_file :
75- module = import_module_from_file (module_file )
76-
77- assert class_name in Artifact ._class_register
78- return Artifact ._class_register [class_name ]
79-
80- class_name_components = class_name .split ("." )
81- klass = getattr (module , class_name_components [0 ])
82- for i in range (1 , len (class_name_components )):
83- klass = getattr (klass , class_name_components [i ])
84- return klass
65+ if hasattr (module , class_name ) and inspect .isclass (getattr (module , class_name )):
66+ klass = getattr (module , class_name )
67+ Artifact ._class_register [
68+ snake_case_class_name
69+ ] = klass # use _class_register as a cache
70+ return klass
71+
72+ raise ValueError (
73+ f"Could not find the definition of class whose name, snake-cased is { snake_case_class_name } "
74+ )
8575
8676
8777def is_name_legal_for_catalog (name ):
0 commit comments