Skip to content

Commit

Permalink
Merge pull request #67 from sminnee/feature/progress
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Sep 26, 2024
2 parents e507271 + 8a97ced commit 254fae4
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
hooks:
- id: black
name: Format with Black
entry: black
entry: poetry run black
language: system
types: [python]

Expand All @@ -36,6 +36,6 @@ repos:

- id: mypy
name: Validate types with MyPy
entry: mypy
entry: poetry run mypy
language: system
types: [ python ]
148 changes: 148 additions & 0 deletions taskiq_redis/redis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from taskiq.abc.result_backend import TaskiqResult
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.depends.progress_tracker import TaskProgress
from taskiq.serializers import PickleSerializer

from taskiq_redis.exceptions import (
Expand All @@ -41,6 +42,8 @@

_ReturnType = TypeVar("_ReturnType")

PROGRESS_KEY_SUFFIX = "__progress"


class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
"""Async result based on redis."""
Expand Down Expand Up @@ -174,6 +177,55 @@ async def get_result(

return taskiq_result

async def set_progress(
self,
task_id: str,
progress: TaskProgress[_ReturnType],
) -> None:
"""
Sets task progress in redis.
Dumps TaskProgress instance into the bytes and writes
it to redis with a standard suffix on the task_id as the key
:param task_id: ID of the task.
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
elif self.result_px_time:
redis_set_params["px"] = self.result_px_time

async with Redis(connection_pool=self.redis_pool) as redis:
await redis.set(**redis_set_params) # type: ignore

async def get_progress(
self,
task_id: str,
) -> Union[TaskProgress[_ReturnType], None]:
"""
Gets progress results from the task.
:param task_id: task's id.
:return: task's TaskProgress instance.
"""
async with Redis(connection_pool=self.redis_pool) as redis:
result_value = await redis.get(
name=task_id + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
return None

return model_validate(
TaskProgress[_ReturnType],
self.serializer.loadb(result_value),
)


class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
"""Async result backend based on redis cluster."""
Expand Down Expand Up @@ -301,6 +353,53 @@ async def get_result(

return taskiq_result

async def set_progress(
self,
task_id: str,
progress: TaskProgress[_ReturnType],
) -> None:
"""
Sets task progress in redis.
Dumps TaskProgress instance into the bytes and writes
it to redis with a standard suffix on the task_id as the key
:param task_id: ID of the task.
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
elif self.result_px_time:
redis_set_params["px"] = self.result_px_time

await self.redis.set(**redis_set_params) # type: ignore

async def get_progress(
self,
task_id: str,
) -> Union[TaskProgress[_ReturnType], None]:
"""
Gets progress results from the task.
:param task_id: task's id.
:return: task's TaskProgress instance.
"""
result_value = await self.redis.get( # type: ignore[attr-defined]
name=task_id + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
return None

return model_validate(
TaskProgress[_ReturnType],
self.serializer.loadb(result_value),
)


class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
"""Async result based on redis sentinel."""
Expand Down Expand Up @@ -439,6 +538,55 @@ async def get_result(

return taskiq_result

async def set_progress(
self,
task_id: str,
progress: TaskProgress[_ReturnType],
) -> None:
"""
Sets task progress in redis.
Dumps TaskProgress instance into the bytes and writes
it to redis with a standard suffix on the task_id as the key
:param task_id: ID of the task.
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
elif self.result_px_time:
redis_set_params["px"] = self.result_px_time

async with self._acquire_master_conn() as redis:
await redis.set(**redis_set_params) # type: ignore

async def get_progress(
self,
task_id: str,
) -> Union[TaskProgress[_ReturnType], None]:
"""
Gets progress results from the task.
:param task_id: task's id.
:return: task's TaskProgress instance.
"""
async with self._acquire_master_conn() as redis:
result_value = await redis.get(
name=task_id + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
return None

return model_validate(
TaskProgress[_ReturnType],
self.serializer.loadb(result_value),
)

async def shutdown(self) -> None:
"""Shutdown sentinel connections."""
for sentinel in self.sentinel.sentinels:
Expand Down
122 changes: 122 additions & 0 deletions tests/test_result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from taskiq import TaskiqResult
from taskiq.depends.progress_tracker import TaskProgress, TaskState

from taskiq_redis import (
RedisAsyncClusterResultBackend,
Expand Down Expand Up @@ -438,3 +439,124 @@ async def test_keep_results_after_reading_sentinel(
res2 = await result_backend.get_result(task_id=task_id)
assert res1 == res2
await result_backend.shutdown()


@pytest.mark.anyio
async def test_set_progress(redis_url: str) -> None:
"""
Test that set_progress/get_progress works.
:param redis_url: redis URL.
"""
result_backend = RedisAsyncResultBackend( # type: ignore
redis_url=redis_url,
)
task_id = uuid.uuid4().hex

test_progress_1 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "quarter way", "pct": 25},
)
test_progress_2 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "half way", "pct": 50},
)

# Progress starts as None
assert await result_backend.get_progress(task_id=task_id) is None

# Setting the first time persists
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_1

# Setting the second time replaces the first
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_2

await result_backend.shutdown()


@pytest.mark.anyio
async def test_set_progress_cluster(redis_cluster_url: str) -> None:
"""
Test that set_progress/get_progress works in cluster mode.
:param redis_url: redis URL.
"""
result_backend = RedisAsyncClusterResultBackend( # type: ignore
redis_url=redis_cluster_url,
)
task_id = uuid.uuid4().hex

test_progress_1 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "quarter way", "pct": 25},
)
test_progress_2 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "half way", "pct": 50},
)

# Progress starts as None
assert await result_backend.get_progress(task_id=task_id) is None

# Setting the first time persists
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_1

# Setting the second time replaces the first
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_2

await result_backend.shutdown()


@pytest.mark.anyio
async def test_set_progress_sentinel(
redis_sentinels: List[Tuple[str, int]],
redis_sentinel_master_name: str,
) -> None:
"""
Test that set_progress/get_progress works in cluster mode.
:param redis_url: redis URL.
"""
result_backend = RedisAsyncSentinelResultBackend( # type: ignore
sentinels=redis_sentinels,
master_name=redis_sentinel_master_name,
)
task_id = uuid.uuid4().hex

test_progress_1 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "quarter way", "pct": 25},
)
test_progress_2 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "half way", "pct": 50},
)

# Progress starts as None
assert await result_backend.get_progress(task_id=task_id) is None

# Setting the first time persists
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_1

# Setting the second time replaces the first
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_2

await result_backend.shutdown()

0 comments on commit 254fae4

Please sign in to comment.