Skip to content

Commit

Permalink
Merge pull request #51 from dapryor/main
Browse files Browse the repository at this point in the history
Make celery log filters IDs configurable
  • Loading branch information
sondrelg authored Sep 29, 2022
2 parents 4a30bc6 + 808122c commit 42cf859
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ jobs:
with:
file: ./coverage.xml
fail_ci_if_error: true
if: matrix.python-version == '3.10'
if: matrix.python-version == '3.10.6'
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,15 @@ load_correlation_ids()
+ load_celery_current_and_parent_ids()
```

If you wish to correlate celery task IDs through the IDs found in your broker (i.e., the celery `task_id`), use the `use_internal_celery_task_id` argument on `load_celery_current_and_parent_ids`
```diff
from asgi_correlation_id.extensions.celery import load_correlation_ids, load_celery_current_and_parent_ids

load_correlation_ids()
+ load_celery_current_and_parent_ids(use_internal_celery_task_id=True)
```
Note: `load_celery_current_and_parent_ids` will ignore the `generator` argument when `use_internal_celery_task_id` is set to `True`

To set up the additional log filters, update your log config like this:

```diff
Expand Down
20 changes: 13 additions & 7 deletions asgi_correlation_id/extensions/celery.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict
from uuid import uuid4

from celery.signals import before_task_publish, task_postrun, task_prerun
Expand All @@ -8,8 +8,10 @@
if TYPE_CHECKING:
from celery import Task

uuid_hex_generator: Callable[[], str] = lambda: uuid4().hex

def load_correlation_ids() -> None:

def load_correlation_ids(header_key: str = 'CORRELATION_ID', generator: Callable[[], str] = uuid_hex_generator) -> None:
"""
Transfer correlation IDs from a HTTP request to a Celery worker,
when spawned from a request.
Expand All @@ -18,7 +20,6 @@ def load_correlation_ids() -> None:
"""
from asgi_correlation_id.context import correlation_id

header_key = 'CORRELATION_ID'
sentry_extension = get_sentry_extension()

@before_task_publish.connect(weak=False)
Expand Down Expand Up @@ -46,7 +47,7 @@ def load_correlation_id(task: 'Task', **kwargs: Any) -> None:
correlation_id.set(id_value)
sentry_extension(id_value)
else:
generated_correlation_id = uuid4().hex
generated_correlation_id = generator()
correlation_id.set(generated_correlation_id)
sentry_extension(generated_correlation_id)

Expand All @@ -61,7 +62,11 @@ def cleanup(**kwargs: Any) -> None:
correlation_id.set(None)


def load_celery_current_and_parent_ids(header_key: str = 'CELERY_PARENT_ID') -> None:
def load_celery_current_and_parent_ids(
header_key: str = 'CELERY_PARENT_ID',
generator: Callable[[], str] = uuid_hex_generator,
use_internal_celery_task_id: bool = False,
) -> None:
"""
Configure Celery event hooks for generating tracing IDs with depth.
Expand All @@ -83,15 +88,16 @@ def publish_task_from_worker_or_request(headers: Dict[str, str], **kwargs: Any)
headers[header_key] = current

@task_prerun.connect(weak=False)
def worker_prerun(task: 'Task', **kwargs: Any) -> None:
def worker_prerun(task_id: str, task: 'Task', **kwargs: Any) -> None:
"""
Set current ID, and parent ID if it exists.
"""
parent_id = task.request.get(header_key)
if parent_id:
celery_parent_id.set(parent_id)

celery_current_id.set(uuid4().hex)
celery_id = task_id if use_internal_celery_task_id else generator()
celery_current_id.set(celery_id)

@task_postrun.connect(weak=False)
def clean_up(**kwargs: Any) -> None:
Expand Down
15 changes: 8 additions & 7 deletions asgi_correlation_id/log_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from logging import LogRecord


def _trim_string(string: Optional[str], string_length: Optional[int]) -> Optional[str]:
return string[:string_length] if string_length is not None and string else string


# Middleware


Expand All @@ -27,18 +31,15 @@ def filter(self, record: 'LogRecord') -> bool:
metadata.
"""
cid = correlation_id.get()
if self.uuid_length is not None and cid:
record.correlation_id = cid[: self.uuid_length]
else:
record.correlation_id = cid
record.correlation_id = _trim_string(cid, self.uuid_length)
return True


# Celery extension


class CeleryTracingIdsFilter(Filter):
def __init__(self, name: str = '', uuid_length: int = 32):
def __init__(self, name: str = '', uuid_length: Optional[int] = None):
super().__init__(name=name)
self.uuid_length = uuid_length

Expand All @@ -52,7 +53,7 @@ def filter(self, record: 'LogRecord') -> bool:
or from an endpoint, the parent ID will be None.
"""
pid = celery_parent_id.get()
record.celery_parent_id = pid[: self.uuid_length] if pid else pid
record.celery_parent_id = _trim_string(pid, self.uuid_length)
cid = celery_current_id.get()
record.celery_current_id = cid[: self.uuid_length] if cid else cid
record.celery_current_id = _trim_string(cid, self.uuid_length)
return True
61 changes: 55 additions & 6 deletions tests/test_log_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,23 @@ def cid():
@pytest.fixture()
def log_record():
"""Create and return an INFO-level log record"""
record = LogRecord(name='', level=INFO, pathname='', lineno=0, msg='Hello, world!', args=(), exc_info=None)
return record
return LogRecord(name='', level=INFO, pathname='', lineno=0, msg='Hello, world!', args=(), exc_info=None)


def test_filter_has_uuid_length_attributes():
filter_ = CorrelationIdFilter(uuid_length=8)
assert filter_.uuid_length == 8


def test_filter_adds_correlation_id(cid, log_record):
def test_filter_adds_correlation_id(cid: str, log_record: LogRecord):
filter_ = CorrelationIdFilter()

assert not hasattr(log_record, 'correlation_id')
filter_.filter(log_record)
assert log_record.correlation_id == cid


def test_filter_truncates_correlation_id(cid, log_record):
def test_filter_truncates_correlation_id(cid: str, log_record: LogRecord):
filter_ = CorrelationIdFilter(uuid_length=8)

assert not hasattr(log_record, 'correlation_id')
Expand All @@ -49,7 +48,7 @@ def test_celery_filter_has_uuid_length_attributes():
assert filter_.uuid_length == 8


def test_celery_filter_adds_parent_id(cid, log_record):
def test_celery_filter_adds_parent_id(cid: str, log_record: LogRecord):
filter_ = CeleryTracingIdsFilter()
celery_parent_id.set('a')

Expand All @@ -58,10 +57,60 @@ def test_celery_filter_adds_parent_id(cid, log_record):
assert log_record.celery_parent_id == 'a'


def test_celery_filter_adds_current_id(cid, log_record):
def test_celery_filter_adds_current_id(cid: str, log_record: LogRecord):
filter_ = CeleryTracingIdsFilter()
celery_current_id.set('b')

assert not hasattr(log_record, 'celery_current_id')
filter_.filter(log_record)
assert log_record.celery_current_id == 'b'


@pytest.mark.parametrize(
('uuid_length', 'expected'),
[
(6, 6),
(16, 16),
(None, 36),
(38, 36),
],
)
def test_celery_filter_truncates_current_id_correctly(cid: str, log_record: LogRecord, uuid_length, expected):
"""
If uuid is unspecified, the default should be 36.
Otherwise, the id should be truncated to the specified length.
"""
filter_ = CeleryTracingIdsFilter(uuid_length=uuid_length)
celery_id = str(uuid4())
celery_current_id.set(celery_id)

assert not hasattr(log_record, 'celery_current_id')
filter_.filter(log_record)
assert log_record.celery_current_id == celery_id[:expected]


def test_celery_filter_maintains_current_behavior(cid: str, log_record: LogRecord):
"""Maintain default behavior with signature change
Since the default values of CeleryTracingIdsFilter are being changed,
the new default values should also not trim a hex uuid.
"""
celery_id = uuid4().hex
celery_current_id.set(celery_id)
new_filter = CeleryTracingIdsFilter()

assert not hasattr(log_record, 'celery_current_id')
new_filter.filter(log_record)
assert log_record.celery_current_id == celery_id
new_filter_record_id = log_record.celery_current_id

del log_record.celery_current_id

original_filter = CeleryTracingIdsFilter(uuid_length=32)
assert not hasattr(log_record, 'celery_current_id')
original_filter.filter(log_record)
assert log_record.celery_current_id == celery_id
original_filter_record_id = log_record.celery_current_id

assert original_filter_record_id == new_filter_record_id

0 comments on commit 42cf859

Please sign in to comment.