Skip to content

Commit

Permalink
Added customizable generator options to celery extension to better ma…
Browse files Browse the repository at this point in the history
…tch capabilities of the correlation id middleware

* Added optional "header_key" param to asgi_correlation_id.extensions.celery.load_correlation_id
* Added optional "generator" param to asgi_correlation_id.extensions.celery.load_correlation_id
* Added optional "generator" param to asgi_correlation_id.extensions.celery.load_celery_current_and_parent_ids
  • Loading branch information
David Pryor (dapryor) committed Sep 9, 2022
1 parent e9879ce commit 75d6587
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions asgi_correlation_id/extensions/celery.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict
from uuid import uuid4

from celery.signals import before_task_publish, task_postrun, task_prerun
Expand All @@ -8,8 +8,12 @@
if TYPE_CHECKING:
from celery import Task

uuid_hex_generator_fn: Callable[[], str] = lambda: uuid4().hex

def load_correlation_ids() -> None:

def load_correlation_ids(
header_key: str = 'CORRELATION_ID', generator: Callable[[], str] = uuid_hex_generator_fn
) -> None:
"""
Transfer correlation IDs from a HTTP request to a Celery worker,
when spawned from a request.
Expand All @@ -18,7 +22,6 @@ def load_correlation_ids() -> None:
"""
from asgi_correlation_id.context import correlation_id

header_key = 'CORRELATION_ID'
sentry_extension = get_sentry_extension()

@before_task_publish.connect(weak=False)
Expand Down Expand Up @@ -46,7 +49,7 @@ def load_correlation_id(task: 'Task', **kwargs: Any) -> None:
correlation_id.set(id_value)
sentry_extension(id_value)
else:
generated_correlation_id = uuid4().hex
generated_correlation_id = generator()
correlation_id.set(generated_correlation_id)
sentry_extension(generated_correlation_id)

Expand All @@ -61,7 +64,9 @@ def cleanup(**kwargs: Any) -> None:
correlation_id.set(None)


def load_celery_current_and_parent_ids(header_key: str = 'CELERY_PARENT_ID') -> None:
def load_celery_current_and_parent_ids(
header_key: str = 'CELERY_PARENT_ID', generator: Callable[[], str] = uuid_hex_generator_fn
) -> None:
"""
Configure Celery event hooks for generating tracing IDs with depth.
Expand Down Expand Up @@ -91,7 +96,7 @@ def worker_prerun(task: 'Task', **kwargs: Any) -> None:
if parent_id:
celery_parent_id.set(parent_id)

celery_current_id.set(uuid4().hex)
celery_current_id.set(generator())

@task_postrun.connect(weak=False)
def clean_up(**kwargs: Any) -> None:
Expand Down

0 comments on commit 75d6587

Please sign in to comment.