Skip to content

Commit

Permalink
implement jwk fingerprinting for dpop
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Jan 20, 2025
1 parent 32ea4a6 commit 7e5de4c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 29 deletions.
5 changes: 3 additions & 2 deletions src/millipds/auth_bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
text="authentication required (this may be a bug, I'm erring on the side of caution for now)"
)
authtype, _, token = auth.partition(" ")
if authtype not in ["Bearer", "DPoP"]:
authtype = authtype.lower()
if authtype not in ["bearer", "dpop"]:
raise web.HTTPUnauthorized(text="invalid auth type")

# validate it TODO: this needs rigorous testing, I'm not 100% sure I'm
Expand All @@ -85,7 +86,7 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
token, options={"verify_signature": False}
)
# logger.info(unverified)
if authtype == "DPoP":
if authtype == "dpop":
# TODO: dpop stuff!!!!!
pass

Expand Down
59 changes: 33 additions & 26 deletions src/millipds/auth_oauth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Tuple
import logging

import jwt
Expand All @@ -18,6 +19,7 @@
from . import static_config
from . import util
from .util import definitely, NoneError
from . import crypto

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,17 +145,18 @@ def pretty_error_page(msg: str) -> web.HTTPBadRequest:
)


def get_auth_request(request: web.Request) -> dict:
def get_auth_request(request: web.Request) -> Tuple[dict, bytes]:
"""
pull a previously PAR'd auth request from db, and check it isn't expired.
reads request_uri query parameter.
Also returns the DPoP jwk thumbprint it was created with.
"""

try:
value, expires_at = definitely(
value, dpop_jkt, expires_at = definitely(
get_db(request)
.con.execute(
"SELECT value, expires_at FROM oauth_par WHERE uri=?",
"SELECT value, dpop_jkt, expires_at FROM oauth_par WHERE uri=?",
(request.query.get("request_uri"),),
)
.fetchone()
Expand All @@ -165,7 +168,7 @@ def get_auth_request(request: web.Request) -> dict:
if expires_at < time.time():
raise pretty_error_page("authorization request expired. try again?")

return cbrrr.decode_dag_cbor(value)
return cbrrr.decode_dag_cbor(value), dpop_jkt


def get_or_initiate_oauth_session(request: web.Request, login_hint: str) -> int:
Expand Down Expand Up @@ -210,7 +213,7 @@ async def oauth_authorize_get(request: web.Request):

client_id_param = request.query.get("client_id")

authorization_request = get_auth_request(request)
authorization_request, dpop_jkt = get_auth_request(request)
logger.info(authorization_request)

login_hint = authorization_request.get("login_hint", "")
Expand Down Expand Up @@ -273,30 +276,31 @@ async def oauth_authorize_get(request: web.Request):
# else, everything checks out.
# generate the auth tokens, encrypt them into the auth code, and redirect the user back to the app!

unix_seconds_now = int(time.time())
# use the same jti for both tokens, so revoking one revokes both
jti = str(uuid.uuid4())
payload_common = {
"aud": db.config["pds_did"],
"sub": did,
"iat": int(time.time()),
"jti": str(uuid.uuid4()),
"cnf": {
"jkt": dpop_jkt,
},
}
access_jwt = jwt.encode(
{
"scope": "com.atproto.access",
"aud": db.config["pds_did"],
"sub": did,
"iat": unix_seconds_now,
"exp": unix_seconds_now + static_config.ACCESS_EXP,
"jti": jti,
payload_common
| {
"scope": authorization_request["scope"],
"exp": payload_common["iat"] + static_config.ACCESS_EXP,
},
db.config["jwt_access_secret"],
"HS256",
)

refresh_jwt = jwt.encode(
{
"scope": "com.atproto.refresh",
"aud": db.config["pds_did"],
"sub": did,
"iat": unix_seconds_now,
"exp": unix_seconds_now + static_config.REFRESH_EXP,
"jti": jti,
payload_common
| {
"scope": "TODO make refresh work",
"exp": payload_common["iat"] + static_config.REFRESH_EXP,
},
db.config["jwt_access_secret"],
"HS256",
Expand All @@ -309,6 +313,7 @@ async def oauth_authorize_get(request: web.Request):
"code_challenge_method": authorization_request[
"code_challenge_method"
],
"dpop_jkt": dpop_jkt,
"token_response": {
"access_token": access_jwt,
"token_type": "DPoP",
Expand Down Expand Up @@ -426,6 +431,7 @@ async def dpop_handler(request: web.Request):
)
jwk_data = unverified["header"]["jwk"]
jwk = jwt.PyJWK.from_dict(jwk_data)
jkt = crypto.jwk_thumbprint(jwk)

# actual signature verification happens here:
decoded: dict = jwt.decode(dpop, key=jwk)
Expand Down Expand Up @@ -458,9 +464,7 @@ async def dpop_handler(request: web.Request):
}, # if we don't put it here, the client will never see it
)

request["dpop_jwk"] = cbrrr.encode_dag_cbor(
jwk_data
) # for easy comparison in db etc.
request["dpop_jkt"] = jkt
request["dpop_jti"] = decoded[
"jti"
] # XXX: should replay prevention happen here?
Expand Down Expand Up @@ -488,6 +492,9 @@ async def oauth_token_post(request: web.Request):
)
logger.info(code_payload)

if request.get("dpop_jkt") != code_payload["dpop_jkt"]:
return web.HTTPBadRequest(text="dpop required")

if code_payload["code_challenge_method"] != "S256":
return web.HTTPBadRequest(text="bad code_challenge_method")

Expand Down Expand Up @@ -540,12 +547,12 @@ async def oauth_pushed_authorization_request(request: web.Request):
get_db(request).con.execute(
"""
INSERT INTO oauth_par (
uri, dpop_jwk, value, created_at, expires_at
uri, dpop_jkt, value, created_at, expires_at
) VALUES (?, ?, ?, ?, ?)
""",
(
par_uri,
request["dpop_jwk"],
request["dpop_jkt"],
cbrrr.encode_dag_cbor(request_json),
now,
now + static_config.OAUTH_PAR_EXP,
Expand Down
41 changes: 41 additions & 0 deletions src/millipds/crypto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Literal
import base64
import json
import hashlib

import jwt

from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes
Expand Down Expand Up @@ -125,3 +129,40 @@ def plc_sign(privkey: ec.EllipticCurvePrivateKey, op: dict) -> str:
raise ValueError("op is already signed!")
rawsig = raw_sign(privkey, cbrrr.encode_dag_cbor(op))
return base64.urlsafe_b64encode(rawsig).decode().rstrip("=")


# in lexicographic order as described in rfc7638
JWK_REQUIRED_MEMBERS = {
"EC": ("crv", "kty", "x", "y"),
"RSA": ("e", "kty", "n"),
"oct": ("k", "kty"),
}


def jwk_thumbprint(jwk: jwt.PyJWK) -> str:
jwk_dict = jwk.Algorithm.to_jwk(jwk.key, as_dict=True)
members = JWK_REQUIRED_MEMBERS.get(jwk.key_type)
if members is None:
raise jwt.exceptions.PyJWKError(
f"I don't know how to canonicalize key type {jwk.key_type}"
)
json_bytes = json.dumps(
{k: jwk_dict[k] for k in members},
separators=(",", ":"),
).encode()
json_hash = hashlib.sha256(json_bytes).digest()
return base64.urlsafe_b64encode(json_hash).rstrip(b"=").decode()


if __name__ == "__main__":
# rfc7638 test vector
test_key = jwt.PyJWK.from_dict(
{
"kty": "RSA",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB",
"alg": "RS256",
"kid": "2011-04-29",
}
)
print(jwk_thumbprint(test_key))
3 changes: 2 additions & 1 deletion src/millipds/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _init_tables(self):
"""
CREATE TABLE oauth_par(
uri TEXT PRIMARY KEY NOT NULL,
dpop_jwk BLOB NOT NULL,
dpop_jkt TEXT NOT NULL,
value BLOB NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
Expand All @@ -297,6 +297,7 @@ def _init_tables(self):
user_id INTEGER NOT NULL,
client_id TEXT NOT NULL,
scope TEXT NOT NULL,
granted_at INTEGER NOT NULL,
FOREIGN KEY (user_id) REFERENCES user(id),
PRIMARY KEY (user_id, client_id, scope)
) STRICT, WITHOUT ROWID
Expand Down

0 comments on commit 7e5de4c

Please sign in to comment.