Skip to content

Commit

Permalink
Use MlKem instead of Kyber.
Browse files Browse the repository at this point in the history
  • Loading branch information
dajiaji committed Sep 15, 2024
1 parent 70061b7 commit 2de550b
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 60 deletions.
6 changes: 3 additions & 3 deletions mod.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export { MlKemError } from "./src/errors.ts";
export { MlKem512 } from "./src/kyber512.ts";
export { MlKem768 } from "./src/kyber768.ts";
export { MlKem1024 } from "./src/kyber1024.ts";
export { MlKem512 } from "./src/mlKem512.ts";
export { MlKem768 } from "./src/mlKem768.ts";
export { MlKem1024 } from "./src/mlKem1024.ts";
22 changes: 14 additions & 8 deletions src/kyber1024.ts → src/mlKem1024.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,37 @@
* https://github.com/antontutoveanu/crystals-kyber-javascript/blob/main/LICENSE
*/
import { N, Q } from "./consts.ts";
import { KyberBase } from "./kyberBase.ts";
import { MlKemBase } from "./mlKemBase.ts";
import { byte, int16, uint16, uint32 } from "./utils.ts";

/**
* Represents the MlKem1024 class.
* Represents the MlKem1024 class, which extends the MlKemBase class.
*
* MlKem1024 is a subclass of KyberBase and implements specific methods for the Kyber-1024 parameter set.
* This class extends the MlKemBase class and provides specific implementation for MlKem1024.
*
* @remarks
*
* MlKem1024 is a specific implementation of the ML-KEM key encapsulation mechanism.
*
* @example
*
* ```ts
* // import { MlKem1024 } from "crystals-kyber-js"; // Node.js
* import { MlKem1024 } from "http://deno.land/x/crystals_kyber/mod.ts"; // Deno
* // Using jsr:
* import { MlKem1024 } from "@dajiaji/mlkem";
* // Using npm:
* // import { MlKem1024 } from "mlkem"; // or "crystals-kyber-js"
*
* const recipient = new MlKem1024();
* const [pkR, skR] = await recipient.generateKeyPair();
*
* const sender = new MlKem1024();
* const [ct, ssS] = await sender.encap(pkR);
*
* const ssR = await recipient.decap(ct, skR);
* // ssS === ssR
* ```
*/
export class MlKem1024 extends KyberBase {
export class MlKem1024 extends MlKemBase {
protected _k = 4;
protected _du = 11;
protected _dv = 5;
Expand Down
14 changes: 8 additions & 6 deletions src/kyber512.ts → src/mlKem512.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,25 @@
* https://github.com/antontutoveanu/crystals-kyber-javascript/blob/main/LICENSE
*/
import { N } from "./consts.ts";
import { KyberBase } from "./kyberBase.ts";
import { MlKemBase } from "./mlKemBase.ts";
import { byteopsLoad24, int16, prf } from "./utils.ts";

/**
* Represents the MlKem512 class.
*
* This class extends the KyberBase class and provides specific implementation for MlKem512.
* This class extends the MlKemBase class and provides specific implementation for MlKem512.
*
* @remarks
*
* MlKem512 is a specific implementation of the Kyber key encapsulation mechanism.
* MlKem512 is a specific implementation of the ML-KEM key encapsulation mechanism.
*
* @example
*
* ```ts
* // import { MlKem512 } from "crystals-kyber-js"; // Node.js
* import { MlKem512 } from "http://deno.land/x/crystals_kyber/mod.ts"; // Deno
* // Using jsr:
* import { MlKem512 } from "@dajiaji/mlkem";
* // Using npm:
* // import { MlKem512 } from "mlkem"; // or "crystals-kyber-js"
*
* const recipient = new MlKem512();
* const [pkR, skR] = await recipient.generateKeyPair();
Expand All @@ -32,7 +34,7 @@ import { byteopsLoad24, int16, prf } from "./utils.ts";
* // ssS === ssR
* ```
*/
export class MlKem512 extends KyberBase {
export class MlKem512 extends MlKemBase {
protected _k = 2;
protected _du = 10;
protected _dv = 4;
Expand Down
16 changes: 9 additions & 7 deletions src/kyber768.ts → src/mlKem768.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,24 @@
* https://github.com/antontutoveanu/crystals-kyber-javascript/blob/main/LICENSE
*/
import { N } from "./consts.ts";
import { KyberBase } from "./kyberBase.ts";
import { MlKemBase } from "./mlKemBase.ts";

/**
* Represents the MlKem768 class, which extends the KyberBase class.
* Represents the MlKem768 class, which extends the MlKemBase class.
*
* MlKem768 is a specific implementation of the Kyber key encapsulation mechanism.
* This class extends the MlKemBase class and provides specific implementation for MlKem768.
*
* @remarks
*
* This class extends the KyberBase class and provides specific implementation for MlKem768.
* MlKem768 is a specific implementation of the ML-KEM key encapsulation mechanism.
*
* @example
*
* ```ts
* // import { MlKem768 } from "crystals-kyber-js"; // Node.js
* import { MlKem768 } from "http://deno.land/x/crystals_kyber/mod.ts"; // Deno
* // Using jsr:
* import { MlKem768 } from "@dajiaji/mlkem";
* // Using npm:
* // import { MlKem768 } from "mlkem"; // or "crystals-kyber-js"
*
* const recipient = new MlKem768();
* const [pkR, skR] = await recipient.generateKeyPair();
Expand All @@ -31,7 +33,7 @@ import { KyberBase } from "./kyberBase.ts";
* // ssS === ssR
* ```
*/
export class MlKem768 extends KyberBase {
export class MlKem768 extends MlKemBase {
protected _k = 3;
protected _du = 10;
protected _dv = 4;
Expand Down
26 changes: 14 additions & 12 deletions src/kyberBase.ts → src/mlKemBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import {
} from "./utils.ts";

/**
* Represents the base class for the Kyber key encapsulation mechanism.
* Represents the base class for the ML-KEM key encapsulation mechanism.
*
* This class provides the base implementation for the Kyber key encapsulation mechanism.
* This class provides the base implementation for the ML-KEM key encapsulation mechanism.
*
* @remarks
*
Expand All @@ -32,10 +32,12 @@ import {
* @example
*
* ```ts
* // import { KyberBase } from "crystals-kyber-js"; // Node.js
* import { KyberBase } from "http://deno.land/x/crystals_kyber/mod.ts"; // Deno
* // Using jsr:
* import { MlKemBase } from "@dajiaji/mlkem";
* // Using npm:
* // import { MlKemBase } from "mlkem"; // or "crystals-kyber-js"
*
* class MlKem768 extends KyberBase {
* class MlKem768 extends MlKemBase {
* protected _k = 3;
* protected _du = 10;
* protected _dv = 4;
Expand All @@ -54,7 +56,7 @@ import {
* const kyber = new MlKem768();
* ```
*/
export class KyberBase {
export class MlKemBase {
private _api: Crypto | undefined = undefined;
protected _k = 0;
protected _du = 0;
Expand All @@ -67,7 +69,7 @@ export class KyberBase {
protected _compressedVSize = 0;

/**
* Creates a new instance of the KyberBase class.
* Creates a new instance of the MlKemBase class.
*/
constructor() {}

Expand Down Expand Up @@ -235,7 +237,7 @@ export class KyberBase {
}

/**
* Sets up the KyberBase instance by loading the necessary crypto library.
* Sets up the MlKemBase instance by loading the necessary crypto library.
* If the crypto library is already loaded, this method does nothing.
* @returns {Promise<void>} A promise that resolves when the setup is complete.
*/
Expand Down Expand Up @@ -289,7 +291,7 @@ export class KyberBase {
}

// indcpaKeyGen generates public and private keys for the CPA-secure
// public-key encryption scheme underlying Kyber.
// public-key encryption scheme underlying ML-KEM.

/**
* Derives a CPA key pair using the provided CPA seed.
Expand Down Expand Up @@ -338,10 +340,10 @@ export class KyberBase {
}

// _encap is the encapsulation function of the CPA-secure
// public-key encryption scheme underlying Kyber.
// public-key encryption scheme underlying ML-KEM.

/**
* Encapsulates a message using the Kyber encryption scheme.
* Encapsulates a message using the ML-KEM encryption scheme.
*
* @param pk - The public key.
* @param msg - The message to be encapsulated.
Expand Down Expand Up @@ -399,7 +401,7 @@ export class KyberBase {
}

// indcpaDecrypt is the decryption function of the CPA-secure
// public-key encryption scheme underlying Kyber.
// public-key encryption scheme underlying ML-KEM.

/**
* Decapsulates the ciphertext using the provided secret key.
Expand Down
10 changes: 5 additions & 5 deletions test/drng.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import type { KyberBase } from "../src/kyberBase.ts";
import type { MlKemBase } from "../src/mlKemBase.ts";
import { shake128 } from "../src/deps.ts";

type GetRandomValuesInputType = Parameters<
typeof Crypto.prototype.getRandomValues
>[0];

export function getDeterministicKyberClass<T extends typeof KyberBase>(
KyberClass: T,
): typeof KyberBase {
export function getDeterministicMlKemClass<T extends typeof MlKemBase>(
MlKemClass: T,
): typeof MlKemBase {
// @ts-ignore mixing constructor error expecting any[] as argument
return class DeterministicKyber extends KyberClass {
return class DeterministicMlKem extends MlKemClass {
// deno-lint-ignore require-await
async _setup() {
// @ts-ignore private accessor
Expand Down
38 changes: 19 additions & 19 deletions test/kyber.test.ts → test/mlkem.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ import { MlKem1024, MlKem512, MlKem768, MlKemError } from "../mod.ts";
import { loadCrypto } from "../src/utils.ts";
import { parseKAT, testVectorPath } from "./utils.ts";
import { bytesToHex, hexToBytes } from "./utils.ts";
import { getDeterministicKyberClass } from "./drng.ts";
import { getDeterministicMlKemClass } from "./drng.ts";

[MlKem512, MlKem768, MlKem1024].forEach((KyberClass) =>
describe(KyberClass.name, () => {
const size = KyberClass.name.substring(5);
const DeterministicKyberClass = getDeterministicKyberClass(KyberClass);
[MlKem512, MlKem768, MlKem1024].forEach((MlKemClass) =>
describe(MlKemClass.name, () => {
const size = MlKemClass.name.substring(5);
const DeterministicMlKemClass = getDeterministicMlKemClass(MlKemClass);

describe("KAT vectors", () => {
it("should match expected values", async () => {
const kyber = new KyberClass();
const kyber = new MlKemClass();
const katData = await Deno.readTextFile(
`${testVectorPath()}/kat/kat_MLKEM_${size}.rsp`,
);
Expand All @@ -35,10 +35,10 @@ import { getDeterministicKyberClass } from "./drng.ts";

describe("A sample code in README.", () => {
it("should work normally", async () => {
const recipient = new KyberClass();
const recipient = new MlKemClass();
const [pkR, skR] = await recipient.generateKeyPair();

const sender = new KyberClass();
const sender = new MlKemClass();
const [ct, ssS] = await sender.encap(pkR);

const ssR = await recipient.decap(ct, skR);
Expand All @@ -47,7 +47,7 @@ import { getDeterministicKyberClass } from "./drng.ts";
});

it("should work normally with deriveKeyPair", async () => {
const recipient = new KyberClass();
const recipient = new MlKemClass();
const api = await loadCrypto();
const seed = new Uint8Array(64);
api.getRandomValues(seed);
Expand All @@ -56,7 +56,7 @@ import { getDeterministicKyberClass } from "./drng.ts";
assertEquals(pkR, pkR2);
assertEquals(skR, skR2);

const sender = new KyberClass();
const sender = new MlKemClass();
const [ct, ssS] = await sender.encap(pkR);

const ssR = await recipient.decap(ct, skR);
Expand All @@ -67,7 +67,7 @@ import { getDeterministicKyberClass } from "./drng.ts";

describe("Advanced testing", () => {
it("Invalid encapsulation keys", async () => {
const sender = new KyberClass();
const sender = new MlKemClass();
const testData = await Deno.readTextFile(
`${testVectorPath()}/modulus/ML-KEM-${size}.txt`,
);
Expand All @@ -80,7 +80,7 @@ import { getDeterministicKyberClass } from "./drng.ts";
});

it("'Unlucky' vectors that require an unusually large number of XOF reads", async () => {
const kyber = new KyberClass();
const kyber = new MlKemClass();
const testData = await Deno.readTextFile(
`${testVectorPath()}/unluckysample/ML-KEM-${size}.txt`,
);
Expand All @@ -90,7 +90,7 @@ import { getDeterministicKyberClass } from "./drng.ts";
});

it("Accumulated vectors", async () => { // See https://github.com/C2SP/CCTV/blob/main/ML-KEM/README.md#accumulated-pq-crystals-vectors
const deterministicKyber = new DeterministicKyberClass();
const deterministicMlKem = new DeterministicMlKemClass();
const shakeInstance = shake128.create({ dkLen: 32 });
/**
* For each test, the following values are drawn from the RNG in order:
Expand Down Expand Up @@ -118,16 +118,16 @@ import { getDeterministicKyberClass } from "./drng.ts";
};

for (let i = 0; i < 10000; i++) {
const [ek, dk] = await deterministicKyber.generateKeyPair();
const [ct, k] = await deterministicKyber.encap(ek);
const kActual = await deterministicKyber.decap(ct, dk);
const [ek, dk] = await deterministicMlKem.generateKeyPair();
const [ct, k] = await deterministicMlKem.encap(ek);
const kActual = await deterministicMlKem.decap(ct, dk);
assertEquals(kActual, k);
// sample random, invalid ct
// @ts-ignore private accessor
const ctRandom = deterministicKyber._api!.getRandomValues(
const ctRandom = deterministicMlKem._api!.getRandomValues(
new Uint8Array(ct.length),
);
const kRandom = await deterministicKyber.decap(ctRandom, dk);
const kRandom = await deterministicMlKem.decap(ctRandom, dk);
// hash results
shakeInstance.update(ek)
.update(dk)
Expand All @@ -137,7 +137,7 @@ import { getDeterministicKyberClass } from "./drng.ts";
}

const actualHash = shakeInstance.digest();
assertEquals(bytesToHex(actualHash), expectedHashes[KyberClass.name]);
assertEquals(bytesToHex(actualHash), expectedHashes[MlKemClass.name]);
});
});
})
Expand Down

0 comments on commit 2de550b

Please sign in to comment.