diff --git a/refactor/context.py b/refactor/context.py index 82de645..f1f1108 100644 --- a/refactor/context.py +++ b/refactor/context.py @@ -7,7 +7,7 @@ from enum import Enum, auto from functools import cached_property from pathlib import Path -from typing import Any, ClassVar, DefaultDict, Protocol, cast +from typing import Any, ClassVar, DefaultDict, Protocol, cast, Type, List, Generator, Deque, Set import refactor.common as common from refactor.ast import UNPARSER_BACKENDS, BaseUnparser @@ -32,18 +32,30 @@ def __init__(self, context: Context) -> None: ... +def _deque_expand( + iterable: Iterable[type[_Dependable] | Iterable[type[_Dependable]]], +) -> Generator[type[_Dependable]]: + q: Deque[type[_Dependable] | Iterable[type[_Dependable]]] = deque(iterable) + while q: + item: type[_Dependable] | Iterable[type[_Dependable]] = q.popleft() + if hasattr(item, '__iter__') or isinstance(item, Iterable): + q.extendleft(iter(item)) + else: + yield item + + def _resolve_dependencies( - dependables: Iterable[type[_Dependable]], -) -> set[type[Representative]]: - dependencies: set[type[Representative]] = set() + dependables: Iterable[type[_Dependable] | Iterable[type[_Dependable]]], +) -> Set[Type[Representative]]: + dependencies: Set[type[Representative]] = set() - pool = deque(dependables) + pool = deque(_deque_expand(dependables)) while pool: - dependable = pool.pop() + dependable: type[_Dependable] = pool.pop() pool.extendleft( - dependency - for dependency in dependable.context_providers - if dependency not in dependencies + (cast(type[_Dependable], dependency) + for dependency in dependable.context_providers + if dependency not in dependencies) ) if issubclass(dependable, Representative): diff --git a/refactor/core.py b/refactor/core.py index 037ac94..f76dda7 100644 --- a/refactor/core.py +++ b/refactor/core.py @@ -5,9 +5,9 @@ import tokenize from collections.abc import Iterator from contextlib import suppress -from dataclasses import dataclass, field +from dataclasses import dataclass, field, astuple from pathlib import Path -from typing import ClassVar, NoReturn +from typing import ClassVar, NoReturn, Tuple, Generator, Type, List, Dict, Optional # TODO: remove the deprecated aliases on 1.0.0 from refactor.actions import ( # unimport:skip @@ -28,13 +28,36 @@ from refactor.internal.action_optimizer import optimize +def _unparsable_source_code(source: str, exc: SyntaxError) -> NoReturn: + error_message = "Generated source is unparsable." + + if Session.c_current_config.debug_mode: + fd, file_name = tempfile.mkstemp(prefix="refactor", text=True) + with open(fd, "w") as stream: + stream.write(source) + error_message += f"\nSee {file_name} for the generated source." + + raise ValueError(error_message) from exc + + +def _match_from_rule_or_collection( + r_or_c: Rule | RuleCollection, + node: ast.AST +) -> Generator[Tuple[Rule, BaseAction | Iterator[BaseAction]]]: + with suppress(AssertionError): + if isinstance(r_or_c, RuleCollection): + yield from r_or_c.match(node) + else: + yield r_or_c, r_or_c.match(node) + + class MaybeOverlappingActions(Exception): pass @dataclass class Rule: - context_providers: ClassVar[tuple[type[Representative], ...]] = () + context_providers: ClassVar[Tuple[Type[Representative], ...]] = () context: Context @@ -48,8 +71,8 @@ def check_file(self, path: Path | None) -> bool: return True def match( - self, - node: ast.AST, + self, + node: ast.AST, ) -> BaseAction | None | Iterator[BaseAction]: """Match the given ``node`` against current rule's scope. @@ -60,67 +83,146 @@ def match( raise NotImplementedError +class _IsIterable(type): + """Makes the class iterable in the sense of dependencies.""" + context_providers: ClassVar[Tuple[Type[Representative], ...]] = () + + def __iter__(self) -> Iterator[RuleCollection]: + return iter(self.rules) + + @dataclass -class Session: - """A refactoring session that consists of a set of rules and a configuration.""" +class RuleCollection(metaclass=_IsIterable): + """Collects a set of Type[Rule] and Type[RuleCollection] to be used as a groupable Rules + The idea is simply to allow cleaner complex Chained rules that may throw 'MaybeOverlap + when too large, yet allowing the Session to have a short set of 'Rule' and 'Collection'""" + rule_instances: Dict[Type[Rule], Rule] = field(default_factory=dict) + collection_instances: Dict[Type[RuleCollection], RuleCollection] = field(default_factory=dict) + + _validated: bool = field(default=False, repr=False) + _initialized: bool = field(default=False, repr=False) + + def _validate_collection(self) -> bool: + """Check if the rules exists and are valid. Raise an error if not. + Always returns True on valid RunCollection, otherwise raises an error + + :raises AttributeError: if the 'rules' attribute is not defined + :raises TypeError: if the rules attribute is not a list + :raises TypeError: if the rules attribute contains Rules or RuleCollections + :return: True if the collection is valid + """ + if not isinstance(self.rules, list): + raise TypeError("RuleCollection.rules must be a list") + if not all(issubclass(rule, (Rule, RuleCollection)) for rule in self.rules): + for rule in self.rules: + if not issubclass(rule, (Rule, RuleCollection)): + raise TypeError(f"RuleCollection.rules must contain only Rules or RuleCollections, not {rule}") + + # Remove duplicates + setattr(self, "rules", list(set(self.rules))) + + # Process collections within this collection. This is different, the collections need + # to be initialized with the context, but the rules do not. + for collection_type in self.rules: + if issubclass(collection_type, RuleCollection): + collection: RuleCollection = collection_type() + if collection._validate_collection(): + self.collection_instances[collection_type] = collection + else: + del self.rules[self.rules.index(collection_type)] + self._validated = True + + return self._validated + + def initialize_rules(self, context: Context, path: Path | None = None) -> None: + """Initialize all rules in the collection. Intended to be called by the Session""" + if not self._validated: + # Validate the collection before initializing, returns True or Raises an error + self._validate_collection() + + for rule in self.rules: + # If the rule is a Rule, initialize it + if issubclass(rule, Rule) and (instance := rule(context)).check_file(path): + self.rule_instances[rule] = instance + + # If the rule is a RuleCollection, call its initialization method + if issubclass(rule, RuleCollection): + self.collection_instances[rule].initialize_rules(context, path) + + self._initialized = True - rules: list[type[Rule]] = field(default_factory=list) - config: Configuration = field(default_factory=Configuration) + def check_file(self, path: Path | None) -> bool: + """This should always be True, as it is only called on the top level RuleCollection""" + return self._initialized - def _initialize_rules( - self, - tree: ast.Module, - source: str, - file_info: _FileInfo, - ) -> list[Rule]: - context = Context._from_dependencies( - _resolve_dependencies(self.rules), - tree=tree, - source=source, - file_info=file_info, - config=self.config, - ) - return [ - instance - for rule in self.rules - if (instance := rule(context)).check_file(file_info.path) - ] - - def _apply_single( - self, - context: Context, - source_code: str, - action: BaseAction, - enable_optimizations: bool = True, - ) -> str: - if enable_optimizations: - action = optimize(action, context) - return action.apply(context, source_code) - - def _apply_multiple( - self, - rule: Rule, - source_code: str, - actions: Iterator[BaseAction], - ) -> str: - # Compute the path of the current node (against the starting tree). - # - # Adjust this path with the knowledge from the previously applied - # actions. - # - # Use the path to find the correct node in the new tree. + def match(self, node: ast.AST) -> Generator[Tuple[Rule, BaseAction | None | Iterator[BaseAction]]]: + """Match the given ``node`` against all the rules in the collection. + + It yields tuples of all the Rule, BaseAction that match. + :raises RuntimeError: if the RuleCollection is not initialized + :return: A generator of tuples of Rule, BaseAction + """ + + if not self._initialized: + raise RuntimeError("RuleCollection must be initialized before matching") + + # We keep the order of 'rules' + for rule_type in self.rules: + # Only initialized rules are in the rule_instances dict + if rule_type in self.rule_instances.keys(): + # For the rules in the Collection, we need to suppress the AssertionError + # as it is not an error, it is just a failed match in the list of Rules + with suppress(AssertionError): + matched_action = self.rule_instances[rule_type].match(node) + rule = self.rule_instances[rule_type] + yield rule, matched_action + if rule_type in self.collection_instances.keys(): + # If the rule is a RuleCollection, call its 'match' method - recursion? + yield from self.collection_instances[rule_type].match(node) + return + + +@dataclass(frozen=True) +class _SourceFromIterator: + """A match of a rule against a source file. + + :param rule: The rule that matched + :param action: The action or action iterator that was returned by the rule + :param source_code: source to which apply the actions + :return: A updated source + """ + + rule: Rule + action: BaseAction | Iterator[BaseAction] + source_code: str + enable_optimizations: bool = field(default=True) + + def __post_init__(self): + if not isinstance(self.source_code, str) or not self.source_code: + raise TypeError("source_code must be a non-empty string") + if not self.action: + raise TypeError("action cannot be None") + + def source(self) -> str: + """Compute the path of the current node (against the starting tree). + + Adjust this path with the knowledge from the previously applied + actions. + + Use the path to find the correct node in the new tree.""" from refactor.internal.graph_access import AccessFailure, GraphPath - shifts: list[tuple[GraphPath, int]] = [] - previous_tree = rule.context.tree - for action in actions: + shifts: List[Tuple[GraphPath, int]] = [] + updated_source = self.source_code + previous_tree = self.rule.context.tree + for action in self.action: input_node, stack_effect = action._stack_effect() # We compute each path against the initial revision of the tree # since the rule who is producing them doesn't have access to the # temporary trees we generate on the fly. - path = GraphPath.backtrack_from(rule.context, input_node) + path = GraphPath.backtrack_from(self.rule.context, input_node) # And due to this, some actions might have altered the tree in a # way that makes the path as is invalid. For ensuring that the path @@ -136,44 +238,124 @@ def _apply_multiple( raise MaybeOverlappingActions( "When using chained actions, individual actions should not" " overlap with each other." + f"\n Action attempted: {action} for node: {ast.unparse(action.node)}" + f"\n Path: {path}" ) from None else: shifts.append((path, stack_effect)) - updated_action = action._replace_input(updated_input) - updated_context = rule.context.replace( - source=source_code, tree=previous_tree - ) - - # TODO: re-enable optimizations if it is viable to run - # them on the new tree/source code. - source_code = self._apply_single( - updated_context, - source_code, - updated_action, - enable_optimizations=False, - ) + updated_action: BaseAction = action._replace_input(updated_input) + updated_context: Context = self.rule.context.replace(source=updated_source, tree=previous_tree) + + # TODO: re-enable optimizations if it is viable to run them on the new tree/source code. + updated_source: str = _SourceFromAction(self.rule, + updated_action, + updated_source, + context=updated_context, + enable_optimizations=False, + ).source() + try: - previous_tree = ast.parse(source_code) + previous_tree = ast.parse(updated_source) except SyntaxError as exc: - return self._unparsable_source_code(source_code, exc) - return source_code + _unparsable_source_code(updated_source, exc) + return updated_source + + +@dataclass(frozen=True) +class _SourceFromAction(_SourceFromIterator): + """A match of a rule against a source file.""" + + context: Context = field(default=None) + + def __post_init__(self): + if not isinstance(self.context, Context): + raise TypeError("context must be a Context instance") + + def source(self) -> str: + """Apply a single action to the source""" + if isinstance(action := self.action, Iterator): + return super().source() + + if self.enable_optimizations: + action = optimize(self.action, self.context) + source: str = action.apply(self.context, self.source_code) + return source + + +@dataclass(frozen=True) +class _SourceFromRuleOrCollection: + rule_or_collection: Rule | RuleCollection + + _indent: str = field(default="") + + def source(self, node, source) -> str: + """Mixes the method of creating the source update between BaseAction and Iterator[BaseAction].""" + new_source: str = source + for rule, action in _match_from_rule_or_collection(self.rule_or_collection, node): + if action is None: + continue + builder: _SourceFromAction = _SourceFromAction(rule, action, new_source, context=rule.context) + + with suppress(AssertionError): + new_source: str = builder.source() + + # Yield if source has changed, otherwise we continue to the next rule. + if new_source is not None and new_source != "": + yield new_source + else: + return source + return source + + +@dataclass +class Session: + """A refactoring session that consists of a set of rules and a configuration.""" + c_current_config: ClassVar[Configuration] + + rules: list[type[Rule] | RuleCollection] = field(default_factory=list) + config: Configuration = field(default_factory=Configuration) + + def _initialize_rules( + self, + tree: ast.Module, + source: str, + file_info: _FileInfo, + ) -> list[Rule]: + """Initialize all the rules in the session. This is done by calling the ``initialize`` method on each rule. """ + Session.c_current_config = self.config + context = Context._from_dependencies( + _resolve_dependencies(self.rules), # type: ignore + tree=tree, + source=source, + file_info=file_info, + config=self.config, + ) + instances: list[Rule | RuleCollection] = [] + for rule_or_collection in self.rules: + if issubclass(rule_or_collection, RuleCollection): + # We want to initialize the rules, but keep the rules grouped for intermediate tree update + (collection := rule_or_collection()).initialize_rules(context, file_info.path) + if collection.rule_instances and len(collection.rule_instances) > 0: + instances.append(collection) + else: + instances.append(rule_or_collection(context)) + return [i for i in instances if i.check_file(file_info.path)] def _run( - self, - source: str, - file_info: _FileInfo, - *, - _changed: bool = False, - _known_sources: frozenset[str] = frozenset(), - ) -> tuple[str, bool]: + self, + source: str, + file_info: _FileInfo, + *, + _changed: bool = False, + _known_sources: frozenset[str] = frozenset(), + ) -> Tuple[str, bool]: try: tree = ast.parse(source) except SyntaxError as exc: if not _changed: return source, _changed - else: - return self._unparsable_source_code(source, exc) + return _unparsable_source_code(source, exc) _known_sources |= {source} rules = self._initialize_rules(tree, source, file_info) @@ -182,20 +364,8 @@ def _run( if not has_positions(type(node)): # type: ignore continue - for rule in rules: - with suppress(AssertionError): - match = rule.match(node) - if match is None: - continue - elif isinstance(match, BaseAction): - new_source = self._apply_single(rule.context, source, match) - elif isinstance(match, Iterator): - new_source = self._apply_multiple(rule, source, match) - else: - raise TypeError( - f"Unexpected action type: {type(match).__name__}" - ) - + for rule_or_collection in rules: + for new_source in _SourceFromRuleOrCollection(rule_or_collection).source(node, source): if new_source not in _known_sources: return self._run( new_source, @@ -206,17 +376,6 @@ def _run( return source, _changed - def _unparsable_source_code(self, source: str, exc: SyntaxError) -> NoReturn: - error_message = "Generated source is unparsable." - - if self.config.debug_mode: - fd, file_name = tempfile.mkstemp(prefix="refactor", text=True) - with open(fd, "w") as stream: - stream.write(source) - error_message += f"\nSee {file_name} for the generated source." - - raise ValueError(error_message) from exc - def run(self, source: str) -> str: """Apply all the rules from this session to the given ``source`` and return the transformed version. diff --git a/tests/test_context.py b/tests/test_context.py index 40ad602..f9f3c8d 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -14,7 +14,7 @@ Scope, _resolve_dependencies, ) -from refactor.core import Rule +from refactor.core import Rule, RuleCollection def get_context(source, *representatives, **kwargs): @@ -309,6 +309,69 @@ class Rule6(Rule): } +def test_dependency_resolver_with_collection(): + class Rep1(Representative): + pass + + class Rep2(Representative): + context_providers = (Rep1,) + + class Rep3(Representative): + context_providers = (Rep2,) + + class Rule1(Rule): + pass + + class Rule2(Rule): + context_providers = (Rep1,) + + class Rule3(Rule): + context_providers = (Rep1,) + + class Rule4(Rule): + context_providers = (Rep1, Rep2) + + class Rule5(Rule): + context_providers = (Rep2,) + + class Rule6(Rule): + context_providers = (Rep3,) + + class Collection2(RuleCollection): + rules = [Rule1, Rule2] + + class Collection4(RuleCollection): + rules = [Rule1, Rule2, Rule4] + + class Collection6(RuleCollection): + rules = [Rule6] + + assert _resolve_dependencies([Rule1]) == set() + assert _resolve_dependencies([Rule2]) == {Rep1} + assert _resolve_dependencies([Rule3]) == {Rep1} + assert _resolve_dependencies([Rule4]) == {Rep1, Rep2} + assert _resolve_dependencies([Rule5]) == {Rep1, Rep2} + assert _resolve_dependencies([Rule6]) == {Rep1, Rep2, Rep3} + + assert _resolve_dependencies([Collection2]) == {Rep1} + assert _resolve_dependencies([Collection4]) == {Rep1, Rep2} + assert _resolve_dependencies([Collection2, Rule4]) == {Rep1, Rep2} + assert _resolve_dependencies([Rule4, Collection2]) == {Rep1, Rep2} + assert _resolve_dependencies([Collection2, Collection6]) == {Rep1, Rep2, Rep3} + + assert _resolve_dependencies([Rule1, Rule2]) == {Rep1} + assert _resolve_dependencies([Rule2, Rule3]) == {Rep1} + assert _resolve_dependencies([Rule1, Rule2, Rule3]) == {Rep1} + assert _resolve_dependencies([Rule1, Rule2, Rule4]) == {Rep1, Rep2} + assert _resolve_dependencies([Rule1, Rule5]) == {Rep1, Rep2} + assert _resolve_dependencies([Rule1, Rule6]) == {Rep1, Rep2, Rep3} + assert _resolve_dependencies([Rule1, Rule2, Rule5, Rule6]) == { + Rep1, + Rep2, + Rep3, + } + + def test_dependency_resolver_recursion(): class Rep1(Representative): pass diff --git a/tests/test_core.py b/tests/test_core.py index e28f540..6d8ce64 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -10,7 +10,7 @@ from refactor.actions import InsertAfter, LazyReplace, Replace from refactor.change import Change from refactor.context import Configuration, Context, Representative -from refactor.core import Rule, Session +from refactor.core import Rule, Session, RuleCollection fake_ctx = Context(source="", tree=ast.AST()) test_file = common._FileInfo() @@ -91,6 +91,38 @@ def match(self, node): return SimpleAction(node) +class MultToMinusRule(Rule): + def match(self, node): + assert isinstance(node, ast.BinOp) + assert isinstance(node.op, ast.Mult) + + return SimpleAction(node) + + +class DivToMinusRule(Rule): + def match(self, node): + assert isinstance(node, ast.BinOp) + assert isinstance(node.op, ast.Div) + + return SimpleAction(node) + + +class ModuloToMinusRule(Rule): + def match(self, node): + assert isinstance(node, ast.BinOp) + assert isinstance(node.op, ast.Mod) + + return SimpleAction(node) + + +class PowToMinusRule(Rule): + def match(self, node): + assert isinstance(node, ast.BinOp) + assert isinstance(node.op, ast.Pow) + + return SimpleAction(node) + + class SimpleRepresentative(Representative): name = "simple" @@ -158,6 +190,55 @@ def test_session_simple(source, rules, expected): assert session.run(source) == expected +class CollectPlusToMinusMultToMinusRule(RuleCollection): + rules = [PlusToMinusRule, MultToMinusRule, ] + + +class CollectMultToMinusPlusToMinusRule(RuleCollection): + rules = [MultToMinusRule, PlusToMinusRule] + + +@pytest.mark.parametrize( + "source, expected, rules", + [ + ("1+1*2", "1 - 1 - 2", CollectPlusToMinusMultToMinusRule), + ("1+1*2", "1 - 1 - 2", CollectMultToMinusPlusToMinusRule), + ("1*1+2", "1 - 1 - 2", CollectPlusToMinusMultToMinusRule), + ("1*1+2", "1 - 1 - 2", CollectMultToMinusPlusToMinusRule), + ] +) +def test_session_collection(source, rules, expected): + if isinstance(rules, type): + rules = [rules] + + source = textwrap.dedent(source) + expected = textwrap.dedent(expected) + + session = Session(rules) + assert session.run(source) == expected + + +class CollectCollectRule(RuleCollection): + rules = [DivToMinusRule, CollectPlusToMinusMultToMinusRule, ModuloToMinusRule] + + +@pytest.mark.parametrize( + "source, expected, rules", + [ + ("1+1*2/3%4", "1 - 1 - 2 - 3 - 4", CollectCollectRule), + ] +) +def test_session_multicollection(source, rules, expected): + if isinstance(rules, type): + rules = [rules] + + source = textwrap.dedent(source) + expected = textwrap.dedent(expected) + + session = Session(rules) + assert session.run(source) == expected + + def test_session_run_file(tmp_path): paths = set()