Skip to content

Commit

Permalink
Support ParamSpec for TypeAliasType (#449)
Browse files Browse the repository at this point in the history

Co-authored-by: Alex Waygood <[email protected]>
  • Loading branch information
Daraan and AlexWaygood authored Nov 26, 2024
1 parent b7d6353 commit f2d0667
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ aliases that have a `Concatenate` special form as their argument.
- Backport CPython PR [#124795](https://github.com/python/cpython/pull/124795):
fix `TypeAliasType` not raising an error on non-tuple inputs for `type_params`.
Patch by [Daraan](https://github.com/Daraan).
- Fix that lists and ... could not be used for parameter expressions for `TypeAliasType`
instances before Python 3.11.
Patch by [Daraan](https://github.com/Daraan).

# Release 4.12.2 (June 7, 2024)

Expand Down
193 changes: 191 additions & 2 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7267,6 +7267,80 @@ def test_attributes(self):
self.assertEqual(Variadic.__type_params__, (Ts,))
self.assertEqual(Variadic.__parameters__, tuple(iter(Ts)))

P = ParamSpec('P')
CallableP = TypeAliasType("CallableP", Callable[P, Any], type_params=(P, ))
self.assertEqual(CallableP.__name__, "CallableP")
self.assertEqual(CallableP.__value__, Callable[P, Any])
self.assertEqual(CallableP.__type_params__, (P,))
self.assertEqual(CallableP.__parameters__, (P,))

def test_alias_types_and_substitutions(self):
T = TypeVar('T')
T2 = TypeVar('T2')
T_default = TypeVar("T_default", default=int)
Ts = TypeVarTuple("Ts")
P = ParamSpec('P')

test_argument_cases = {
# arguments : expected parameters
int : (),
... : (),
None : (),
T2 : (T2,),
Union[int, List[T2]] : (T2,),
Tuple[int, str] : (),
Tuple[T, T_default, T2] : (T, T_default, T2),
Tuple[Unpack[Ts]] : (Ts,),
Callable[[Unpack[Ts]], T2] : (Ts, T2),
Callable[P, T2] : (P, T2),
Callable[Concatenate[T2, P], T_default] : (T2, P, T_default),
TypeAliasType("NestedAlias", List[T], type_params=(T,))[T2] : (T2,),
Unpack[Ts] : (Ts,),
Unpack[Tuple[int, T2]] : (T2,),
Concatenate[int, P] : (P,),
# Not tested usage of bare TypeVarTuple, would need 3.11+
# Ts : (Ts,), # invalid case
}

test_alias_cases = [
# Simple cases
TypeAliasType("ListT", List[T], type_params=(T,)),
TypeAliasType("UnionT", Union[int, List[T]], type_params=(T,)),
# Value has no parameter but in type_param
TypeAliasType("ValueWithoutT", int, type_params=(T,)),
# Callable
TypeAliasType("CallableP", Callable[P, Any], type_params=(P, )),
TypeAliasType("CallableT", Callable[..., T], type_params=(T, )),
TypeAliasType("CallableTs", Callable[[Unpack[Ts]], Any], type_params=(Ts, )),
# TypeVarTuple
TypeAliasType("Variadic", Tuple[int, Unpack[Ts]], type_params=(Ts,)),
# TypeVar with default
TypeAliasType("TupleT_default", Tuple[T_default, T], type_params=(T, T_default)),
TypeAliasType("CallableT_default", Callable[[T], T_default], type_params=(T, T_default)),
]

for alias in test_alias_cases:
with self.subTest(alias=alias, args=[]):
subscripted = alias[[]]
self.assertEqual(get_args(subscripted), ([],))
self.assertEqual(subscripted.__parameters__, ())
with self.subTest(alias=alias, args=()):
subscripted = alias[()]
self.assertEqual(get_args(subscripted), ())
self.assertEqual(subscripted.__parameters__, ())
with self.subTest(alias=alias, args=(int, float)):
subscripted = alias[int, float]
self.assertEqual(get_args(subscripted), (int, float))
self.assertEqual(subscripted.__parameters__, ())
with self.subTest(alias=alias, args=[int, float]):
subscripted = alias[[int, float]]
self.assertEqual(get_args(subscripted), ([int, float],))
self.assertEqual(subscripted.__parameters__, ())
for expected_args, expected_parameters in test_argument_cases.items():
with self.subTest(alias=alias, args=expected_args):
self.assertEqual(get_args(alias[expected_args]), (expected_args,))
self.assertEqual(alias[expected_args].__parameters__, expected_parameters)

def test_cannot_set_attributes(self):
Simple = TypeAliasType("Simple", int)
with self.assertRaisesRegex(AttributeError, "readonly attribute"):
Expand Down Expand Up @@ -7327,12 +7401,19 @@ def test_or(self):
Alias | "Ref"

def test_getitem(self):
T = TypeVar('T')
ListOrSetT = TypeAliasType("ListOrSetT", Union[List[T], Set[T]], type_params=(T,))
subscripted = ListOrSetT[int]
self.assertEqual(get_args(subscripted), (int,))
self.assertIs(get_origin(subscripted), ListOrSetT)
with self.assertRaises(TypeError):
subscripted[str]
with self.assertRaisesRegex(TypeError,
"not a generic class"
# types.GenericAlias raises a different error in 3.10
if sys.version_info[:2] != (3, 10)
else "There are no type variables left in ListOrSetT"
):
subscripted[int]


still_generic = ListOrSetT[Iterable[T]]
self.assertEqual(get_args(still_generic), (Iterable[T],))
Expand All @@ -7341,6 +7422,114 @@ def test_getitem(self):
self.assertEqual(get_args(fully_subscripted), (Iterable[float],))
self.assertIs(get_origin(fully_subscripted), ListOrSetT)

ValueWithoutTypeVar = TypeAliasType("ValueWithoutTypeVar", int, type_params=(T,))
still_subscripted = ValueWithoutTypeVar[str]
self.assertEqual(get_args(still_subscripted), (str,))

def test_callable_without_concatenate(self):
P = ParamSpec('P')
CallableP = TypeAliasType("CallableP", Callable[P, Any], type_params=(P,))
get_args_test_cases = [
# List of (alias, expected_args)
# () -> Any
(CallableP[()], ()),
(CallableP[[]], ([],)),
# (int) -> Any
(CallableP[int], (int,)),
(CallableP[[int]], ([int],)),
# (int, int) -> Any
(CallableP[int, int], (int, int)),
(CallableP[[int, int]], ([int, int],)),
# (...) -> Any
(CallableP[...], (...,)),
# (int, ...) -> Any
(CallableP[[int, ...]], ([int, ...],)),
]

for index, (expression, expected_args) in enumerate(get_args_test_cases):
with self.subTest(index=index, expression=expression):
self.assertEqual(get_args(expression), expected_args)

self.assertEqual(CallableP[...], CallableP[(...,)])
# (T) -> Any
CallableT = CallableP[T]
self.assertEqual(get_args(CallableT), (T,))
self.assertEqual(CallableT.__parameters__, (T,))

def test_callable_with_concatenate(self):
P = ParamSpec('P')
P2 = ParamSpec('P2')
CallableP = TypeAliasType("CallableP", Callable[P, Any], type_params=(P,))

callable_concat = CallableP[Concatenate[int, P2]]
self.assertEqual(callable_concat.__parameters__, (P2,))
concat_usage = callable_concat[str]
with self.subTest("get_args of Concatenate in TypeAliasType"):
if not TYPING_3_9_0:
# args are: ([<class 'int'>, ~P2],)
self.skipTest("Nested ParamSpec is not substituted")
if sys.version_info < (3, 10, 2):
self.skipTest("GenericAlias keeps Concatenate in __args__ prior to 3.10.2")
self.assertEqual(get_args(concat_usage), ((int, str),))
with self.subTest("Equality of parameter_expression without []"):
if not TYPING_3_10_0:
self.skipTest("Nested list is invalid type form")
self.assertEqual(concat_usage, callable_concat[[str]])

def test_substitution(self):
T = TypeVar('T')
Ts = TypeVarTuple("Ts")

CallableTs = TypeAliasType("CallableTs", Callable[[Unpack[Ts]], Any], type_params=(Ts, ))
unpack_callable = CallableTs[Unpack[Tuple[int, T]]]
self.assertEqual(get_args(unpack_callable), (Unpack[Tuple[int, T]],))

P = ParamSpec('P')
CallableP = TypeAliasType("CallableP", Callable[P, T], type_params=(P, T))
callable_concat = CallableP[Concatenate[int, P], Any]
self.assertEqual(get_args(callable_concat), (Concatenate[int, P], Any))

def test_wrong_amount_of_parameters(self):
T = TypeVar('T')
T2 = TypeVar("T2")
P = ParamSpec('P')
ListOrSetT = TypeAliasType("ListOrSetT", Union[List[T], Set[T]], type_params=(T,))
TwoT = TypeAliasType("TwoT", Union[List[T], Set[T2]], type_params=(T, T2))
CallablePT = TypeAliasType("CallablePT", Callable[P, T], type_params=(P, T))

# Not enough parameters
test_cases = [
# not_enough
(TwoT[int], [(int,), ()]),
(TwoT[T], [(T,), (T,)]),
# callable and not enough
(CallablePT[int], [(int,), ()]),
# too many
(ListOrSetT[int, bool], [(int, bool), ()]),
# callable and too many
(CallablePT[str, float, int], [(str, float, int), ()]),
# Check if TypeVar is still present even if over substituted
(ListOrSetT[int, T], [(int, T), (T,)]),
# With and without list for ParamSpec
(CallablePT[str, float, T], [(str, float, T), (T,)]),
(CallablePT[[str], float, int, T2], [([str], float, int, T2), (T2,)]),
]

for index, (alias, [expected_args, expected_params]) in enumerate(test_cases):
with self.subTest(index=index, alias=alias):
self.assertEqual(get_args(alias), expected_args)
self.assertEqual(alias.__parameters__, expected_params)

# The condition should align with the version of GeneriAlias usage in __getitem__ or be 3.11+
@skipIf(TYPING_3_10_0, "Most arguments are allowed in 3.11+ or with GenericAlias")
def test_invalid_cases_before_3_10(self):
T = TypeVar('T')
ListOrSetT = TypeAliasType("ListOrSetT", Union[List[T], Set[T]], type_params=(T,))
with self.assertRaises(TypeError):
ListOrSetT[Generic[T]]
with self.assertRaises(TypeError):
ListOrSetT[(Generic[T], )]

def test_unpack_parameter_collection(self):
Ts = TypeVarTuple("Ts")

Expand Down
42 changes: 35 additions & 7 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3662,6 +3662,33 @@ def _raise_attribute_error(self, name: str) -> Never:
def __repr__(self) -> str:
return self.__name__

if sys.version_info < (3, 11):
def _check_single_param(self, param, recursion=0):
# Allow [], [int], [int, str], [int, ...], [int, T]
if param is ...:
return ...
if param is None:
return None
# Note in <= 3.9 _ConcatenateGenericAlias inherits from list
if isinstance(param, list) and recursion == 0:
return [self._check_single_param(arg, recursion+1)
for arg in param]
return typing._type_check(
param, f'Subscripting {self.__name__} requires a type.'
)

def _check_parameters(self, parameters):
if sys.version_info < (3, 11):
return tuple(
self._check_single_param(item)
for item in parameters
)
return tuple(typing._type_check(
item, f'Subscripting {self.__name__} requires a type.'
)
for item in parameters
)

def __getitem__(self, parameters):
if not self.__type_params__:
raise TypeError("Only generic type aliases are subscriptable")
Expand All @@ -3670,13 +3697,14 @@ def __getitem__(self, parameters):
# Using 3.9 here will create problems with Concatenate
if sys.version_info >= (3, 10):
return _types.GenericAlias(self, parameters)
parameters = tuple(
typing._type_check(
item, f'Subscripting {self.__name__} requires a type.'
)
for item in parameters
)
return _TypeAliasGenericAlias(self, parameters)
type_vars = _collect_type_vars(parameters)
parameters = self._check_parameters(parameters)
alias = _TypeAliasGenericAlias(self, parameters)
# alias.__parameters__ is not complete if Concatenate is present
# as it is converted to a list from which no parameters are extracted.
if alias.__parameters__ != type_vars:
alias.__parameters__ = type_vars
return alias

def __reduce__(self):
return self.__name__
Expand Down

0 comments on commit f2d0667

Please sign in to comment.