Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions refactor/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def apply(self, context: Context, source: str) -> str:
replacement = split_lines(context.unparse(self.build()))
replacement.apply_indentation(indentation, start_prefix=start_prefix)

if hasattr(self, "separator") and self.separator:
# Adding extra separating line
replacement.insert(0, lines._newline_type)

original_node_end = cast(int, self.node.end_lineno) - 1
if lines[original_node_end].endswith(lines._newline_type):
replacement[-1] += lines._newline_type
Expand Down Expand Up @@ -220,6 +224,10 @@ def apply(self, context: Context, source: str) -> str:
replacement.apply_indentation(indentation, start_prefix=start_prefix)
replacement[-1] += lines._newline_type

if hasattr(self, "separator") and self.separator:
# Adding extra separating line
replacement.append(lines._newline_type)

original_node_start = cast(int, self.node.lineno)
for line in reversed(replacement):
lines.insert(original_node_start - 1, line)
Expand Down Expand Up @@ -253,6 +261,7 @@ class InsertAfter(LazyInsertAfter):
"""

target: ast.stmt
separator: bool = False

def build(self) -> ast.stmt:
return self.target
Expand All @@ -268,6 +277,7 @@ class InsertBefore(LazyInsertBefore):
"""

target: ast.stmt
separator: bool = False

def build(self) -> ast.stmt:
return self.target
Expand Down
128 changes: 127 additions & 1 deletion tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import ast
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator, cast

import pytest
from refactor.ast import DEFAULT_ENCODING

from refactor import Session, common
from refactor.actions import Erase, InvalidActionError, InsertAfter, Replace, InsertBefore
from refactor.actions import Erase, InvalidActionError, InsertAfter, Replace, InsertBefore, LazyInsertAfter, LazyInsertBefore
from refactor.context import Context
from refactor.core import Rule

Expand Down Expand Up @@ -48,6 +49,14 @@ def foo():
INVALID_ERASES_TREE = ast.parse(INVALID_ERASES)


@dataclass
class BuildInsertAfterBottom(LazyInsertAfter):
separator: bool
def build(self) -> ast.Await:
await_st = ast.parse("await async_test()")
return await_st


class TestInsertAfterBottom(Rule):
INPUT_SOURCE = """
try:
Expand Down Expand Up @@ -76,6 +85,90 @@ def match(self, node: ast.AST) -> Iterator[InsertAfter]:
yield Replace(node, cast(ast.AST, new_try))


class TestInsertAfterBottomWithBuild(Rule):
INPUT_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
first_tree = get_tree(first_tree, module_name)
second_tree = get_tree(second_tree, module_name)
third_tree = get_tree(third_tree, module_name)
except (SyntaxError, FileNotFoundError):
continue"""

EXPECTED_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
except (SyntaxError, FileNotFoundError):
continue
await async_test()"""

def match(self, node: ast.AST) -> Iterator[InsertAfter]:
assert isinstance(node, ast.Try)
assert len(node.body) >= 2

yield BuildInsertAfterBottom(node, separator=False)
new_try = common.clone(node)
new_try.body = [node.body[0]]
yield Replace(node, cast(ast.AST, new_try))


class TestInsertAfterBottomWithSeparator(Rule):
INPUT_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
first_tree = get_tree(first_tree, module_name)
second_tree = get_tree(second_tree, module_name)
third_tree = get_tree(third_tree, module_name)
except (SyntaxError, FileNotFoundError):
continue"""

EXPECTED_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
except (SyntaxError, FileNotFoundError):
continue

await async_test()"""

def match(self, node: ast.AST) -> Iterator[InsertAfter]:
assert isinstance(node, ast.Try)
assert len(node.body) >= 2

await_st = ast.parse("await async_test()")
yield InsertAfter(node, cast(ast.stmt, await_st), separator=True)
new_try = common.clone(node)
new_try.body = [node.body[0]]
yield Replace(node, cast(ast.AST, new_try))


class TestInsertAfterBottomWithSeparatorWithBuild(Rule):
INPUT_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
first_tree = get_tree(first_tree, module_name)
second_tree = get_tree(second_tree, module_name)
third_tree = get_tree(third_tree, module_name)
except (SyntaxError, FileNotFoundError):
continue"""

EXPECTED_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
except (SyntaxError, FileNotFoundError):
continue

await async_test()"""

def match(self, node: ast.AST) -> Iterator[InsertAfter]:
assert isinstance(node, ast.Try)
assert len(node.body) >= 2

yield BuildInsertAfterBottom(node, separator=True)
new_try = common.clone(node)
new_try.body = [node.body[0]]
yield Replace(node, cast(ast.AST, new_try))


class TestInsertBeforeTop(Rule):
INPUT_SOURCE = """
try:
Expand Down Expand Up @@ -104,6 +197,35 @@ def match(self, node: ast.AST) -> Iterator[InsertBefore]:
yield Replace(node, cast(ast.AST, new_try))


class TestInsertBeforeTopWithSeparator(Rule):
INPUT_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
first_tree = get_tree(first_tree, module_name)
second_tree = get_tree(second_tree, module_name)
third_tree = get_tree(third_tree, module_name)
except (SyntaxError, FileNotFoundError):
continue"""

EXPECTED_SOURCE = """
await async_test()

try:
base_tree = get_tree(base_file, module_name)
except (SyntaxError, FileNotFoundError):
continue"""

def match(self, node: ast.AST) -> Iterator[InsertBefore]:
assert isinstance(node, ast.Try)
assert len(node.body) >= 2

await_st = ast.parse("await async_test()")
yield InsertBefore(node, cast(ast.stmt, await_st), separator=True)
new_try = common.clone(node)
new_try.body = [node.body[0]]
yield Replace(node, cast(ast.AST, new_try))


class TestInsertAfter(Rule):
INPUT_SOURCE = """
def generate_index(base_path, active_path):
Expand Down Expand Up @@ -489,7 +611,11 @@ def test_erase_invalid(invalid_node):
"rule",
[
TestInsertAfterBottom,
TestInsertAfterBottomWithBuild,
TestInsertAfterBottomWithSeparator,
TestInsertAfterBottomWithSeparatorWithBuild,
TestInsertBeforeTop,
TestInsertBeforeTopWithSeparator,
TestInsertAfter,
TestInsertBefore,
TestInsertAfterThenBefore,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_complete_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def collect(self, name, scope):
class AddNewImport(LazyInsertAfter):
module: str
names: list[str]
separator: bool = False

def build(self):
return ast.ImportFrom(
Expand All @@ -189,6 +190,7 @@ def build(self):
class AddNewImportBefore(LazyInsertBefore):
module: str
names: list[str]
separator: bool = False

def build(self):
return ast.ImportFrom(
Expand Down