Skip to content

Commit

Permalink
Start working on SPAKE2+
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Jul 24, 2024
1 parent aa1b3fb commit 5f14782
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 2 deletions.
181 changes: 180 additions & 1 deletion circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

import binascii
import enum
import hashlib
import hmac
import pathlib
import json
import os
import pathlib
import struct
import time
from ecdsa.ellipticcurve import Point
from ecdsa.curves import NIST256p

from typing import Optional

Expand Down Expand Up @@ -192,6 +197,28 @@ class PBKDFParamResponse(tlv.TLVStructure):
responderSessionParams = tlv.StructMember(5, SessionParameterStruct, optional=True)


CRYPTO_GROUP_SIZE_BITS = 256
CRYPTO_GROUP_SIZE_BYTES = 32
CRYPTO_PUBLIC_KEY_SIZE_BYTES = (2 * CRYPTO_GROUP_SIZE_BYTES) + 1

CRYPTO_HASH_LEN_BITS = 256
CRYPTO_HASH_LEN_BYTES = 32
CRYPTO_HASH_BLOCK_LEN_BYTES = 64


class PAKE1(tlv.TLVStructure):
pA = tlv.OctetStringMember(1, CRYPTO_PUBLIC_KEY_SIZE_BYTES)


class PAKE2(tlv.TLVStructure):
pB = tlv.OctetStringMember(1, CRYPTO_PUBLIC_KEY_SIZE_BYTES)
cB = tlv.OctetStringMember(2, CRYPTO_HASH_LEN_BYTES)


class PAKE3(tlv.TLVStructure):
cA = tlv.OctetStringMember(1, CRYPTO_HASH_LEN_BYTES)


class MessageReceptionState:
def __init__(self, starting_value, rollover=True, encrypted=False):
"""Implements 4.6.5.1"""
Expand Down Expand Up @@ -824,6 +851,135 @@ def process_exchange(self, message):
return exchange


M = Point.from_bytes(
NIST256p.curve,
b"\x02\x88\x6e\x2f\x97\xac\xe4\x6e\x55\xba\x9d\xd7\x24\x25\x79\xf2\x99\x3b\x64\xe1\x6e\xf3\xdc\xab\x95\xaf\xd4\x97\x33\x3d\x8f\xa1\x2f",
)
N = Point.from_bytes(
NIST256p.curve,
b"\x03\xd8\xbb\xd6\xc6\x39\xc6\x29\x37\xb0\x4d\x99\x7f\x38\xc3\x77\x07\x19\xc6\x29\xd7\x01\x4d\x49\xa2\x4b\x4f\x98\xba\xa1\x29\x2b\x49",
)
CRYPTO_W_SIZE_BYTES = CRYPTO_GROUP_SIZE_BYTES + 8


def _pbkdf2(passcode, salt, iterations):
ws = hashlib.pbkdf2_hmac(
"sha256", struct.pack("<I", passcode), salt, iterations, CRYPTO_W_SIZE_BYTES * 2
)
w0 = int.from_bytes(ws[:CRYPTO_W_SIZE_BYTES], byteorder="big") % NIST256p.order
w1 = int.from_bytes(ws[CRYPTO_W_SIZE_BYTES:], byteorder="big") % NIST256p.order
return w0, w1


def initiator_values(passcode, salt, iterations) -> tuple[bytes, bytes]:
w0, w1 = _pbkdf2(passcode, salt, iterations)
return w0.to_bytes(NIST256p.baselen, byteorder="big"), w1.to_bytes(
NIST256p.baselen, byteorder="big"
)


def verifier_values(passcode: int, salt: bytes, iterations: int) -> tuple[bytes, bytes]:
w0, w1 = _pbkdf2(passcode, salt, iterations)
L = NIST256p.generator * w1

return w0.to_bytes(NIST256p.baselen, byteorder="big"), L.to_bytes("uncompressed")


# w0 and w1 are big-endian encoded
def Crypto_pA(w0, w1) -> bytes:
return b""


def Crypto_pB(w0: bytes, L: bytes) -> bytes:
return b""


def Crypto_Transcript(context, pA, pB, Z, V, w0) -> bytes:
elements = [
context,
b"",
b"",
M.to_bytes("uncompressed"),
N.to_bytes("uncompressed"),
pA,
pB,
Z,
V,
w0,
]
total_length = 0
for e in elements:
total_length += len(e) + 8
tt = bytearray(total_length)
offset = 0
for e in elements:
struct.pack_into("<Q", tt, offset, len(e))
offset += 8

tt[offset : offset + len(e)] = e
offset += len(e)
return tt


def Crypto_Hash(message) -> bytes:
return hashlib.sha256(message).digest()


def Crypto_HMAC(key, message) -> bytes:
m = hmac.new(key, digestmod=hashlib.sha256)
m.update(message)
return m.digest()


def HKDF_Extract(salt, input_key) -> bytes:
return Crypto_HMAC(salt, input_key)


def HKDF_Expand(prk, info, length) -> bytes:
if length > 255:
raise ValueError("length must be less than 256")
last_hash = b""
bytes_generated = []
num_bytes_generated = 0
i = 1
while num_bytes_generated < length:
num_bytes_generated += CRYPTO_HASH_LEN_BYTES
# Do the hmac directly so we don't need to allocate a buffer for last_hash + info + i.
m = hmac.new(prk, digestmod=hashlib.sha256)
m.update(last_hash)
m.update(info)
m.update(struct.pack("b", i))
last_hash = m.digest()
bytes_generated.append(last_hash)
i += 1
return b"".join(bytes_generated)


def Crypto_KDF(input_key, salt, info, length):
if salt is None:
salt = b"\x00" * CRYPTO_HASH_LEN_BYTES
return HKDF_Expand(HKDF_Extract(salt, input_key), info, length / 8)


def KDF(salt, key, info):
# Section 3.10 defines the mapping from KDF to Crypto_KDF but it is wrong!
# The arg order is correct above.
return Crypto_KDF(key, salt, info, CRYPTO_HASH_LEN_BITS)


def Crypto_P2(tt, pA, pB) -> tuple[bytes, bytes, bytes]:
KaKe = Crypto_Hash(tt)
Ka = KaKe[: CRYPTO_HASH_LEN_BYTES // 2]
Ke = KaKe[CRYPTO_HASH_LEN_BYTES // 2 :]
# https://github.com/project-chip/connectedhomeip/blob/c88d5cf83cd3e3323ac196630acc34f196a2f405/src/crypto/CHIPCryptoPAL.cpp#L458-L468
KcAKcB = KDF(None, Ka, b"ConfirmationKeys")
KcA = KcAKcB[:CRYPTO_GROUP_SIZE_BYTES]
KcB = KcAKcB[CRYPTO_GROUP_SIZE_BYTES:]
cA = Crypto_HMAC(KcA, pB)
cB = Crypto_HMAC(KcB, pA)
return (cA, cB, Ke)


class CircuitMatter:
def __init__(self, socketpool, mdns_server, state_filename, record_to=None):
self.socketpool = socketpool
Expand All @@ -837,7 +993,7 @@ def __init__(self, socketpool, mdns_server, state_filename, record_to=None):
with open(state_filename, "r") as state_file:
self.nonvolatile = json.load(state_file)

for key in ["descriminator", "salt", "iteration-count"]:
for key in ["descriminator", "salt", "iteration-count", "verifier"]:
if key not in self.nonvolatile:
raise RuntimeError(f"Missing key {key} in state file")

Expand Down Expand Up @@ -946,6 +1102,10 @@ def process_packet(self, address, data):
print("Received PBKDF Parameter Request")
# This is Section 4.14.1.2
request = PBKDFParamRequest(message.application_payload[1:-1])
exchange.commissioning_hash = hashlib.sha256(
b"CHIP PAKE V1 Commissioning"
)
exchange.commissioning_hash.update(message.application_payload)
if request.passcodeId == 0:
pass
# Send back failure
Expand All @@ -965,6 +1125,7 @@ def process_packet(self, address, data):
params.iterations = self.nonvolatile["iteration-count"]
params.salt = binascii.a2b_base64(self.nonvolatile["salt"])
response.pbkdf_parameters = params
exchange.commissioning_hash.update(response.encode())
exchange.send(
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.PBKDF_PARAM_RESPONSE,
Expand All @@ -975,6 +1136,24 @@ def process_packet(self, address, data):
print("Received PBKDF Parameter Response")
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE1:
print("Received PASE PAKE1")
pake1 = PAKE1(message.application_payload[1:-1])
print(pake1)
pake2 = PAKE2()
verifier = binascii.a2b_base64(self.nonvolatile["verifier"])
w0 = memoryview(verifier)[:CRYPTO_GROUP_SIZE_BYTES]
L = memoryview(verifier)[CRYPTO_GROUP_SIZE_BYTES:]
pake2.pB = Crypto_pB(w0, L)
# TODO: Compute these
Z = b""
V = b""
tt = Crypto_Transcript(
exchange.commissioning_hash.digest(), pake1.pA, pake2.pB, Z, V, w0
)
cA, cB, Ke = Crypto_P2(tt, pake1.pA, pake2.pB)
pake2.cB = cB
exchange.send(
ProtocolId.SECURE_CHANNEL, SecureProtocolOpcode.PASE_PAKE2, pake2
)
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE2:
print("Received PASE PAKE2")
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE3:
Expand Down
3 changes: 2 additions & 1 deletion test_data/device_state.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"descriminator": 2207,
"iteration-count": 10000,
"salt": "5uCP0ITHYzI9qBEe6hfU4HfY3y7VopSk0qNvhvznhiQ="
"salt": "5uCP0ITHYzI9qBEe6hfU4HfY3y7VopSk0qNvhvznhiQ=",
"verifier": "wxAsyKc/NiJkxkXfi9zu8aVXfMR5zOTmTA2ssdg5B+wEFcTyXODL7NAqAgFIUUvIdgZL3lB7ZoHmQDTroBgAV4ZebS6l5jrklt97N418Wnypeoi9JED6aVVDpmTivkFFUw=="
}

0 comments on commit 5f14782

Please sign in to comment.