From e4f87debc26f80de68150d8c734b252e0e1e7b91 Mon Sep 17 00:00:00 2001 From: Stephen Finucane Date: Thu, 24 Oct 2024 12:23:38 +0100 Subject: [PATCH 1/2] Track import aliases via State Signed-off-by: Stephen Finucane --- pyupgrade/_data.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pyupgrade/_data.py b/pyupgrade/_data.py index 164d30b2..d5df0c68 100644 --- a/pyupgrade/_data.py +++ b/pyupgrade/_data.py @@ -27,6 +27,7 @@ class Settings(NamedTuple): class State(NamedTuple): settings: Settings from_imports: dict[str, set[str]] + as_imports: dict[str, str] in_annotation: bool = False @@ -34,7 +35,7 @@ class State(NamedTuple): TokenFunc = Callable[[int, list[Token]], None] ASTFunc = Callable[[State, AST_T, ast.AST], Iterable[tuple[Offset, TokenFunc]]] -RECORD_FROM_IMPORTS = frozenset(( +RECORDED_IMPORTS = frozenset(( '__future__', 'asyncio', 'collections', @@ -75,6 +76,7 @@ def visit( initial_state = State( settings=settings, from_imports=collections.defaultdict(set), + as_imports=collections.defaultdict(), ) nodes: list[tuple[State, ast.AST, ast.AST]] = [(initial_state, tree, tree)] @@ -91,11 +93,18 @@ def visit( if ( isinstance(node, ast.ImportFrom) and not node.level and - node.module in RECORD_FROM_IMPORTS + node.module in RECORDED_IMPORTS ): state.from_imports[node.module].update( name.name for name in node.names if not name.asname ) + elif ( + isinstance(node, ast.Import) + ): + state.as_imports.update({ + x.asname or x.name: x.name for x in node.names + if x.name in RECORDED_IMPORTS + }) for name in reversed(node._fields): value = getattr(node, name) From 3bc5f61d0800667fc6735d94905d069c7f962eeb Mon Sep 17 00:00:00 2001 From: Stephen Finucane Date: Thu, 24 Oct 2024 12:25:35 +0100 Subject: [PATCH 2/2] Support pep585 rewriting when using alias State.as_import probably has uses elsewhere, but this is sufficient to start with. We need to update some of the pep604 feature tests to ensure we actually do the import. Signed-off-by: Stephen Finucane --- pyupgrade/_plugins/typing_pep585.py | 2 +- tests/features/typing_pep585_test.py | 9 +++++++++ tests/features/typing_pep604_test.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pyupgrade/_plugins/typing_pep585.py b/pyupgrade/_plugins/typing_pep585.py index 77dbd7e3..1b38cb48 100644 --- a/pyupgrade/_plugins/typing_pep585.py +++ b/pyupgrade/_plugins/typing_pep585.py @@ -36,7 +36,7 @@ def visit_Attribute( if ( _should_rewrite(state) and isinstance(node.value, ast.Name) and - node.value.id == 'typing' and + state.as_imports.get(node.value.id) == 'typing' and node.attr in PEP585_BUILTINS ): func = functools.partial( diff --git a/tests/features/typing_pep585_test.py b/tests/features/typing_pep585_test.py index 93ca79b2..90a76fe8 100644 --- a/tests/features/typing_pep585_test.py +++ b/tests/features/typing_pep585_test.py @@ -84,6 +84,15 @@ def f(x: list[str]) -> None: ... id='import of typing + typing.List', ), + pytest.param( + 'import typing as ty\n' + 'x: ty.List[int]\n', + + 'import typing as ty\n' + 'x: list[int]\n', + + id='aliased import of typing + typing.List', + ), pytest.param( 'from typing import List\n' 'SomeAlias = List[int]\n', diff --git a/tests/features/typing_pep604_test.py b/tests/features/typing_pep604_test.py index eff5bf31..4f91f431 100644 --- a/tests/features/typing_pep604_test.py +++ b/tests/features/typing_pep604_test.py @@ -105,46 +105,58 @@ def f(x: int | str) -> None: ... id='Union rewrite', ), pytest.param( + 'import typing\n' 'x: typing.Union[int]\n', + 'import typing\n' 'x: int\n', id='Union of only one value', ), pytest.param( + 'import typing\n' 'x: typing.Union[Foo[str, int], str]\n', + 'import typing\n' 'x: Foo[str, int] | str\n', id='Union containing a value with brackets', ), pytest.param( + 'import typing\n' 'x: typing.Union[typing.List[str], str]\n', + 'import typing\n' 'x: list[str] | str\n', id='Union containing pep585 rewritten type', ), pytest.param( + 'import typing\n' 'x: typing.Union[int, str,]\n', + 'import typing\n' 'x: int | str\n', id='Union trailing comma', ), pytest.param( + 'import typing\n' 'x: typing.Union[(int, str)]\n', + 'import typing\n' 'x: int | str\n', id='Union, parenthesized tuple', ), pytest.param( + 'import typing\n' 'x: typing.Union[\n' ' int,\n' ' str\n' ']\n', + 'import typing\n' 'x: (\n' ' int |\n' ' str\n' @@ -153,11 +165,13 @@ def f(x: int | str) -> None: ... id='Union multiple lines', ), pytest.param( + 'import typing\n' 'x: typing.Union[\n' ' int,\n' ' str,\n' ']\n', + 'import typing\n' 'x: (\n' ' int |\n' ' str\n' @@ -175,10 +189,12 @@ def f(x: int | str) -> None: ... id='Optional rewrite', ), pytest.param( + 'import typing\n' 'x: typing.Optional[\n' ' ComplicatedLongType[int]\n' ']\n', + 'import typing\n' 'x: None | (\n' ' ComplicatedLongType[int]\n' ')\n',