diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py index 16d59b52..e128eeaa 100644 --- a/pre_commit_hooks/check_builtin_literals.py +++ b/pre_commit_hooks/check_builtin_literals.py @@ -26,36 +26,37 @@ class Call(NamedTuple): class Visitor(ast.NodeVisitor): def __init__( self, - ignore: Sequence[str] | None = None, + ignore: set[str], allow_dict_kwargs: bool = True, ) -> None: self.builtin_type_calls: list[Call] = [] - self.ignore = set(ignore) if ignore else set() self.allow_dict_kwargs = allow_dict_kwargs + self._disallowed = BUILTIN_TYPES.keys() - ignore def _check_dict_call(self, node: ast.Call) -> bool: return self.allow_dict_kwargs and bool(node.keywords) def visit_Call(self, node: ast.Call) -> None: - if not isinstance(node.func, ast.Name): + if ( # Ignore functions that are object attributes (`foo.bar()`). # Assume that if the user calls `builtins.list()`, they know what # they're doing. - return - if node.func.id not in set(BUILTIN_TYPES).difference(self.ignore): - return - if node.func.id == 'dict' and self._check_dict_call(node): - return - elif node.args: - return - self.builtin_type_calls.append( - Call(node.func.id, node.lineno, node.col_offset), - ) + isinstance(node.func, ast.Name) and + node.func.id in self._disallowed and + (node.func.id != 'dict' or not self._check_dict_call(node)) and + not node.args + ): + self.builtin_type_calls.append( + Call(node.func.id, node.lineno, node.col_offset), + ) + + self.generic_visit(node) def check_file( filename: str, - ignore: Sequence[str] | None = None, + *, + ignore: set[str], allow_dict_kwargs: bool = True, ) -> list[Call]: with open(filename, 'rb') as f: diff --git a/tests/check_builtin_literals_test.py b/tests/check_builtin_literals_test.py index 1b182573..de29063f 100644 --- a/tests/check_builtin_literals_test.py +++ b/tests/check_builtin_literals_test.py @@ -38,11 +38,6 @@ ''' -@pytest.fixture -def visitor(): - return Visitor() - - @pytest.mark.parametrize( ('expression', 'calls'), [ @@ -85,7 +80,8 @@ def visitor(): ('builtins.tuple()', []), ], ) -def test_non_dict_exprs(visitor, expression, calls): +def test_non_dict_exprs(expression, calls): + visitor = Visitor(ignore=set()) visitor.visit(ast.parse(expression)) assert visitor.builtin_type_calls == calls @@ -102,7 +98,8 @@ def test_non_dict_exprs(visitor, expression, calls): ('builtins.dict()', []), ], ) -def test_dict_allow_kwargs_exprs(visitor, expression, calls): +def test_dict_allow_kwargs_exprs(expression, calls): + visitor = Visitor(ignore=set()) visitor.visit(ast.parse(expression)) assert visitor.builtin_type_calls == calls @@ -114,17 +111,18 @@ def test_dict_allow_kwargs_exprs(visitor, expression, calls): ('dict(a=1, b=2, c=3)', [Call('dict', 1, 0)]), ("dict(**{'a': 1, 'b': 2, 'c': 3})", [Call('dict', 1, 0)]), ('builtins.dict()', []), + pytest.param('f(dict())', [Call('dict', 1, 2)], id='nested'), ], ) def test_dict_no_allow_kwargs_exprs(expression, calls): - visitor = Visitor(allow_dict_kwargs=False) + visitor = Visitor(ignore=set(), allow_dict_kwargs=False) visitor.visit(ast.parse(expression)) assert visitor.builtin_type_calls == calls def test_ignore_constructors(): visitor = Visitor( - ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'), + ignore={'complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'}, ) visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS)) assert visitor.builtin_type_calls == []