Skip to content

Commit

Permalink
Merge pull request #39 from snok/configuration
Browse files Browse the repository at this point in the history
Make middleware fully configurable
  • Loading branch information
sondrelg authored May 16, 2022
2 parents 0c6fd1e + 24b5530 commit d4d06da
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 89 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ repos:
]

- repo: https://github.com/sirosen/check-jsonschema
rev: 0.14.3
rev: 0.15.0
hooks:
- id: check-github-actions
- id: check-github-workflows

- repo: https://github.com/asottile/pyupgrade
rev: v2.32.0
rev: v2.32.1
hooks:
- id: pyupgrade
args: [ "--py36-plus" ]
Expand Down
66 changes: 50 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# ASGI Correlation ID middleware

Middleware for loading or generating correlation IDs for each incoming request. Correlation IDs can be added to your
Middleware for reading or generating correlation IDs for each incoming request. Correlation IDs can then be added to your
logs, making it simple to retrieve all logs generated from a single HTTP request.

When the middleware detects a correlation ID HTTP header in an incoming request, the ID is stored. If no header is
Expand Down Expand Up @@ -64,21 +64,7 @@ app.add_middleware(CorrelationIdMiddleware)

or any other way your framework allows.

For [Starlette](https://github.com/encode/starlette) apps, just substitute `FastAPI` with `Starlette` in the example
above.

The middleware only has two settings, and can be configured like this:

```python
app.add_middleware(
CorrelationIdMiddleware,
# The HTTP header key to read IDs from.
header_name='X-Request-ID',
# Enforce UUID formatting to limit chance of collisions
# - Invalid header values are discarded, and an ID is generated in its place
validate_header_as_uuid=True
)
```
For [Starlette](https://github.com/encode/starlette) apps, just substitute `FastAPI` with `Starlette` in all examples.

## Configure logging

Expand Down Expand Up @@ -154,6 +140,54 @@ LOGGING = {

If you're using a json log-formatter, just add `correlation-id: %(correlation_id)s` to your list of properties.

## Middleware configuration

The middleware can be configured in a few ways, but there are no required arguments.

```python
app.add_middleware(
CorrelationIdMiddleware,
header_name='X-Request-ID',
generator=lambda: uuid4().hex,
validator=is_valid_uuid4,
transformer=lambda a: a,
)
```

Configurable middleware arguments include:

**header_name**

- Type: `str`
- Default: `X-Request-ID`
- Description: The header name decides which HTTP header value to read correlation IDs from. `X-Request-ID` and
`X-Correlation-ID` are common choices.

**generator**

- Type: `Callable[[], str]`
- Default: `lambda: uuid4().hex`
- Description: The generator function is responsible for generating new correlation IDs when no ID is received from an
incoming request's headers. We use UUIDs by default, but if you prefer, you could use libraries
like [nanoid](https://github.com/puyuan/py-nanoid) or your own custom function.

**validator**

- Type: `Callable[[str], bool]`
- Default: `is_valid_uuid` (
found [here](https://github.com/snok/asgi-correlation-id/blob/main/asgi_correlation_id/validators.py))
- Description: The validator function is used when reading incoming HTTP header values. By default, we discard non-UUID
formatted header values, to enforce correlation ID uniqueness. If you prefer to allow any header value, you can set
this setting to `None`, or pass your own validator.

**transformer**

- Type: `Callable[[str], str]`
- Default: `lambda a: a`
- Description: Most users won't need a transformer, and by default we do nothing.
The argument was added for cases where users might want to alter incoming or generated ID values in some way. It
provides a mechanism for transforming an incoming ID in a way you see fit. See the middleware code for more context.

## Exception handling

By default, the `X-Correlation-ID` and `Access-Control-Expose-Headers` response headers will be included in all
Expand Down
35 changes: 27 additions & 8 deletions asgi_correlation_id/middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional
from uuid import UUID, uuid4

from starlette.datastructures import Headers, MutableHeaders
Expand All @@ -14,7 +14,7 @@
logger = logging.getLogger('asgi_correlation_id')


def is_valid_uuid(uuid_: str) -> bool:
def is_valid_uuid4(uuid_: str) -> bool:
"""
Check whether a string is a valid v4 uuid.
"""
Expand All @@ -24,11 +24,22 @@ def is_valid_uuid(uuid_: str) -> bool:
return False


FAILED_VALIDATION_MESSAGE = 'Generated new request ID (%s), since request header value failed validation'


@dataclass
class CorrelationIdMiddleware:
app: 'ASGIApp'
header_name: str = 'X-Request-ID'
validate_header_as_uuid: bool = True

# ID-generating callable
generator: Callable[[], str] = field(default=lambda: uuid4().hex)

# ID validator
validator: Optional[Callable[[str], bool]] = field(default=is_valid_uuid4)

# ID transformer - can be used to clean/mutate IDs
transformer: Optional[Callable[[str], str]] = field(default=lambda a: a)

async def __call__(self, scope: 'Scope', receive: 'Receive', send: 'Send') -> None:
"""
Expand All @@ -38,16 +49,24 @@ async def __call__(self, scope: 'Scope', receive: 'Receive', send: 'Send') -> No
await self.app(scope, receive, send)
return

# Try to load request ID from the request headers
header_value = Headers(scope=scope).get(self.header_name.lower())

if not header_value:
id_value = uuid4().hex
elif self.validate_header_as_uuid and not is_valid_uuid(header_value):
logger.warning('Generating new UUID, since header value \'%s\' is invalid', header_value)
id_value = uuid4().hex
# Generate request ID if none was found
id_value = self.generator()
elif self.validator and not self.validator(header_value):
# Also generate a request ID if one was found, but it was deemed invalid
id_value = self.generator()
logger.warning(FAILED_VALIDATION_MESSAGE, header_value)
else:
# Otherwise, use the found request ID
id_value = header_value

# Clean/change the ID if needed
if self.transformer:
id_value = self.transformer(id_value)

correlation_id.set(id_value)
self.sentry_extension(id_value)

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "asgi-correlation-id"
version = "2.0.0"
version = "3.0.0a1"
description = "Middleware correlating project logs to individual requests"
authors = ["Sondre Lillebø Gundersen <[email protected]>"]
maintainers = ["Jonas Krüger Svensson <[email protected]>"]
Expand Down Expand Up @@ -74,7 +74,7 @@ build-backend = "poetry.masonry.api"
quiet = true
line-length = 120
skip-string-normalization = true
experimental-string-processing = true
preview = true

[tool.isort]
profile = "black"
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[tool:pytest]
testpaths = tests
asyncio_mode = auto

[flake8]
max-line-length = 120
Expand Down
14 changes: 11 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from logging.config import dictConfig

import pytest
import pytest_asyncio
from fastapi import FastAPI
from httpx import AsyncClient
from starlette.middleware import Middleware
Expand Down Expand Up @@ -44,7 +45,14 @@ def _configure_logging():
dictConfig(LOGGING)


app = FastAPI(middleware=[Middleware(CorrelationIdMiddleware)])
TRANSFORMER_VALUE = 'some-id'

default_app = FastAPI(middleware=[Middleware(CorrelationIdMiddleware)])
no_validator_or_transformer_app = FastAPI(
middleware=[Middleware(CorrelationIdMiddleware, validator=None, transformer=None)]
)
transformer_app = FastAPI(middleware=[Middleware(CorrelationIdMiddleware, transformer=lambda a: a * 2)])
generator_app = FastAPI(middleware=[Middleware(CorrelationIdMiddleware, generator=lambda: TRANSFORMER_VALUE)])


@pytest.fixture(scope='session', autouse=True)
Expand All @@ -54,7 +62,7 @@ def event_loop():
loop.close()


@pytest.fixture(scope='module')
@pytest_asyncio.fixture(scope='module')
async def client() -> AsyncClient:
async with AsyncClient(app=app, base_url='http://test') as client:
async with AsyncClient(app=default_app, base_url='http://test') as client:
yield client
4 changes: 2 additions & 2 deletions tests/test_extension_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from celery import shared_task

from asgi_correlation_id.extensions.celery import load_celery_current_and_parent_ids, load_correlation_ids
from tests.conftest import app
from tests.conftest import default_app

logger = logging.getLogger('asgi_correlation_id')

Expand Down Expand Up @@ -40,7 +40,7 @@ async def test_endpoint_to_worker_to_worker(client, caplog, celery_session_app,
- The current ID of the first worker to be added as the parent ID of the second worker
"""

@app.get('/celery-test', status_code=200)
@default_app.get('/celery-test', status_code=200)
async def test_view() -> dict:
logger.debug('Test view')
task1.delay().get(timeout=10)
Expand Down
Loading

0 comments on commit d4d06da

Please sign in to comment.