diff --git a/asgi_correlation_id/log_filters.py b/asgi_correlation_id/log_filters.py index ac60413..8e05dd2 100644 --- a/asgi_correlation_id/log_filters.py +++ b/asgi_correlation_id/log_filters.py @@ -7,6 +7,15 @@ from logging import LogRecord +def _trim_string(string: Optional[str], string_length: Optional[int]) -> Optional[str]: + trimmed_string: Optional[str] + if string_length is not None and string: + trimmed_string = string[:string_length] + else: + trimmed_string = string + return trimmed_string + + # Middleware @@ -27,10 +36,7 @@ def filter(self, record: 'LogRecord') -> bool: metadata. """ cid = correlation_id.get() - if self.uuid_length is not None and cid: - record.correlation_id = cid[: self.uuid_length] - else: - record.correlation_id = cid + record.correlation_id = _trim_string(cid, self.uuid_length) return True @@ -38,7 +44,7 @@ def filter(self, record: 'LogRecord') -> bool: class CeleryTracingIdsFilter(Filter): - def __init__(self, name: str = '', uuid_length: int = 32): + def __init__(self, name: str = '', uuid_length: Optional[int] = None): super().__init__(name=name) self.uuid_length = uuid_length @@ -52,7 +58,7 @@ def filter(self, record: 'LogRecord') -> bool: or from an endpoint, the parent ID will be None. """ pid = celery_parent_id.get() - record.celery_parent_id = pid[: self.uuid_length] if pid else pid + record.celery_parent_id = _trim_string(pid, self.uuid_length) cid = celery_current_id.get() - record.celery_current_id = cid[: self.uuid_length] if cid else cid + record.celery_current_id = _trim_string(cid, self.uuid_length) return True diff --git a/tests/test_log_filter.py b/tests/test_log_filter.py index cd09485..99190fc 100644 --- a/tests/test_log_filter.py +++ b/tests/test_log_filter.py @@ -65,3 +65,50 @@ def test_celery_filter_adds_current_id(cid, log_record): assert not hasattr(log_record, 'celery_current_id') filter_.filter(log_record) assert log_record.celery_current_id == 'b' + + +def test_celery_filter_does_not_truncate_current_id(cid, log_record): + filter_ = CeleryTracingIdsFilter() + celery_id: str = str(uuid4()) + celery_current_id.set(celery_id) + + assert not hasattr(log_record, 'celery_current_id') + filter_.filter(log_record) + assert log_record.celery_current_id == celery_id + + +def test_celery_filter_maintains_current_behavior(cid, log_record): + """Maintain default behavior with signature change + + Since the default values of CeleryTracingIdsFilter are being changed, + the new default values should also not trim a hex uuid. + """ + celery_id: str = uuid4().hex + celery_current_id.set(celery_id) + + new_filter = CeleryTracingIdsFilter() + + assert not hasattr(log_record, 'celery_current_id') + new_filter.filter(log_record) + assert log_record.celery_current_id == celery_id + new_filter_record_id = log_record.celery_current_id + + del log_record.celery_current_id + + original_filter = CeleryTracingIdsFilter(uuid_length=32) + assert not hasattr(log_record, 'celery_current_id') + original_filter.filter(log_record) + assert log_record.celery_current_id == celery_id + original_filter_record_id = log_record.celery_current_id + + assert original_filter_record_id == new_filter_record_id + + +def test_celery_filter_does_truncates_current_id(cid, log_record): + filter_ = CeleryTracingIdsFilter(uuid_length=16) + celery_id: str = uuid4().hex + celery_current_id.set(celery_id) + + assert not hasattr(log_record, 'celery_current_id') + filter_.filter(log_record) + assert log_record.celery_current_id == celery_id[:16]