Skip to content

Commit

Permalink
Extend type checking to tuple and type annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Jun 7, 2024
1 parent a1ebf69 commit fe0ea6f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
36 changes: 28 additions & 8 deletions spec_classes/utils/type_checking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect
import numbers
import sys
import types
from collections.abc import Sequence as SequenceMutator
from collections.abc import Set as SetMutator
from typing import (
Expand Down Expand Up @@ -37,36 +39,54 @@ def check_type(value: Any, attr_type: Type) -> bool:
"""
Check whether a given object `value` matches the provided `attr_type`.
"""
if attr_type is Any:
if attr_type is Any or isinstance(attr_type, TypeVar):
return True

if attr_type is float:
attr_type = numbers.Real

if sys.version_info >= (3, 10) and isinstance(attr_type, types.UnionType):
return any(check_type(value, type_) for type_ in attr_type.__args__)

if hasattr(attr_type, "__origin__"): # we are dealing with a `typing` object.
if attr_type.__origin__ is Union:
return any(check_type(value, type_) for type_ in attr_type.__args__)

if attr_type.__origin__ in (Literal, LiteralExtension):
return value in attr_type.__args__

if isinstance(attr_type, _GenericAlias):
if (
isinstance(attr_type, _GenericAlias)
or sys.version_info >= (3, 9)
and isinstance(attr_type, types.GenericAlias)
):
if not isinstance(value, attr_type.__origin__):
return False
if attr_type._name in ("List", "Set") and not isinstance(
attr_type.__args__[0], TypeVar
): # pylint: disable=protected-access
if attr_type.__origin__ in (list, set):
for item in value:
if not check_type(item, attr_type.__args__[0]):
return False
elif attr_type._name == "Dict" and not isinstance(
attr_type.__args__[0], TypeVar
): # pylint: disable=protected-access
elif attr_type.__origin__ == dict:
for k, v in value.items():
if not check_type(k, attr_type.__args__[0]):
return False
if not check_type(v, attr_type.__args__[1]):
return False
elif attr_type.__origin__ == tuple:
if len(attr_type.__args__) == 2 and attr_type.__args__[1] is Ellipsis:
for item in value:
if not check_type(item, attr_type.__args__[0]):
return False
else:
if len(value) != len(attr_type.__args__):
return False
for i, item in enumerate(value):
if not check_type(item, attr_type.__args__[i]):
return False
elif attr_type.__origin__ == type:
if not issubclass(value, attr_type.__args__[0]):
return False

return True

return isinstance(
Expand Down
29 changes: 28 additions & 1 deletion tests/utils/test_type_checking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, List, Set, TypeVar, Union
import sys
from typing import Any, Callable, Dict, List, Set, Tuple, Type, TypeVar, Union

from typing_extensions import Literal

Expand Down Expand Up @@ -38,6 +39,11 @@ def test_type_checking(self):
assert check_type(["a", "b"], List[str])
assert not check_type([1, 2], List[str])

assert check_type((), Tuple)
assert check_type((1, "a"), Tuple[int, str])
assert not check_type((1,), Tuple[str])
assert not check_type(("1", 2), Tuple[str, ...])

assert check_type({}, Dict)
assert not check_type("a", Dict)
assert check_type({"a": 1, "b": 2}, Dict[str, int])
Expand All @@ -56,9 +62,30 @@ def test_type_checking(self):
assert check_type("hi", Literal["hi"])
assert not check_type(1, Literal["hi"])

class MyType:
pass

class SubType(MyType):
pass

assert check_type(MyType, Type[MyType])
assert check_type(SubType, Type[MyType])

assert check_type(1, float)
assert not check_type(1.0, int)

if sys.version_info >= (3, 9):
assert check_type(["a", "b"], list[str])
assert check_type((1, 2, 3, 4), tuple[int, ...])
assert check_type((1, 2, 3), tuple[int, int, float])
assert not check_type((1, 2, 3), tuple[int, int])
assert not check_type({1: "a", 2: "b"}, dict[str, int])
assert not check_type({1, 2}, set[str])
assert not check_type(str, type[MyType])

if sys.version_info >= (3, 10):
assert check_type([1, "a"], list[str | int])

def test_get_collection_item_type(self):
assert get_collection_item_type(list) is Any
assert get_collection_item_type(List) is Any
Expand Down

0 comments on commit fe0ea6f

Please sign in to comment.