Skip to content

Commit

Permalink
Standardized the signatures for CorrelationIdFilter/CeleryTracingIdsF…
Browse files Browse the repository at this point in the history
…ilter and standardized the string trimming process for the filtering logic

* changed type from int to Optional[int] for optional parameter 'uuid_length' in CeleryTracingIdsFilter
* changed default value from 32 to None for optional parameter 'uuid_length' in CeleryTracingIdsFilter
* created and utilized function `_trim_string` to standardize string trimming logic in filters
* added test to test string trimming for CeleryTracingIdsFilter
* added test to ensure the default behavior of the new filter matched the behavior of the old filter.  This assumes default generators are used
  • Loading branch information
David Pryor (dapryor) committed Sep 28, 2022
1 parent 75d6587 commit ec96801
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
20 changes: 13 additions & 7 deletions asgi_correlation_id/log_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
from logging import LogRecord


def _trim_string(string: Optional[str], string_length: Optional[int]) -> Optional[str]:
trimmed_string: Optional[str]
if string_length is not None and string:
trimmed_string = string[:string_length]
else:
trimmed_string = string
return trimmed_string


# Middleware


Expand All @@ -27,18 +36,15 @@ def filter(self, record: 'LogRecord') -> bool:
metadata.
"""
cid = correlation_id.get()
if self.uuid_length is not None and cid:
record.correlation_id = cid[: self.uuid_length]
else:
record.correlation_id = cid
record.correlation_id = _trim_string(cid, self.uuid_length)
return True


# Celery extension


class CeleryTracingIdsFilter(Filter):
def __init__(self, name: str = '', uuid_length: int = 32):
def __init__(self, name: str = '', uuid_length: Optional[int] = None):
super().__init__(name=name)
self.uuid_length = uuid_length

Expand All @@ -52,7 +58,7 @@ def filter(self, record: 'LogRecord') -> bool:
or from an endpoint, the parent ID will be None.
"""
pid = celery_parent_id.get()
record.celery_parent_id = pid[: self.uuid_length] if pid else pid
record.celery_parent_id = _trim_string(pid, self.uuid_length)
cid = celery_current_id.get()
record.celery_current_id = cid[: self.uuid_length] if cid else cid
record.celery_current_id = _trim_string(cid, self.uuid_length)
return True
47 changes: 47 additions & 0 deletions tests/test_log_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,50 @@ def test_celery_filter_adds_current_id(cid, log_record):
assert not hasattr(log_record, 'celery_current_id')
filter_.filter(log_record)
assert log_record.celery_current_id == 'b'


def test_celery_filter_does_not_truncate_current_id(cid, log_record):
filter_ = CeleryTracingIdsFilter()
celery_id: str = str(uuid4())
celery_current_id.set(celery_id)

assert not hasattr(log_record, 'celery_current_id')
filter_.filter(log_record)
assert log_record.celery_current_id == celery_id


def test_celery_filter_maintains_current_behavior(cid, log_record):
"""Maintain default behavior with signature change
Since the default values of CeleryTracingIdsFilter are being changed,
the new default values should also not trim a hex uuid.
"""
celery_id: str = uuid4().hex
celery_current_id.set(celery_id)

new_filter = CeleryTracingIdsFilter()

assert not hasattr(log_record, 'celery_current_id')
new_filter.filter(log_record)
assert log_record.celery_current_id == celery_id
new_filter_record_id = log_record.celery_current_id

del log_record.celery_current_id

original_filter = CeleryTracingIdsFilter(uuid_length=32)
assert not hasattr(log_record, 'celery_current_id')
original_filter.filter(log_record)
assert log_record.celery_current_id == celery_id
original_filter_record_id = log_record.celery_current_id

assert original_filter_record_id == new_filter_record_id


def test_celery_filter_does_truncates_current_id(cid, log_record):
filter_ = CeleryTracingIdsFilter(uuid_length=16)
celery_id: str = uuid4().hex
celery_current_id.set(celery_id)

assert not hasattr(log_record, 'celery_current_id')
filter_.filter(log_record)
assert log_record.celery_current_id == celery_id[:16]

0 comments on commit ec96801

Please sign in to comment.