Skip to content

Commit

Permalink
clean up dpop error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Feb 9, 2025
1 parent 325e386 commit b33ff57
Showing 1 changed file with 110 additions and 101 deletions.
211 changes: 110 additions & 101 deletions src/millipds/auth_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@
# ciphertexts have short TTL so the key does not need to persist.
code_fernet = Fernet(Fernet.generate_key())

DPOP_SIGNING_ALG_SUPPORTED = [
"RS256",
"RS384",
"RS512",
"PS256",
"PS384",
"PS512",
"ES256",
"ES256K",
"ES384",
"ES512",
]


# example: https://shiitake.us-east.host.bsky.network/.well-known/oauth-protected-resource
@routes.get("/.well-known/oauth-protected-resource")
Expand Down Expand Up @@ -121,18 +134,7 @@ async def oauth_authorization_server(request: web.Request):
"introspection_endpoint": pfx + "/oauth/introspect",
"pushed_authorization_request_endpoint": pfx + "/oauth/par",
"require_pushed_authorization_requests": True,
"dpop_signing_alg_values_supported": [
"RS256",
"RS384",
"RS512",
"PS256",
"PS384",
"PS512",
"ES256",
"ES256K",
"ES384",
"ES512",
],
"dpop_signing_alg_values_supported": DPOP_SIGNING_ALG_SUPPORTED,
"client_id_metadata_document_supported": True,
}
)
Expand Down Expand Up @@ -560,11 +562,11 @@ def validate_dpop_nonce_and_extract_jti_and_exp(
) -> Tuple[str, int]:
payload = jwt.decode(
nonce_jwt,
get_db(request).config["jwt_access_secret"] + ":dpop_nonce",
"HS256",
key=get_db(request).config["jwt_access_secret"] + ":dpop_nonce",
algorithms=["HS256"],
options={
"require": ["exp"] # pyjwt will verify if present
}
"require": ["exp"] # pyjwt will verify if present
},
)
return payload["jti"], payload["exp"]

Expand All @@ -575,105 +577,112 @@ async def dpop_middlware(request: web.Request, handler) -> web.Response:
if (dpop := request.headers.get("dpop")) is None:
return await handler(request)

# we're not verifying yet, we just want to pull out the jwk from the header
unverified = jwt.api_jwt.decode_complete(
dpop, options={"verify_signature": False}
)
jwk_data = unverified["header"]["jwk"]
jwk = jwt.PyJWK.from_dict(jwk_data)
jkt = crypto.jwk_thumbprint(jwk)
try:
# we're not verifying yet, we just want to pull out the jwk from the header
unverified = jwt.api_jwt.decode_complete(
dpop, options={"verify_signature": False}
)
jwk_data = unverified["header"]["jwk"]
jwk = jwt.PyJWK.from_dict(jwk_data)
jkt = crypto.jwk_thumbprint(jwk)

# actual signature verification happens here:
try:
# TODO: be explicit about what algorithms we support???
decoded: dict = jwt.decode(
dpop, key=jwk, algorithms=DPOP_SIGNING_ALG_SUPPORTED
)
except jwt.DecodeError:
raise util.atproto_json_http_error(
web.HTTPBadRequest,
"invalid_dpop_proof",
"failed to verify dpop proof signature",
)

# actual signature verification happens here:
decoded: dict = jwt.decode(dpop, key=jwk)
logger.info(decoded)
logger.info(request.url)

logger.info(decoded)
logger.info(request.url)
# TODO: verify iat?

# TODO: verify iat?
if request.method != decoded["htm"]:
raise web.HTTPBadRequest(text="dpop: bad htm")

if request.method != decoded["htm"]:
raise web.HTTPBadRequest(text="dpop: bad htm")
if str(request.url) != decoded["htu"]:
logger.info(f"{request.url!r} != {decoded['htu']!r}")
raise web.HTTPBadRequest(
text="dpop: bad htu (if your application is reverse-proxied, make sure the Host header is getting set properly)"
)

if str(request.url) != decoded["htu"]:
logger.info(f"{request.url!r} != {decoded['htu']!r}")
raise web.HTTPBadRequest(
text="dpop: bad htu (if your application is reverse-proxied, make sure the Host header is getting set properly)"
)
if "nonce" not in decoded:
raise util.atproto_json_http_error(
web.HTTPBadRequest,
"use_dpop_nonce",
"Authorization server requires nonce in DPoP proof",
)

if "nonce" not in decoded:
res = util.atproto_json_http_error(
web.HTTPBadRequest,
"use_dpop_nonce",
"Authorization server requires nonce in DPoP proof",
)
res.headers["DPoP-Nonce"] = generate_dpop_nonce(request)
raise res
try:
nonce_jti, nonce_exp = validate_dpop_nonce_and_extract_jti_and_exp(
request, decoded["nonce"]
)
except jwt.ExpiredSignatureError:
raise util.atproto_json_http_error(
web.HTTPBadRequest,
"invalid_dpop_proof",
"expired nonce",
)
except jwt.DecodeError:
raise util.atproto_json_http_error(
web.HTTPBadRequest,
"invalid_dpop_proof",
"invalid nonce",
)

try:
nonce_jti, nonce_exp = validate_dpop_nonce_and_extract_jti_and_exp(
request,
decoded["nonce"]
)
except jwt.ExpiredSignatureError:
res = util.atproto_json_http_error(
web.HTTPBadRequest,
"invalid_dpop_proof",
"expired nonce",
db = get_db(request)
# note: we don't need to check for nonce expiry in this query because
# validate_dpop_nonce_and_extract_jti already checked it
is_replay = db.con.execute(
"SELECT COUNT(*) FROM dpop_replay WHERE dpop_jti=? AND nonce_jti=?",
(decoded["jti"], nonce_jti),
).get
if is_replay:
# TODO: is this the right kind of error?
raise util.atproto_json_http_error(
web.HTTPBadRequest,
"invalid_dpop_proof",
"you used that one before?!",
)

# store this one so it can't be replayed later
db.con.execute(
"""
INSERT INTO dpop_replay (dpop_jti, nonce_jti, nonce_expires_at)
VALUES (?, ?, ?)
""",
(decoded["jti"], nonce_jti, nonce_exp),
)
res.headers["DPoP-Nonce"] = generate_dpop_nonce(request)
raise res
except jwt.DecodeError:
res = util.atproto_json_http_error(
web.HTTPBadRequest,
"invalid_dpop_proof",
"invalid nonce",

# drop expired nonces from the replay table
db.con.execute(
"DELETE FROM dpop_replay WHERE nonce_expires_at<?",
(int(time.time()),),
)
res.headers["DPoP-Nonce"] = generate_dpop_nonce(request)
raise res

db = get_db(request)
# note: we don't need to check for nonce expiry in this query because
# validate_dpop_nonce_and_extract_jti already checked it
is_replay = db.con.execute(
"SELECT COUNT(*) FROM dpop_replay WHERE dpop_jti=? AND nonce_jti=?",
(decoded["jti"], nonce_jti),
).get
if is_replay:
# TODO: is this the right kind of error?
res = util.atproto_json_http_error(
web.HTTPBadRequest,
"invalid_dpop_proof",
"you used that one before?!",
request["verified_dpop_jkt"] = (
jkt # certifies that the dpop is valid for this particular jkt
)
request["dpop_jti"] = decoded[
"jti"
] # do we really need to pass this thru?
request["dpop_iss"] = decoded["iss"]
# TODO: store dpop_ath and check it during authorization

res: web.Response = await handler(request)
res.headers["DPoP-Nonce"] = generate_dpop_nonce(request)
return res
except web.HTTPError as res:
res.headers["DPoP-Nonce"] = generate_dpop_nonce(request)
raise res

# store this one so it can't be replayed later
db.con.execute(
"""
INSERT INTO dpop_replay (dpop_jti, nonce_jti, nonce_expires_at)
VALUES (?, ?, ?)
""",
(decoded["jti"], nonce_jti, nonce_exp),
)

# drop expired nonces from the replay table
db.con.execute(
"DELETE FROM dpop_replay WHERE nonce_expires_at<?", (int(time.time()),)
)

request["verified_dpop_jkt"] = (
jkt # certifies that the dpop is valid for this particular jkt
)
request["dpop_jti"] = decoded["jti"] # do we really need to pass this thru?
request["dpop_iss"] = decoded["iss"]

res: web.Response = await handler(request)
# TODO: make sure this always gets set even under error conditions?
# do we need to try/catch and re-raise or does aiohttp do something like that for us?
res.headers["DPoP-Nonce"] = generate_dpop_nonce(request)
return res


@routes.get("/xrpc/com.atproto.server.listAppPasswords")
@auth_required({"transition:generic"})
Expand Down

0 comments on commit b33ff57

Please sign in to comment.