Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions pre_commit_hooks/check_builtin_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,37 @@ class Call(NamedTuple):
class Visitor(ast.NodeVisitor):
def __init__(
self,
ignore: Sequence[str] | None = None,
ignore: set[str],
allow_dict_kwargs: bool = True,
) -> None:
self.builtin_type_calls: list[Call] = []
self.ignore = set(ignore) if ignore else set()
self.allow_dict_kwargs = allow_dict_kwargs
self._disallowed = BUILTIN_TYPES.keys() - ignore

def _check_dict_call(self, node: ast.Call) -> bool:
return self.allow_dict_kwargs and bool(node.keywords)

def visit_Call(self, node: ast.Call) -> None:
if not isinstance(node.func, ast.Name):
if (
# Ignore functions that are object attributes (`foo.bar()`).
# Assume that if the user calls `builtins.list()`, they know what
# they're doing.
return
if node.func.id not in set(BUILTIN_TYPES).difference(self.ignore):
return
if node.func.id == 'dict' and self._check_dict_call(node):
return
elif node.args:
return
self.builtin_type_calls.append(
Call(node.func.id, node.lineno, node.col_offset),
)
isinstance(node.func, ast.Name) and
node.func.id in self._disallowed and
(node.func.id != 'dict' or not self._check_dict_call(node)) and
not node.args
):
self.builtin_type_calls.append(
Call(node.func.id, node.lineno, node.col_offset),
)

self.generic_visit(node)


def check_file(
filename: str,
ignore: Sequence[str] | None = None,
*,
ignore: set[str],
allow_dict_kwargs: bool = True,
) -> list[Call]:
with open(filename, 'rb') as f:
Expand Down
16 changes: 7 additions & 9 deletions tests/check_builtin_literals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@
'''


@pytest.fixture
def visitor():
return Visitor()


@pytest.mark.parametrize(
('expression', 'calls'),
[
Expand Down Expand Up @@ -85,7 +80,8 @@ def visitor():
('builtins.tuple()', []),
],
)
def test_non_dict_exprs(visitor, expression, calls):
def test_non_dict_exprs(expression, calls):
visitor = Visitor(ignore=set())
visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls

Expand All @@ -102,7 +98,8 @@ def test_non_dict_exprs(visitor, expression, calls):
('builtins.dict()', []),
],
)
def test_dict_allow_kwargs_exprs(visitor, expression, calls):
def test_dict_allow_kwargs_exprs(expression, calls):
visitor = Visitor(ignore=set())
visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls

Expand All @@ -114,17 +111,18 @@ def test_dict_allow_kwargs_exprs(visitor, expression, calls):
('dict(a=1, b=2, c=3)', [Call('dict', 1, 0)]),
("dict(**{'a': 1, 'b': 2, 'c': 3})", [Call('dict', 1, 0)]),
('builtins.dict()', []),
pytest.param('f(dict())', [Call('dict', 1, 2)], id='nested'),
],
)
def test_dict_no_allow_kwargs_exprs(expression, calls):
visitor = Visitor(allow_dict_kwargs=False)
visitor = Visitor(ignore=set(), allow_dict_kwargs=False)
visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls


def test_ignore_constructors():
visitor = Visitor(
ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'),
ignore={'complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'},
)
visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS))
assert visitor.builtin_type_calls == []
Expand Down