diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f64b8b9..c1ba380 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,4 +68,4 @@ jobs: with: file: ./coverage.xml fail_ci_if_error: true - if: matrix.python-version == '3.10' + if: matrix.python-version == '3.10.6' diff --git a/README.md b/README.md index 1aff9d1..fb20e3e 100644 --- a/README.md +++ b/README.md @@ -484,6 +484,15 @@ load_correlation_ids() + load_celery_current_and_parent_ids() ``` +If you wish to correlate celery task IDs through the IDs found in your broker (i.e., the celery `task_id`), use the `use_internal_celery_task_id` argument on `load_celery_current_and_parent_ids` +```diff +from asgi_correlation_id.extensions.celery import load_correlation_ids, load_celery_current_and_parent_ids + +load_correlation_ids() ++ load_celery_current_and_parent_ids(use_internal_celery_task_id=True) +``` +Note: `load_celery_current_and_parent_ids` will ignore the `generator` argument when `use_internal_celery_task_id` is set to `True` + To set up the additional log filters, update your log config like this: ```diff diff --git a/asgi_correlation_id/extensions/celery.py b/asgi_correlation_id/extensions/celery.py index a44b9f0..d0e6e85 100644 --- a/asgi_correlation_id/extensions/celery.py +++ b/asgi_correlation_id/extensions/celery.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Callable, Dict from uuid import uuid4 from celery.signals import before_task_publish, task_postrun, task_prerun @@ -8,8 +8,10 @@ if TYPE_CHECKING: from celery import Task +uuid_hex_generator: Callable[[], str] = lambda: uuid4().hex -def load_correlation_ids() -> None: + +def load_correlation_ids(header_key: str = 'CORRELATION_ID', generator: Callable[[], str] = uuid_hex_generator) -> None: """ Transfer correlation IDs from a HTTP request to a Celery worker, when spawned from a request. @@ -18,7 +20,6 @@ def load_correlation_ids() -> None: """ from asgi_correlation_id.context import correlation_id - header_key = 'CORRELATION_ID' sentry_extension = get_sentry_extension() @before_task_publish.connect(weak=False) @@ -46,7 +47,7 @@ def load_correlation_id(task: 'Task', **kwargs: Any) -> None: correlation_id.set(id_value) sentry_extension(id_value) else: - generated_correlation_id = uuid4().hex + generated_correlation_id = generator() correlation_id.set(generated_correlation_id) sentry_extension(generated_correlation_id) @@ -61,7 +62,11 @@ def cleanup(**kwargs: Any) -> None: correlation_id.set(None) -def load_celery_current_and_parent_ids(header_key: str = 'CELERY_PARENT_ID') -> None: +def load_celery_current_and_parent_ids( + header_key: str = 'CELERY_PARENT_ID', + generator: Callable[[], str] = uuid_hex_generator, + use_internal_celery_task_id: bool = False, +) -> None: """ Configure Celery event hooks for generating tracing IDs with depth. @@ -83,7 +88,7 @@ def publish_task_from_worker_or_request(headers: Dict[str, str], **kwargs: Any) headers[header_key] = current @task_prerun.connect(weak=False) - def worker_prerun(task: 'Task', **kwargs: Any) -> None: + def worker_prerun(task_id: str, task: 'Task', **kwargs: Any) -> None: """ Set current ID, and parent ID if it exists. """ @@ -91,7 +96,8 @@ def worker_prerun(task: 'Task', **kwargs: Any) -> None: if parent_id: celery_parent_id.set(parent_id) - celery_current_id.set(uuid4().hex) + celery_id = task_id if use_internal_celery_task_id else generator() + celery_current_id.set(celery_id) @task_postrun.connect(weak=False) def clean_up(**kwargs: Any) -> None: diff --git a/asgi_correlation_id/log_filters.py b/asgi_correlation_id/log_filters.py index ac60413..199de96 100644 --- a/asgi_correlation_id/log_filters.py +++ b/asgi_correlation_id/log_filters.py @@ -7,6 +7,10 @@ from logging import LogRecord +def _trim_string(string: Optional[str], string_length: Optional[int]) -> Optional[str]: + return string[:string_length] if string_length is not None and string else string + + # Middleware @@ -27,10 +31,7 @@ def filter(self, record: 'LogRecord') -> bool: metadata. """ cid = correlation_id.get() - if self.uuid_length is not None and cid: - record.correlation_id = cid[: self.uuid_length] - else: - record.correlation_id = cid + record.correlation_id = _trim_string(cid, self.uuid_length) return True @@ -38,7 +39,7 @@ def filter(self, record: 'LogRecord') -> bool: class CeleryTracingIdsFilter(Filter): - def __init__(self, name: str = '', uuid_length: int = 32): + def __init__(self, name: str = '', uuid_length: Optional[int] = None): super().__init__(name=name) self.uuid_length = uuid_length @@ -52,7 +53,7 @@ def filter(self, record: 'LogRecord') -> bool: or from an endpoint, the parent ID will be None. """ pid = celery_parent_id.get() - record.celery_parent_id = pid[: self.uuid_length] if pid else pid + record.celery_parent_id = _trim_string(pid, self.uuid_length) cid = celery_current_id.get() - record.celery_current_id = cid[: self.uuid_length] if cid else cid + record.celery_current_id = _trim_string(cid, self.uuid_length) return True diff --git a/tests/test_log_filter.py b/tests/test_log_filter.py index cd09485..bbaee79 100644 --- a/tests/test_log_filter.py +++ b/tests/test_log_filter.py @@ -18,8 +18,7 @@ def cid(): @pytest.fixture() def log_record(): """Create and return an INFO-level log record""" - record = LogRecord(name='', level=INFO, pathname='', lineno=0, msg='Hello, world!', args=(), exc_info=None) - return record + return LogRecord(name='', level=INFO, pathname='', lineno=0, msg='Hello, world!', args=(), exc_info=None) def test_filter_has_uuid_length_attributes(): @@ -27,7 +26,7 @@ def test_filter_has_uuid_length_attributes(): assert filter_.uuid_length == 8 -def test_filter_adds_correlation_id(cid, log_record): +def test_filter_adds_correlation_id(cid: str, log_record: LogRecord): filter_ = CorrelationIdFilter() assert not hasattr(log_record, 'correlation_id') @@ -35,7 +34,7 @@ def test_filter_adds_correlation_id(cid, log_record): assert log_record.correlation_id == cid -def test_filter_truncates_correlation_id(cid, log_record): +def test_filter_truncates_correlation_id(cid: str, log_record: LogRecord): filter_ = CorrelationIdFilter(uuid_length=8) assert not hasattr(log_record, 'correlation_id') @@ -49,7 +48,7 @@ def test_celery_filter_has_uuid_length_attributes(): assert filter_.uuid_length == 8 -def test_celery_filter_adds_parent_id(cid, log_record): +def test_celery_filter_adds_parent_id(cid: str, log_record: LogRecord): filter_ = CeleryTracingIdsFilter() celery_parent_id.set('a') @@ -58,10 +57,60 @@ def test_celery_filter_adds_parent_id(cid, log_record): assert log_record.celery_parent_id == 'a' -def test_celery_filter_adds_current_id(cid, log_record): +def test_celery_filter_adds_current_id(cid: str, log_record: LogRecord): filter_ = CeleryTracingIdsFilter() celery_current_id.set('b') assert not hasattr(log_record, 'celery_current_id') filter_.filter(log_record) assert log_record.celery_current_id == 'b' + + +@pytest.mark.parametrize( + ('uuid_length', 'expected'), + [ + (6, 6), + (16, 16), + (None, 36), + (38, 36), + ], +) +def test_celery_filter_truncates_current_id_correctly(cid: str, log_record: LogRecord, uuid_length, expected): + """ + If uuid is unspecified, the default should be 36. + + Otherwise, the id should be truncated to the specified length. + """ + filter_ = CeleryTracingIdsFilter(uuid_length=uuid_length) + celery_id = str(uuid4()) + celery_current_id.set(celery_id) + + assert not hasattr(log_record, 'celery_current_id') + filter_.filter(log_record) + assert log_record.celery_current_id == celery_id[:expected] + + +def test_celery_filter_maintains_current_behavior(cid: str, log_record: LogRecord): + """Maintain default behavior with signature change + + Since the default values of CeleryTracingIdsFilter are being changed, + the new default values should also not trim a hex uuid. + """ + celery_id = uuid4().hex + celery_current_id.set(celery_id) + new_filter = CeleryTracingIdsFilter() + + assert not hasattr(log_record, 'celery_current_id') + new_filter.filter(log_record) + assert log_record.celery_current_id == celery_id + new_filter_record_id = log_record.celery_current_id + + del log_record.celery_current_id + + original_filter = CeleryTracingIdsFilter(uuid_length=32) + assert not hasattr(log_record, 'celery_current_id') + original_filter.filter(log_record) + assert log_record.celery_current_id == celery_id + original_filter_record_id = log_record.celery_current_id + + assert original_filter_record_id == new_filter_record_id