diff --git a/refactor/actions.py b/refactor/actions.py index 173a0e4..95c6189 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -177,6 +177,10 @@ def apply(self, context: Context, source: str) -> str: replacement = split_lines(context.unparse(self.build())) replacement.apply_indentation(indentation, start_prefix=start_prefix) + if hasattr(self, "separator") and self.separator: + # Adding extra separating line + replacement.insert(0, lines._newline_type) + original_node_end = cast(int, self.node.end_lineno) - 1 if lines[original_node_end].endswith(lines._newline_type): replacement[-1] += lines._newline_type @@ -220,6 +224,10 @@ def apply(self, context: Context, source: str) -> str: replacement.apply_indentation(indentation, start_prefix=start_prefix) replacement[-1] += lines._newline_type + if hasattr(self, "separator") and self.separator: + # Adding extra separating line + replacement.append(lines._newline_type) + original_node_start = cast(int, self.node.lineno) for line in reversed(replacement): lines.insert(original_node_start - 1, line) @@ -253,6 +261,7 @@ class InsertAfter(LazyInsertAfter): """ target: ast.stmt + separator: bool = False def build(self) -> ast.stmt: return self.target @@ -268,6 +277,7 @@ class InsertBefore(LazyInsertBefore): """ target: ast.stmt + separator: bool = False def build(self) -> ast.stmt: return self.target diff --git a/tests/test_actions.py b/tests/test_actions.py index a8040be..b9aa631 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -2,6 +2,7 @@ import ast import textwrap +from dataclasses import dataclass from pathlib import Path from typing import Iterator, cast @@ -9,7 +10,7 @@ from refactor.ast import DEFAULT_ENCODING from refactor import Session, common -from refactor.actions import Erase, InvalidActionError, InsertAfter, Replace, InsertBefore +from refactor.actions import Erase, InvalidActionError, InsertAfter, Replace, InsertBefore, LazyInsertAfter, LazyInsertBefore from refactor.context import Context from refactor.core import Rule @@ -48,6 +49,14 @@ def foo(): INVALID_ERASES_TREE = ast.parse(INVALID_ERASES) +@dataclass +class BuildInsertAfterBottom(LazyInsertAfter): + separator: bool + def build(self) -> ast.Await: + await_st = ast.parse("await async_test()") + return await_st + + class TestInsertAfterBottom(Rule): INPUT_SOURCE = """ try: @@ -76,6 +85,90 @@ def match(self, node: ast.AST) -> Iterator[InsertAfter]: yield Replace(node, cast(ast.AST, new_try)) +class TestInsertAfterBottomWithBuild(Rule): + INPUT_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + await async_test()""" + + def match(self, node: ast.AST) -> Iterator[InsertAfter]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + yield BuildInsertAfterBottom(node, separator=False) + new_try = common.clone(node) + new_try.body = [node.body[0]] + yield Replace(node, cast(ast.AST, new_try)) + + +class TestInsertAfterBottomWithSeparator(Rule): + INPUT_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + + await async_test()""" + + def match(self, node: ast.AST) -> Iterator[InsertAfter]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + await_st = ast.parse("await async_test()") + yield InsertAfter(node, cast(ast.stmt, await_st), separator=True) + new_try = common.clone(node) + new_try.body = [node.body[0]] + yield Replace(node, cast(ast.AST, new_try)) + + +class TestInsertAfterBottomWithSeparatorWithBuild(Rule): + INPUT_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + + await async_test()""" + + def match(self, node: ast.AST) -> Iterator[InsertAfter]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + yield BuildInsertAfterBottom(node, separator=True) + new_try = common.clone(node) + new_try.body = [node.body[0]] + yield Replace(node, cast(ast.AST, new_try)) + + class TestInsertBeforeTop(Rule): INPUT_SOURCE = """ try: @@ -104,6 +197,35 @@ def match(self, node: ast.AST) -> Iterator[InsertBefore]: yield Replace(node, cast(ast.AST, new_try)) +class TestInsertBeforeTopWithSeparator(Rule): + INPUT_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + await async_test() + + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + def match(self, node: ast.AST) -> Iterator[InsertBefore]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + await_st = ast.parse("await async_test()") + yield InsertBefore(node, cast(ast.stmt, await_st), separator=True) + new_try = common.clone(node) + new_try.body = [node.body[0]] + yield Replace(node, cast(ast.AST, new_try)) + + class TestInsertAfter(Rule): INPUT_SOURCE = """ def generate_index(base_path, active_path): @@ -489,7 +611,11 @@ def test_erase_invalid(invalid_node): "rule", [ TestInsertAfterBottom, + TestInsertAfterBottomWithBuild, + TestInsertAfterBottomWithSeparator, + TestInsertAfterBottomWithSeparatorWithBuild, TestInsertBeforeTop, + TestInsertBeforeTopWithSeparator, TestInsertAfter, TestInsertBefore, TestInsertAfterThenBefore, diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index dfd4984..469a0a6 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -176,6 +176,7 @@ def collect(self, name, scope): class AddNewImport(LazyInsertAfter): module: str names: list[str] + separator: bool = False def build(self): return ast.ImportFrom( @@ -189,6 +190,7 @@ def build(self): class AddNewImportBefore(LazyInsertBefore): module: str names: list[str] + separator: bool = False def build(self): return ast.ImportFrom(