Skip to content

Commit

Permalink
Add missing overload for Task.__call__ (#16891)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Jan 29, 2025
1 parent 71fb263 commit 9514808
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
14 changes: 8 additions & 6 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ async def create_local_run(
def __call__(
self: "Task[P, NoReturn]",
*args: P.args,
return_state: Literal[False],
return_state: Literal[False] = False,
wait_for: Optional[OneOrManyFutureOrResult[Any]] = None,
**kwargs: P.kwargs,
) -> None:
Expand All @@ -977,20 +977,22 @@ def __call__(
def __call__(
self: "Task[P, R]",
*args: P.args,
return_state: Literal[True],
wait_for: Optional[OneOrManyFutureOrResult[Any]] = None,
**kwargs: P.kwargs,
) -> State[R]:
) -> R:
...

# Keyword parameters `return_state` and `wait_for` aren't allowed after the
# ParamSpec `*args` parameter, so we lose return type typing when either of
# those are provided.
# TODO: Find a way to expose this functionality without losing type information
@overload
def __call__(
self: "Task[P, R]",
*args: P.args,
return_state: Literal[False],
return_state: Literal[True] = True,
wait_for: Optional[OneOrManyFutureOrResult[Any]] = None,
**kwargs: P.kwargs,
) -> R:
) -> State[R]:
...

@overload
Expand Down
14 changes: 13 additions & 1 deletion tests/typesafety/test_flows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,16 @@
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.flows.Flow[[bar: builtins.str], builtins.int]\
\""
\""

- case: prefect_flow_call
main: |
from prefect import flow
@flow
def foo(bar: str) -> int:
return 42
ret = foo(bar="baz")
reveal_type(ret)
out: "main:6: note: Revealed type is \"\
builtins.int\
\""
14 changes: 13 additions & 1 deletion tests/typesafety/test_tasks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,16 @@
reveal_type(foo)
out: "main:9: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""
\""

- case: prefect_task_call
main: |
from prefect import task
@task
def foo(bar: str) -> int:
return 42
ret = foo(bar="baz")
reveal_type(ret)
out: "main:6: note: Revealed type is \"\
builtins.int\
\""

0 comments on commit 9514808

Please sign in to comment.