From 40ac3d5359eb6a0b78fb447df260bf64335b9637 Mon Sep 17 00:00:00 2001 From: therealyingtong Date: Fri, 4 Oct 2024 15:15:48 +0800 Subject: [PATCH] plaintext_store_db --- .gitattributes | 0 Cargo.lock | 414 ++++++++++++--- iris-mpc-common/Cargo.toml | 1 + iris-mpc-common/src/iris_db/iris.rs | 6 +- iris-mpc-cpu/Cargo.toml | 9 +- iris-mpc-cpu/benches/assets/.gitattributes | 17 + .../assets/processed_masked_irises_chunk_0 | 3 + .../assets/processed_masked_irises_chunk_1 | 3 + .../assets/processed_masked_irises_chunk_2 | 3 + .../assets/processed_masked_irises_chunk_3 | 3 + .../assets/processed_masked_irises_chunk_4 | 3 + .../assets/processed_masked_irises_chunk_5 | 3 + .../assets/processed_masked_irises_chunk_6 | 3 + .../assets/processed_masked_irises_chunk_7 | 3 + .../assets/processed_masked_irises_chunk_8 | 3 + .../assets/processed_masked_irises_chunk_9 | 3 + iris-mpc-cpu/benches/hnsw.rs | 6 +- .../migrations/20240909105323_init.down.sql | 1 + .../migrations/20240909105323_init.up.sql | 5 + iris-mpc-cpu/src/hawkers/galois_store.rs | 53 +- iris-mpc-cpu/src/hawkers/mod.rs | 2 + iris-mpc-cpu/src/hawkers/ng_aby3_store.rs | 487 ++++++++++++++++++ iris-mpc-cpu/src/hawkers/plaintext_store.rs | 70 +-- .../src/hawkers/plaintext_store_db.rs | 469 +++++++++++++++++ iris-mpc-cpu/src/shares/vecshare.rs | 2 +- 25 files changed, 1446 insertions(+), 126 deletions(-) create mode 100644 .gitattributes create mode 100644 iris-mpc-cpu/benches/assets/.gitattributes create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_0 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_1 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_2 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_3 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_4 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_5 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_6 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_7 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_8 create mode 100644 iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_9 create mode 100644 iris-mpc-cpu/migrations/20240909105323_init.down.sql create mode 100644 iris-mpc-cpu/migrations/20240909105323_init.up.sql create mode 100644 iris-mpc-cpu/src/hawkers/ng_aby3_store.rs create mode 100644 iris-mpc-cpu/src/hawkers/plaintext_store_db.rs diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..e69de29bb diff --git a/Cargo.lock b/Cargo.lock index 0166994e7..898fd1052 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "aes" version = "0.8.4" @@ -147,6 +153,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +dependencies = [ + "derive_arbitrary", +] + [[package]] name = "arraydeque" version = "0.5.1" @@ -183,7 +198,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -304,9 +319,9 @@ dependencies = [ [[package]] name = "aws-sdk-kms" -version = "1.48.0" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2afbd208dabc6785946d4ef2444eb1f54fe0aaf0f62f2a4f9a9e9c303aeff0be" +checksum = "1f4c89f1d2e0df99ccd21f98598c1e587ad78bd87ae22a74aba392b5566bb038" dependencies = [ "aws-credential-types", "aws-runtime", @@ -326,9 +341,9 @@ dependencies = [ [[package]] name = "aws-sdk-s3" -version = "1.58.0" +version = "1.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0656a79cf5e6ab0d4bb2465cd750a7a2fd7ea26c062183ed94225f5782e22365" +checksum = "9f883bb1e349fa8343dc46336c252c0f32ceb6e81acb146aeef2e0f8afc9183e" dependencies = [ "aws-credential-types", "aws-runtime", @@ -360,9 +375,9 @@ dependencies = [ [[package]] name = "aws-sdk-secretsmanager" -version = "1.51.0" +version = "1.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdb70468666df09a91f7aeb2231f9c567f916c908b858f309e98aa4150c11c52" +checksum = "36da16e2e14f415bd779964a80e406cc924a5f2b5b382a5c71ef955909fdcb19" dependencies = [ "aws-credential-types", "aws-runtime", @@ -383,9 +398,9 @@ dependencies = [ [[package]] name = "aws-sdk-sns" -version = "1.48.0" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6fb2200e8fa934344edc64a94883e6c7daaeab9e496a8fa610a46704605a98" +checksum = "080fda28c775c607f6add6b0290633604269c0c54aba6901c97e374fdc1c498d" dependencies = [ "aws-credential-types", "aws-runtime", @@ -406,9 +421,9 @@ dependencies = [ [[package]] name = "aws-sdk-sqs" -version = "1.47.0" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88aef7bb5f3be3302ccbda94ac4109522925c5050c05db388f188bc8fb5d90e9" +checksum = "2964ce1905daddb62a7e6ceb2865c2e27058b74385fb87e9525037bd1805ca3a" dependencies = [ "aws-credential-types", "aws-runtime", @@ -428,9 +443,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.47.0" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8776850becacbd3a82a4737a9375ddb5c6832a51379f24443a98e61513f852c" +checksum = "ded855583fa1d22e88fe39fd6062b062376e50a8211989e07cf5e38d52eb3453" dependencies = [ "aws-credential-types", "aws-runtime", @@ -450,9 +465,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.48.0" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0007b5b8004547133319b6c4e87193eee2a0bcb3e4c18c75d09febe9dab7b383" +checksum = "9177ea1192e6601ae16c7273385690d88a7ed386a00b74a6bc894d12103cd933" dependencies = [ "aws-credential-types", "aws-runtime", @@ -472,9 +487,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.47.0" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fffaa356e7f1c725908b75136d53207fa714e348f365671df14e95a60530ad3" +checksum = "823ef553cf36713c97453e2ddff1eb8f62be7f4523544e2a5db64caf80100f0a" dependencies = [ "aws-credential-types", "aws-runtime", @@ -763,7 +778,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.4", "object", "rustc-demangle", ] @@ -836,7 +851,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.85", + "syn 2.0.86", "which", ] @@ -886,7 +901,7 @@ dependencies = [ "base64 0.13.1", "bitvec", "hex", - "indexmap", + "indexmap 2.6.0", "js-sys", "once_cell", "rand", @@ -920,7 +935,7 @@ checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -945,6 +960,27 @@ dependencies = [ "either", ] +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cadence" version = "1.5.0" @@ -996,6 +1032,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-targets 0.52.6", ] @@ -1079,7 +1116,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -1197,6 +1234,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "convert_case" version = "0.4.0" @@ -1414,8 +1457,18 @@ version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a01d95850c592940db9b8194bc39f4bc0e89dee5c4265e4b1807c34a9aba453c" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.13.4", + "darling_macro 0.13.4", +] + +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core 0.20.10", + "darling_macro 0.20.10", ] [[package]] @@ -1432,17 +1485,42 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.86", +] + [[package]] name = "darling_macro" version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c972679f83bdf9c42bd905396b6c3588a843a17f0f16dfcfa3e2c5d57441835" dependencies = [ - "darling_core", + "darling_core 0.13.4", "quote", "syn 1.0.109", ] +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core 0.20.10", + "quote", + "syn 2.0.86", +] + [[package]] name = "dashmap" version = "6.1.0" @@ -1481,6 +1559,12 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +[[package]] +name = "deflate64" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da692b8d1080ea3045efaab14434d40468c3d8657e42abddfffca87b428f4c1b" + [[package]] name = "der" version = "0.6.1" @@ -1509,6 +1593,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -1522,6 +1607,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.86", +] + [[package]] name = "derive_more" version = "0.99.18" @@ -1532,7 +1628,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version 0.4.1", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -1574,6 +1670,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.86", +] + [[package]] name = "dlv-list" version = "0.5.2" @@ -1761,6 +1868,16 @@ version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" +[[package]] +name = "flate2" +version = "1.0.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +dependencies = [ + "crc32fast", + "miniz_oxide 0.8.0", +] + [[package]] name = "float_eq" version = "1.0.1" @@ -1935,7 +2052,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -2043,7 +2160,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -2062,7 +2179,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -2079,6 +2196,12 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.5" @@ -2121,7 +2244,7 @@ dependencies = [ [[package]] name = "hawk-pack" version = "0.1.0" -source = "git+https://github.com/Inversed-Tech/hawk-pack.git?rev=400db13#400db13435acbdb6500ed3c8584c66281bf9e000" +source = "git+https://github.com/therealyingtong/hawk-pack.git?branch=new-with-params#f2071872b9032f4808d2b8716d6a95df41695a53" dependencies = [ "aes-prng 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "criterion", @@ -2476,6 +2599,17 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + [[package]] name = "indexmap" version = "2.6.0" @@ -2484,6 +2618,7 @@ checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", "hashbrown 0.15.0", + "serde", ] [[package]] @@ -2606,6 +2741,7 @@ dependencies = [ "serde", "serde-big-array", "serde_json", + "serde_with 3.11.0", "sha2", "sodiumoxide", "telemetry-batteries", @@ -2641,10 +2777,12 @@ dependencies = [ "rand", "rstest", "serde", + "sqlx", "static_assertions", "tokio", "tracing", "tracing-test", + "zip", ] [[package]] @@ -2907,6 +3045,12 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "lockfree-object-pool" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" + [[package]] name = "log" version = "0.4.22" @@ -2944,6 +3088,16 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "lzma-rs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297e814c836ae64db86b36cf2a557ba54368d03f6afcd7d947c266692f71115e" +dependencies = [ + "byteorder", + "crc", +] + [[package]] name = "match_cfg" version = "0.1.0" @@ -3028,7 +3182,7 @@ dependencies = [ "hyper 1.5.0", "hyper-rustls 0.27.3", "hyper-util", - "indexmap", + "indexmap 2.6.0", "ipnet", "metrics 0.23.0", "metrics-util", @@ -3096,6 +3250,15 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "1.0.2" @@ -3135,7 +3298,7 @@ dependencies = [ "hmac", "lazy_static", "md-5", - "pbkdf2", + "pbkdf2 0.11.0", "percent-encoding", "rand", "rustc_version_runtime", @@ -3143,7 +3306,7 @@ dependencies = [ "rustls-pemfile 1.0.4", "serde", "serde_bytes", - "serde_with", + "serde_with 1.14.0", "sha-1", "sha2", "socket2 0.4.10", @@ -3378,7 +3541,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -3422,7 +3585,7 @@ dependencies = [ "ahash", "futures-core", "http 1.1.0", - "indexmap", + "indexmap 2.6.0", "itertools 0.11.0", "itoa", "once_cell", @@ -3581,6 +3744,16 @@ dependencies = [ "digest", ] +[[package]] +name = "pbkdf2" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +dependencies = [ + "digest", + "hmac", +] + [[package]] name = "pem" version = "3.0.4" @@ -3637,7 +3810,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -3668,7 +3841,7 @@ checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -3785,7 +3958,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -4185,7 +4358,7 @@ dependencies = [ "regex", "relative-path", "rustc_version 0.4.1", - "syn 2.0.85", + "syn 2.0.86", "unicode-ident", ] @@ -4495,7 +4668,7 @@ checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -4504,7 +4677,7 @@ version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ - "indexmap", + "indexmap 2.6.0", "itoa", "memchr", "ryu", @@ -4549,7 +4722,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "678b5a069e50bf00ecd22d0cd8ddf7c236f68581b03db652061ed5eb13a312ff" dependencies = [ "serde", - "serde_with_macros", + "serde_with_macros 1.5.2", +] + +[[package]] +name = "serde_with" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.6.0", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros 3.11.0", + "time", ] [[package]] @@ -4558,12 +4749,24 @@ version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e182d6ec6f05393cc0e5ed1bf81ad6db3a8feedf8ee515ecdd369809bcce8082" dependencies = [ - "darling", + "darling 0.13.4", "proc-macro2", "quote", "syn 1.0.109", ] +[[package]] +name = "serde_with_macros" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" +dependencies = [ + "darling 0.20.10", + "proc-macro2", + "quote", + "syn 2.0.86", +] + [[package]] name = "sha-1" version = "0.10.1" @@ -4641,6 +4844,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "sketches-ddsketch" version = "0.2.2" @@ -4771,7 +4980,7 @@ dependencies = [ "hashbrown 0.14.5", "hashlink 0.9.1", "hex", - "indexmap", + "indexmap 2.6.0", "log", "memchr", "native-tls", @@ -4800,7 +5009,7 @@ dependencies = [ "quote", "sqlx-core", "sqlx-macros-core", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -4823,7 +5032,7 @@ dependencies = [ "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", - "syn 2.0.85", + "syn 2.0.86", "tempfile", "tokio", "url", @@ -4983,9 +5192,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.85" +version = "2.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" +checksum = "e89275301d38033efb81a6e60e3497e734dfcc62571f2854bf4b16690398824c" dependencies = [ "proc-macro2", "quote", @@ -5104,22 +5313,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" +checksum = "5d171f59dbaa811dbbb1aee1e73db92ec2b122911a48e1390dfe327a821ddede" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" +checksum = "b08be0f17bd307950653ce45db00cd31200d82b624b36e181337d9c7d92765b5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -5223,7 +5432,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -5320,7 +5529,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap", + "indexmap 2.6.0", "serde", "serde_spanned", "toml_datetime", @@ -5387,7 +5596,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -5488,7 +5697,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04659ddb06c87d233c566112c1c9c5b9e98256d9af50ec3bc9c8327f873a7568" dependencies = [ "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -5732,7 +5941,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", "wasm-bindgen-shared", ] @@ -5766,7 +5975,7 @@ checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -6154,7 +6363,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -6162,3 +6371,88 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.86", +] + +[[package]] +name = "zip" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc5e4288ea4057ae23afc69a4472434a87a2495cafce6632fd1c4ec9f5cf3494" +dependencies = [ + "aes", + "arbitrary", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "deflate64", + "displaydoc", + "flate2", + "hmac", + "indexmap 2.6.0", + "lzma-rs", + "memchr", + "pbkdf2 0.12.2", + "rand", + "sha1", + "thiserror", + "time", + "zeroize", + "zopfli", + "zstd", +] + +[[package]] +name = "zopfli" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5019f391bac5cf252e93bbcc53d039ffd62c7bfb7c150414d61369afe57e946" +dependencies = [ + "bumpalo", + "crc32fast", + "lockfree-object-pool", + "log", + "once_cell", + "simd-adler32", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/iris-mpc-common/Cargo.toml b/iris-mpc-common/Cargo.toml index a658dba54..335a3fd50 100644 --- a/iris-mpc-common/Cargo.toml +++ b/iris-mpc-common/Cargo.toml @@ -47,6 +47,7 @@ ring = "0.17.8" data-encoding = "2.6.0" bincode = "1.3.3" serde-big-array = "0.5.1" +serde_with = "3.11.0" [dev-dependencies] float_eq = "1" diff --git a/iris-mpc-common/src/iris_db/iris.rs b/iris-mpc-common/src/iris_db/iris.rs index b8acc9e88..74ceb6359 100644 --- a/iris-mpc-common/src/iris_db/iris.rs +++ b/iris-mpc-common/src/iris_db/iris.rs @@ -4,12 +4,14 @@ use rand::{ distributions::{Bernoulli, Distribution}, Rng, }; +use serde::{Deserialize, Serialize}; +use serde_big_array::BigArray; pub const MATCH_THRESHOLD_RATIO: f64 = 0.375; #[repr(transparent)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct IrisCodeArray(pub [u64; Self::IRIS_CODE_SIZE_U64]); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct IrisCodeArray(#[serde(with = "BigArray")] pub [u64; Self::IRIS_CODE_SIZE_U64]); impl Default for IrisCodeArray { fn default() -> Self { Self::ZERO diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index 46d2baf6b..39a98ee76 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -17,13 +17,14 @@ bytemuck.workspace = true dashmap = "6.1.0" eyre.workspace = true futures.workspace = true -hawk-pack = { git = "https://github.com/Inversed-Tech/hawk-pack.git", rev = "400db13" } +hawk-pack = { git = "https://github.com/therealyingtong/hawk-pack.git", branch = "new-with-params" } iris-mpc-common = { path = "../iris-mpc-common" } itertools.workspace = true num-traits.workspace = true rand.workspace = true rstest = "0.23.0" serde.workspace = true +sqlx.workspace = true static_assertions.workspace = true tokio.workspace = true tracing.workspace = true @@ -31,7 +32,11 @@ tracing-test = "0.2.5" [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } +zip = "2.2.0" + +[features] +db_dependent = [] [[bench]] name = "hnsw" -harness = false \ No newline at end of file +harness = false diff --git a/iris-mpc-cpu/benches/assets/.gitattributes b/iris-mpc-cpu/benches/assets/.gitattributes new file mode 100644 index 000000000..1b82ce12d --- /dev/null +++ b/iris-mpc-cpu/benches/assets/.gitattributes @@ -0,0 +1,17 @@ +hnsw_db_1000000_hawk_graph_links.csv.zip filter=lfs diff=lfs merge=lfs -text +hnsw_db_1000000_hawk_vectors.csv.zip filter=lfs diff=lfs merge=lfs -text +hnsw_db_100000_hawk_graph_links.csv.zip filter=lfs diff=lfs merge=lfs -text +hnsw_db_100000_hawk_vectors.csv.zip filter=lfs diff=lfs merge=lfs -text +hnsw_db_200000_hawk_graph_links.csv.zip filter=lfs diff=lfs merge=lfs -text +hnsw_db_200000_hawk_vectors.csv.zip filter=lfs diff=lfs merge=lfs -text +100K_rust_format_synthetic_data.dat.zip filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_8 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_9 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_3 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_4 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_2 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_5 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_6 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_7 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_0 filter=lfs diff=lfs merge=lfs -text +processed_masked_irises_chunk_1 filter=lfs diff=lfs merge=lfs -text diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_0 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_0 new file mode 100644 index 000000000..9923f72ad --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7cac4b339df2cde28e2de1776837cc8824fe3b40138c2b479f1220e12ecfda7f +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_1 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_1 new file mode 100644 index 000000000..9d84c8f3e --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_1 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa8ba22a085755bacd30bbacfb2d8d90738e586dd3d5ffc165a310b4e8ed6433 +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_2 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_2 new file mode 100644 index 000000000..35aa26205 --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:030a4eea8c5e066fedc6f3496a79b6b73fb9892d4fe9b5e4852800cb87c43b62 +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_3 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_3 new file mode 100644 index 000000000..2b2cfa757 --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7ab16e2f0d9a629693ff08b6263d6d1f157dd6f2ccfa2446628dbb09ced75b8 +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_4 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_4 new file mode 100644 index 000000000..6d8649aea --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d97d03932c9e34c496880eec1e5387ecd4f9fa18ba4663ee9e613c5fb07f3b4 +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_5 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_5 new file mode 100644 index 000000000..44c78d36d --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d598e1baca7d0493497d7a6078c43ceee90ee47759a28144724114896f352a2 +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_6 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_6 new file mode 100644 index 000000000..7a5e913f9 --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_6 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0146318477420af7aa3dd805630e6c16b89c96cd2f73b66d1ba613b413626b93 +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_7 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_7 new file mode 100644 index 000000000..2d99063e9 --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_7 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e2441f599517d41af205373ab4333df8eb36512039a32344dff7734c6aff59b +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_8 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_8 new file mode 100644 index 000000000..4a8079b80 --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_8 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ace6c6681ac8dea80788a3534ba0759b294f9b39ef50111d315667093543abeb +size 1280000000 diff --git a/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_9 b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_9 new file mode 100644 index 000000000..429a7594d --- /dev/null +++ b/iris-mpc-cpu/benches/assets/processed_masked_irises_chunk_9 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ff6e33ad767d7bbf71564ccfdcf928ba9dc46e15bc7470e05347ac96b6af4e0 +size 1280000000 diff --git a/iris-mpc-cpu/benches/hnsw.rs b/iris-mpc-cpu/benches/hnsw.rs index 30574db5a..42b97a3ba 100644 --- a/iris-mpc-cpu/benches/hnsw.rs +++ b/iris-mpc-cpu/benches/hnsw.rs @@ -16,7 +16,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { group.sample_size(10); group.sampling_mode(SamplingMode::Flat); - for database_size in [100_usize, 1000, 10000] { + for database_size in [10000] { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -30,7 +30,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { for _ in 0..database_size { let raw_query = IrisCode::random_rng(&mut rng); - let query = vector.prepare_query(raw_query.clone()); + let query = vector.prepare_query(raw_query.clone().into()); let neighbors = searcher .search_to_insert(&mut vector, &mut graph, &query) .await; @@ -55,7 +55,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { let searcher = HawkSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); - let query = db_vectors.prepare_query(on_the_fly_query); + let query = db_vectors.prepare_query(on_the_fly_query.into()); let neighbors = searcher .search_to_insert(&mut db_vectors, &mut graph, &query) .await; diff --git a/iris-mpc-cpu/migrations/20240909105323_init.down.sql b/iris-mpc-cpu/migrations/20240909105323_init.down.sql new file mode 100644 index 000000000..99af295b2 --- /dev/null +++ b/iris-mpc-cpu/migrations/20240909105323_init.down.sql @@ -0,0 +1 @@ +DROP TABLE hawk_vectors; diff --git a/iris-mpc-cpu/migrations/20240909105323_init.up.sql b/iris-mpc-cpu/migrations/20240909105323_init.up.sql new file mode 100644 index 000000000..163e2e86a --- /dev/null +++ b/iris-mpc-cpu/migrations/20240909105323_init.up.sql @@ -0,0 +1,5 @@ +CREATE TABLE IF NOT EXISTS hawk_vectors ( + id integer NOT NULL, + point jsonb NOT NULL, + CONSTRAINT hawk_vectors_pkey PRIMARY KEY (id) +); diff --git a/iris-mpc-cpu/src/hawkers/galois_store.rs b/iris-mpc-cpu/src/hawkers/galois_store.rs index be238366a..f27317c24 100644 --- a/iris-mpc-cpu/src/hawkers/galois_store.rs +++ b/iris-mpc-cpu/src/hawkers/galois_store.rs @@ -114,31 +114,6 @@ pub fn setup_local_store_aby3_players() -> eyre::Result) -> PointId { - assert_eq!(code.len(), 3); - assert_eq!(self.players.len(), 3); - let pid0 = self - .players - .get_mut(&Identity::from("alice")) - .unwrap() - .prepare_query(code[0].clone()); - let pid1 = self - .players - .get_mut(&Identity::from("bob")) - .unwrap() - .prepare_query(code[1].clone()); - let pid2 = self - .players - .get_mut(&Identity::from("charlie")) - .unwrap() - .prepare_query(code[2].clone()); - assert_eq!(pid0, pid1); - assert_eq!(pid1, pid2); - pid0 - } -} - #[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] #[serde(bound = "")] pub struct DistanceShare { @@ -167,6 +142,30 @@ impl VectorStore for LocalNetAby3NgStoreProtocol { type QueryRef = PointId; // Vector ID, pending insertion. type VectorRef = PointId; // Vector ID, inserted. type DistanceRef = Vec>; // Distance represented as shares. + type Data = Vec; + + fn prepare_query(&mut self, code: Vec) -> PointId { + assert_eq!(code.len(), 3); + assert_eq!(self.players.len(), 3); + let pid0 = self + .players + .get_mut(&Identity::from("alice")) + .unwrap() + .prepare_query(code[0].clone()); + let pid1 = self + .players + .get_mut(&Identity::from("bob")) + .unwrap() + .prepare_query(code[1].clone()); + let pid2 = self + .players + .get_mut(&Identity::from("charlie")) + .unwrap() + .prepare_query(code[2].clone()); + assert_eq!(pid0, pid1); + assert_eq!(pid1, pid2); + pid0 + } async fn insert(&mut self, query: &Self::QueryRef) -> Self::VectorRef { // The query is now accepted in the store. It keeps the same ID. @@ -353,7 +352,7 @@ pub async fn gr_create_ready_made_hawk_searcher( let searcher = HawkSearcher::default(); for raw_query in cleartext_database.iter() { - let query = plaintext_vector_store.prepare_query(raw_query.clone()); + let query = plaintext_vector_store.prepare_query(raw_query.clone().into()); let neighbors = searcher .search_to_insert( &mut plaintext_vector_store, @@ -594,7 +593,7 @@ mod tests { // Now do the work for the plaintext store let mut plaintext_store = PlaintextStore::default(); let plaintext_preps: Vec<_> = (0..db_dim) - .map(|id| plaintext_store.prepare_query(cleartext_database[id].clone())) + .map(|id| plaintext_store.prepare_query(cleartext_database[id].clone().into())) .collect(); let mut plaintext_inserts = Vec::new(); for p in plaintext_preps.iter() { diff --git a/iris-mpc-cpu/src/hawkers/mod.rs b/iris-mpc-cpu/src/hawkers/mod.rs index e2ec49a26..558696508 100644 --- a/iris-mpc-cpu/src/hawkers/mod.rs +++ b/iris-mpc-cpu/src/hawkers/mod.rs @@ -1,2 +1,4 @@ pub mod galois_store; pub mod plaintext_store; +// #[cfg(feature = "db_dependent")] +pub mod plaintext_store_db; diff --git a/iris-mpc-cpu/src/hawkers/ng_aby3_store.rs b/iris-mpc-cpu/src/hawkers/ng_aby3_store.rs new file mode 100644 index 000000000..7bf61d0a1 --- /dev/null +++ b/iris-mpc-cpu/src/hawkers/ng_aby3_store.rs @@ -0,0 +1,487 @@ +use super::plaintext_store::PlaintextStore; +use crate::{ + database_generators::{ng_generate_iris_shares, NgSharedIris}, + execution::player::Identity, + hawkers::plaintext_store::PointId, + next_gen_protocol::ng_worker::{ + ng_cross_compare, ng_replicated_is_match, ng_replicated_pairwise_distance, LocalRuntime, + }, +}; +use aes_prng::AesRng; +use hawk_pack::{graph_store::GraphMem, hnsw_db::HawkSearcher, VectorStore}; +use iris_mpc_common::iris_db::{db::IrisDB, iris::IrisCode}; +use rand::{RngCore, SeedableRng}; +use std::collections::HashMap; +use tokio::task::JoinSet; + +#[derive(Default, Clone)] +pub struct Aby3NgStorePlayer { + points: Vec, +} + +impl std::fmt::Debug for Aby3NgStorePlayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.points.fmt(f) + } +} + +#[derive(Eq, PartialEq, Clone, Debug)] +struct NgPoint { + /// Whatever encoding of a vector. + data: NgSharedIris, +} + +impl Aby3NgStorePlayer { + pub fn new_with_shared_db(data: Vec) -> Self { + let points: Vec = data.into_iter().map(|d| NgPoint { data: d }).collect(); + Aby3NgStorePlayer { points } + } + + pub fn prepare_query(&mut self, raw_query: NgSharedIris) -> PointId { + self.points.push(NgPoint { data: raw_query }); + + let point_id = self.points.len() - 1; + PointId(point_id) + } +} + +impl Aby3NgStorePlayer { + fn insert(&mut self, query: &PointId) -> PointId { + // The query is now accepted in the store. It keeps the same ID. + *query + } +} + +pub fn setup_local_player_preloaded_db( + database: Vec, +) -> eyre::Result { + let aby3_store = Aby3NgStorePlayer::new_with_shared_db(database); + Ok(aby3_store) +} + +pub fn setup_local_aby3_players_with_preloaded_db( + rng: &mut R, + database: Vec, +) -> eyre::Result { + let mut p0 = Vec::new(); + let mut p1 = Vec::new(); + let mut p2 = Vec::new(); + + for iris in database { + let all_shares = ng_generate_iris_shares(rng, iris); + p0.push(all_shares[0].clone()); + p1.push(all_shares[1].clone()); + p2.push(all_shares[2].clone()); + } + + let player_0 = setup_local_player_preloaded_db(p0)?; + let player_1 = setup_local_player_preloaded_db(p1)?; + let player_2 = setup_local_player_preloaded_db(p2)?; + let players = HashMap::from([ + (Identity::from("alice"), player_0), + (Identity::from("bob"), player_1), + (Identity::from("charlie"), player_2), + ]); + let runtime = LocalRuntime::replicated_test_config(); + Ok(LocalNetAby3NgStoreProtocol { runtime, players }) +} + +#[derive(Debug, Clone)] +pub struct LocalNetAby3NgStoreProtocol { + pub players: HashMap, + pub runtime: LocalRuntime, +} + +pub fn setup_local_store_aby3_players() -> eyre::Result { + let player_0 = Aby3NgStorePlayer::default(); + let player_1 = Aby3NgStorePlayer::default(); + let player_2 = Aby3NgStorePlayer::default(); + let runtime = LocalRuntime::replicated_test_config(); + let players = HashMap::from([ + (Identity::from("alice"), player_0), + (Identity::from("bob"), player_1), + (Identity::from("charlie"), player_2), + ]); + Ok(LocalNetAby3NgStoreProtocol { runtime, players }) +} + +impl VectorStore for LocalNetAby3NgStoreProtocol { + type QueryRef = PointId; // Vector ID, pending insertion. + type VectorRef = PointId; // Vector ID, inserted. + type DistanceRef = (PointId, PointId); // Lazy distance representation. + type Data = Vec; + + fn prepare_query(&mut self, code: Vec) -> PointId { + assert_eq!(code.len(), 3); + assert_eq!(self.players.len(), 3); + let pid0 = self + .players + .get_mut(&Identity::from("alice")) + .unwrap() + .prepare_query(code[0].clone()); + let pid1 = self + .players + .get_mut(&Identity::from("bob")) + .unwrap() + .prepare_query(code[1].clone()); + let pid2 = self + .players + .get_mut(&Identity::from("charlie")) + .unwrap() + .prepare_query(code[2].clone()); + assert_eq!(pid0, pid1); + assert_eq!(pid1, pid2); + pid0 + } + + async fn insert(&mut self, query: &Self::QueryRef) -> Self::VectorRef { + // The query is now accepted in the store. It keeps the same ID. + for (_id, storage) in self.players.iter_mut() { + storage.insert(query); + } + *query + } + + async fn eval_distance( + &self, + query: &Self::QueryRef, + vector: &Self::VectorRef, + ) -> Self::DistanceRef { + // Do not compute the distance yet, just forward the IDs. + (*query, *vector) + } + + async fn is_match(&self, distance: &Self::DistanceRef) -> bool { + let ready_sessions = self.runtime.create_player_sessions().await.unwrap(); + let mut jobs = JoinSet::new(); + for player in self.runtime.identities.clone() { + let mut player_session = ready_sessions.get(&player).unwrap().clone(); + let storage = self.players.get(&player).unwrap(); + let x = storage.points[distance.0.val()].clone(); + let y = storage.points[distance.1.val()].clone(); + jobs.spawn(async move { + ng_replicated_is_match(&mut player_session, &[(x.data, y.data)]) + .await + .unwrap() + }); + } + let r0 = jobs.join_next().await.unwrap().unwrap(); + let r1 = jobs.join_next().await.unwrap().unwrap(); + let r2 = jobs.join_next().await.unwrap().unwrap(); + assert_eq!(r0, r1); + assert_eq!(r1, r2); + r0 + } + + async fn less_than( + &self, + distance1: &Self::DistanceRef, + distance2: &Self::DistanceRef, + ) -> bool { + let d1 = *distance1; + let d2 = *distance2; + let ready_sessions = self.runtime.create_player_sessions().await.unwrap(); + let mut jobs = JoinSet::new(); + for player in self.runtime.identities.clone() { + let mut player_session = ready_sessions.get(&player).unwrap().clone(); + let storage = self.players.get(&player).unwrap(); + let (x1, y1) = ( + storage.points[d1.0.val()].clone(), + storage.points[d1.1.val()].clone(), + ); + let (x2, y2) = ( + storage.points[d2.0.val()].clone(), + storage.points[d2.1.val()].clone(), + ); + jobs.spawn(async move { + let ds_and_ts = ng_replicated_pairwise_distance(&mut player_session, &[ + (x1.data, y1.data), + (x2.data, y2.data), + ]) + .await + .unwrap(); + + ng_cross_compare( + &mut player_session, + ds_and_ts[0].clone(), + ds_and_ts[1].clone(), + ds_and_ts[2].clone(), + ds_and_ts[3].clone(), + ) + .await + .unwrap() + }); + } + + let r0 = jobs.join_next().await.unwrap().unwrap(); + let r1 = jobs.join_next().await.unwrap().unwrap(); + let r2 = jobs.join_next().await.unwrap().unwrap(); + assert_eq!(r0, r1); + assert_eq!(r1, r2); + r0 + } +} + +pub async fn ng_create_ready_made_hawk_searcher( + rng: &mut R, + database_size: usize, +) -> eyre::Result<( + HawkSearcher>, + HawkSearcher>, +)> { + // makes sure the searcher produces same graph structure by having the same rng + let mut rng_searcher1 = AesRng::from_rng(rng.clone())?; + let mut rng_searcher2 = rng_searcher1.clone(); + + let cleartext_database = IrisDB::new_random_rng(database_size, rng).db; + + let vector_store = PlaintextStore::default(); + let graph_store = GraphMem::new(); + let mut cleartext_searcher = HawkSearcher::new(vector_store, graph_store, &mut rng_searcher1); + + for raw_query in cleartext_database.iter() { + let query = cleartext_searcher + .vector_store + .prepare_query(raw_query.clone().into()); + let neighbors = cleartext_searcher.search_to_insert(&query).await; + let inserted = cleartext_searcher.vector_store.insert(&query).await; + cleartext_searcher + .insert_from_search_results(inserted, neighbors) + .await; + } + + let protocol_store = setup_local_aby3_players_with_preloaded_db(rng, cleartext_database)?; + let protocol_graph = GraphMem::::from_another( + cleartext_searcher.graph_store.clone(), + ); + let secret_searcher = HawkSearcher::new(protocol_store, protocol_graph, &mut rng_searcher2); + + Ok((cleartext_searcher, secret_searcher)) +} + +pub async fn ng_create_from_scratch_hawk_searcher( + rng: &mut R, + database_size: usize, +) -> eyre::Result>> +{ + let mut rng_searcher = AesRng::from_rng(rng.clone())?; + let cleartext_database = IrisDB::new_random_rng(database_size, rng).db; + let shared_irises: Vec<_> = (0..database_size) + .map(|id| ng_generate_iris_shares(rng, cleartext_database[id].clone())) + .collect(); + let aby3_store_protocol = setup_local_store_aby3_players().unwrap(); + + let graph_store = GraphMem::new(); + + let mut searcher = HawkSearcher::new(aby3_store_protocol, graph_store, &mut rng_searcher); + let queries = (0..database_size) + .map(|id| { + searcher + .vector_store + .prepare_query(shared_irises[id].clone()) + }) + .collect::>(); + + // insert queries + for query in queries.iter() { + let neighbors = searcher.search_to_insert(query).await; + searcher.insert_from_search_results(*query, neighbors).await; + } + + Ok(searcher) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::database_generators::ng_generate_iris_shares; + use aes_prng::AesRng; + use hawk_pack::{graph_store::GraphMem, hnsw_db::HawkSearcher}; + use iris_mpc_common::iris_db::db::IrisDB; + use itertools::Itertools; + use rand::SeedableRng; + use tracing_test::traced_test; + + #[tokio::test(flavor = "multi_thread")] + async fn test_ng_hnsw() { + let mut rng = AesRng::seed_from_u64(0_u64); + let database_size = 10; + let cleartext_database = IrisDB::new_random_rng(database_size, &mut rng).db; + + let aby3_store_protocol = setup_local_store_aby3_players().unwrap(); + let graph_store = GraphMem::new(); + let mut db = HawkSearcher::new(aby3_store_protocol, graph_store, &mut rng); + + let queries = (0..database_size) + .map(|id| { + db.vector_store.prepare_query( + ng_generate_iris_shares(&mut rng, cleartext_database[id].clone()).into(), + ) + }) + .collect::>(); + + // insert queries + for query in queries.iter() { + let neighbors = db.search_to_insert(query).await; + db.insert_from_search_results(*query, neighbors).await; + } + println!("FINISHED INSERTING"); + // Search for the same codes and find matches. + for (index, query) in queries.iter().enumerate() { + let neighbors = db.search_to_insert(query).await; + // assert_eq!(false, true); + tracing::debug!("Finished query"); + assert!(db.is_match(&neighbors).await, "failed at index {:?}", index); + } + } + + #[tokio::test(flavor = "multi_thread")] + #[traced_test] + async fn test_ng_premade_hnsw() { + let mut rng = AesRng::seed_from_u64(0_u64); + let database_size = 10; + let (cleartext_searcher, secret_searcher) = + ng_create_ready_made_hawk_searcher(&mut rng, database_size) + .await + .unwrap(); + + let mut rng = AesRng::seed_from_u64(0_u64); + let scratch_secret_searcher = ng_create_from_scratch_hawk_searcher(&mut rng, database_size) + .await + .unwrap(); + + assert_eq!( + scratch_secret_searcher + .vector_store + .players + .get(&Identity::from("alice")) + .unwrap() + .points, + secret_searcher + .vector_store + .players + .get(&Identity::from("alice")) + .unwrap() + .points + ); + assert_eq!( + scratch_secret_searcher + .vector_store + .players + .get(&Identity::from("bob")) + .unwrap() + .points, + secret_searcher + .vector_store + .players + .get(&Identity::from("bob")) + .unwrap() + .points + ); + assert_eq!( + scratch_secret_searcher + .vector_store + .players + .get(&Identity::from("charlie")) + .unwrap() + .points, + secret_searcher + .vector_store + .players + .get(&Identity::from("charlie")) + .unwrap() + .points + ); + + for i in 0..database_size { + let cleartext_neighbors = cleartext_searcher.search_to_insert(&PointId(i)).await; + assert!(cleartext_searcher.is_match(&cleartext_neighbors).await,); + + let secret_neighbors = secret_searcher.search_to_insert(&PointId(i)).await; + assert!(secret_searcher.is_match(&secret_neighbors).await); + + let scratch_secret_neighbors = + scratch_secret_searcher.search_to_insert(&PointId(i)).await; + assert!( + scratch_secret_searcher + .is_match(&scratch_secret_neighbors) + .await, + ); + } + } + + #[tokio::test(flavor = "multi_thread")] + #[traced_test] + async fn test_ng_aby3_store_plaintext() { + let mut rng = AesRng::seed_from_u64(0_u64); + let cleartext_database = IrisDB::new_random_rng(10, &mut rng).db; + + let mut aby3_store_protocol = setup_local_store_aby3_players().unwrap(); + let db_dim = 4; + + let aby3_preps: Vec<_> = (0..db_dim) + .map(|id| { + aby3_store_protocol.prepare_query(ng_generate_iris_shares( + &mut rng, + cleartext_database[id].clone(), + )) + }) + .collect(); + let mut aby3_inserts = Vec::new(); + for p in aby3_preps.iter() { + aby3_inserts.push(aby3_store_protocol.insert(p).await); + } + + // Now do the work for the plaintext store + let mut plaintext_store = PlaintextStore::default(); + let plaintext_preps: Vec<_> = (0..db_dim) + .map(|id| plaintext_store.prepare_query(cleartext_database[id].clone().into())) + .collect(); + let mut plaintext_inserts = Vec::new(); + for p in plaintext_preps.iter() { + plaintext_inserts.push(plaintext_store.insert(p).await); + } + let it1 = (0..db_dim).combinations(2); + let it2 = (0..db_dim).combinations(2); + for comb1 in it1 { + for comb2 in it2.clone() { + assert_eq!( + aby3_store_protocol + .less_than( + &(aby3_inserts[comb1[0]], aby3_inserts[comb1[1]]), + &(aby3_inserts[comb2[0]], aby3_inserts[comb2[1]]) + ) + .await, + plaintext_store + .less_than( + &(plaintext_inserts[comb1[0]], plaintext_inserts[comb1[1]]), + &(plaintext_inserts[comb2[0]], plaintext_inserts[comb2[1]]) + ) + .await, + "Failed at combo: {:?}, {:?}", + comb1, + comb2 + ) + } + } + } + + #[tokio::test(flavor = "multi_thread")] + #[traced_test] + async fn test_ng_scratch_hnsw() { + let mut rng = AesRng::seed_from_u64(0_u64); + let database_size = 2; + let secret_searcher = ng_create_from_scratch_hawk_searcher(&mut rng, database_size) + .await + .unwrap(); + + for i in 0..database_size { + let secret_neighbors = secret_searcher.search_to_insert(&PointId(i)).await; + assert!( + secret_searcher.is_match(&secret_neighbors).await, + "Failed at index {:?}", + i + ); + } + } +} diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 1d788c319..5ac6086dc 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,16 +1,17 @@ use hawk_pack::VectorStore; use iris_mpc_common::iris_db::iris::{IrisCode, IrisCodeArray, MATCH_THRESHOLD_RATIO}; use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; #[derive(Default, Debug, Clone)] pub struct PlaintextStore { - pub points: Vec, + pub points: BTreeMap, } -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct FormattedIris { - data: Vec, - mask: IrisCodeArray, + pub data: Vec, + pub mask: IrisCodeArray, } impl From for FormattedIris { @@ -39,13 +40,22 @@ impl FormattedIris { } } -#[derive(Clone, Default, Debug)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct PlaintextPoint { /// Whatever encoding of a vector. - data: FormattedIris, + pub data: FormattedIris, /// Distinguish between queries that are pending, and those that were /// ultimately accepted into the vector store. - is_persistent: bool, + pub is_persistent: bool, +} + +impl From for PlaintextPoint { + fn from(value: IrisCode) -> Self { + Self { + data: FormattedIris::from(value), + is_persistent: false, + } + } } impl FormattedIris { @@ -63,11 +73,11 @@ impl FormattedIris { } impl PlaintextPoint { - fn compute_distance(&self, other: &PlaintextPoint) -> (i16, usize) { + pub fn compute_distance(&self, other: &PlaintextPoint) -> (i16, usize) { self.data.compute_distance(&other.data) } - fn is_close(&self, other: &PlaintextPoint) -> bool { + pub fn is_close(&self, other: &PlaintextPoint) -> bool { let hd = self.data.dot_on_code(&other.data); let mask_ones = (self.data.mask & other.data.mask).count_ones(); let threshold = (mask_ones as f64) * (1. - 2. * MATCH_THRESHOLD_RATIO); @@ -75,7 +85,7 @@ impl PlaintextPoint { } } -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct PointId(pub usize); impl PointId { @@ -85,28 +95,18 @@ impl PointId { } impl PlaintextStore { - pub fn prepare_query(&mut self, raw_query: IrisCode) -> ::QueryRef { - self.points.push(PlaintextPoint { - data: FormattedIris::from(raw_query), - is_persistent: false, - }); - - let point_id = self.points.len() - 1; - PointId(point_id) - } - pub fn distance_computation( &self, distance1: &(PointId, PointId), distance2: &(PointId, PointId), ) -> (i32, i32) { let (x1, y1) = ( - &self.points[distance1.0.val()], - &self.points[distance1.1.val()], + &self.points.get(&distance1.0).unwrap(), + &self.points.get(&distance1.1).unwrap(), ); let (x2, y2) = ( - &self.points[distance2.0.val()], - &self.points[distance2.1.val()], + &self.points.get(&distance2.0).unwrap(), + &self.points.get(&distance2.1).unwrap(), ); let (d1, t1) = x1.compute_distance(y1); let (d2, t2) = x2.compute_distance(y2); @@ -120,10 +120,18 @@ impl VectorStore for PlaintextStore { type QueryRef = PointId; // Vector ID, pending insertion. type VectorRef = PointId; // Vector ID, inserted. type DistanceRef = (PointId, PointId); // Lazy distance representation. + type Data = PlaintextPoint; + + fn prepare_query(&mut self, raw_query: PlaintextPoint) -> PointId { + let point_id = PointId(self.points.len()); + self.points.insert(point_id, raw_query); + + point_id + } async fn insert(&mut self, query: &Self::QueryRef) -> Self::VectorRef { // The query is now accepted in the store. It keeps the same ID. - self.points[query.0].is_persistent = true; + self.points.get_mut(query).unwrap().is_persistent = true; *query } @@ -137,8 +145,8 @@ impl VectorStore for PlaintextStore { } async fn is_match(&self, distance: &Self::DistanceRef) -> bool { - let x = &self.points[distance.0 .0]; - let y = &self.points[distance.1 .0]; + let x = &self.points.get(&distance.0).unwrap(); + let y = &self.points.get(&distance.1).unwrap(); x.is_close(y) } @@ -173,10 +181,10 @@ mod tests { .collect(); let mut plaintext_store = PlaintextStore::default(); - let pid0 = plaintext_store.prepare_query(cleartext_database[0].clone()); - let pid1 = plaintext_store.prepare_query(cleartext_database[1].clone()); - let pid2 = plaintext_store.prepare_query(cleartext_database[2].clone()); - let pid3 = plaintext_store.prepare_query(cleartext_database[3].clone()); + let pid0 = plaintext_store.prepare_query(cleartext_database[0].clone().into()); + let pid1 = plaintext_store.prepare_query(cleartext_database[1].clone().into()); + let pid2 = plaintext_store.prepare_query(cleartext_database[2].clone().into()); + let pid3 = plaintext_store.prepare_query(cleartext_database[3].clone().into()); let q0 = plaintext_store.insert(&pid0).await; let q1 = plaintext_store.insert(&pid1).await; diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store_db.rs b/iris-mpc-cpu/src/hawkers/plaintext_store_db.rs new file mode 100644 index 000000000..e20e26c17 --- /dev/null +++ b/iris-mpc-cpu/src/hawkers/plaintext_store_db.rs @@ -0,0 +1,469 @@ +use super::plaintext_store::{PlaintextPoint, PlaintextStore, PointId}; +use eyre::{eyre, Result}; +use futures::stream::TryStreamExt; +use hawk_pack::{DbStore, VectorStore}; +use sqlx::{ + migrate::Migrator, + postgres::{PgPoolOptions, PgRow}, + Executor, PgPool, Row, +}; +use std::{collections::BTreeMap, path}; +use tokio::io::AsyncWriteExt; + +const MAX_CONNECTIONS: u32 = 5; + +static MIGRATOR: Migrator = sqlx::migrate!("./migrations"); + +#[derive(Debug, Clone)] +pub struct PlaintextStoreDb { + cache: PlaintextStore, + schema_name: String, + pool: sqlx::PgPool, +} + +impl VectorStore for PlaintextStoreDb { + type QueryRef = PointId; // Vector ID, pending insertion. + type VectorRef = PointId; // Vector ID, inserted. + type DistanceRef = (PointId, PointId); // Lazy distance representation. + type Data = PlaintextPoint; + + fn prepare_query(&mut self, raw_query: PlaintextPoint) -> PointId { + self.cache.prepare_query(raw_query) + } + + async fn insert(&mut self, query: &Self::QueryRef) -> Self::VectorRef { + let point = self.get_point(*query).await.unwrap(); + + sqlx::query( + " + INSERT INTO hawk_vectors (id, point) + VALUES ($1, $2) + ", + ) + .bind(query.val() as i32) + .bind(sqlx::types::Json(point)) + .execute(&self.pool) + .await + .expect(&format!( + "Failed to insert query {} into vector store", + query.val() + )); + + *query + } + + async fn eval_distance( + &self, + query: &Self::QueryRef, + vector: &Self::VectorRef, + ) -> Self::DistanceRef { + // Do not compute the distance yet, just forward the IDs. + (*query, *vector) + } + + async fn is_match(&self, distance: &Self::DistanceRef) -> bool { + let x = &self.get_point(distance.0).await.unwrap(); + let y = &self.get_point(distance.1).await.unwrap(); + x.is_close(y) + } + + async fn less_than( + &self, + distance1: &Self::DistanceRef, + distance2: &Self::DistanceRef, + ) -> bool { + let (d2t1, d1t2) = self.distance_computation(distance1, distance2).await; + (d2t1 - d1t2) < 0 + } +} + +impl DbStore for PlaintextStoreDb { + async fn new(url: &str, schema_name: &str) -> Result { + let connect_sql = sql_switch_schema(schema_name)?; + + let pool = PgPoolOptions::new() + .max_connections(MAX_CONNECTIONS) + .after_connect(move |conn, _meta| { + // Switch to the given schema in every connection. + let connect_sql = connect_sql.clone(); + Box::pin(async move { + conn.execute(connect_sql.as_ref()).await.inspect_err(|e| { + eprintln!("error in after_connect: {:?}", e); + })?; + Ok(()) + }) + }) + .connect(url) + .await?; + + // Create the schema on the first startup. + MIGRATOR.run(&pool).await?; + + Ok(PlaintextStoreDb { + cache: PlaintextStore { + points: BTreeMap::new(), + }, + schema_name: schema_name.to_owned(), + pool, + }) + } + + fn pool(&self) -> &PgPool { + &self.pool + } + + fn schema_name(&self) -> String { + self.schema_name.to_string() + } + + async fn copy_out(&self) -> Result> { + let file_name = format!("{}_vectors.csv", self.schema_name.clone()); + self.copy_out_with_filename(file_name).await + } +} + +impl PlaintextStoreDb { + pub async fn to_plaintext_store(&self) -> PlaintextStore { + let points = sqlx::query( + " + SELECT * FROM hawk_vectors + ", + ) + .fetch_all(&self.pool) + .await + .unwrap() + .iter() + .map(|row| { + let id: i32 = row.get("id"); + let point: sqlx::types::Json = row.get("point"); + (PointId(id as usize), point.as_ref().clone()) + }) + .collect(); + + PlaintextStore { points } + } + + pub async fn get_point(&self, point: PointId) -> Option { + let mut res = self.cache.points.get(&point).map(|p| p.clone()); + if res.is_none() { + res = sqlx::query( + " + SELECT point FROM hawk_vectors WHERE id = $1 + ", + ) + .bind(point.0 as i32) + .fetch_optional(&self.pool) + .await + .expect(&format!("Failed to fetch point {}", point.0)) + .map(|row: PgRow| { + let x: sqlx::types::Json = row.get("point"); + x.as_ref().clone() + }); + } + res + } + + pub async fn distance_computation( + &self, + distance1: &(PointId, PointId), + distance2: &(PointId, PointId), + ) -> (i32, i32) { + let (x1, y1) = ( + &self.get_point(distance1.0).await.unwrap(), + &self.get_point(distance1.1).await.unwrap(), + ); + let (x2, y2) = ( + &self.get_point(distance2.0).await.unwrap(), + &self.get_point(distance2.1).await.unwrap(), + ); + let (d1, t1) = x1.compute_distance(y1); + let (d2, t2) = x2.compute_distance(y2); + + let cross_1 = d2 as i32 * t1 as i32; + let cross_2 = d1 as i32 * t2 as i32; + (cross_1, cross_2) + } + + async fn copy_out_with_filename(&self, file_name: String) -> Result> { + let table_name = "hawk_vectors"; + + let path = path::absolute(file_name.clone())? + .as_os_str() + .to_str() + .unwrap() + .to_owned(); + + let mut file = tokio::fs::File::create(path.clone()).await?; + let mut conn = self.pool.acquire().await?; + + let mut copy_stream = conn + .copy_out_raw(&format!( + "COPY {} TO STDOUT (FORMAT CSV, HEADER)", + table_name + )) + .await?; + + while let Some(chunk) = copy_stream.try_next().await? { + file.write_all(&chunk).await?; + } + + Ok(vec![(table_name.to_string(), path)]) + } +} + +fn sql_switch_schema(schema_name: &str) -> Result { + sanitize_identifier(schema_name)?; + Ok(format!( + " + CREATE SCHEMA IF NOT EXISTS \"{}\"; + SET search_path TO \"{}\"; + ", + schema_name, schema_name + )) +} + +fn sanitize_identifier(input: &str) -> Result<()> { + if input.chars().all(|c| c.is_alphanumeric() || c == '_') { + Ok(()) + } else { + Err(eyre!("Invalid SQL identifier")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hawkers::plaintext_store::FormattedIris; + use aes_prng::AesRng; + use hawk_pack::{graph_store::GraphPg, hnsw_db::HawkSearcher}; + use iris_mpc_common::iris_db::iris::{IrisCode, IrisCodeArray}; + use rand::SeedableRng; + + struct PlaintextPointReader { + inner: R, + } + + impl PlaintextPointReader { + pub fn new(inner: R) -> Self { + Self { inner } + } + } + + impl Iterator for PlaintextPointReader { + type Item = Vec; + + fn next(&mut self) -> Option { + let mut buf: [u8; 12800] = [0; 12800]; + self.inner.read_exact(&mut buf).ok()?; + + Some( + buf.iter() + .map(|&byte| match byte { + 2 => -1, + _ => byte as i8, + }) + .collect(), + ) + } + } + + fn plaintext_point_from_data(data: Vec) -> PlaintextPoint { + let mut mask = IrisCodeArray::ZERO; + + for (bit, &val) in data.iter().enumerate() { + if val != 0 { + mask.set_bit(bit, true); + } + } + let iris_code = FormattedIris { data, mask }; + PlaintextPoint { + data: iris_code, + is_persistent: false, + } + } + + #[tokio::test] + async fn hawk_searcher_from_db() { + let database_size = 100; + let schema_name = format!("hnsw_db_{}", database_size.to_string()); + let temporary_name = || format!("{}_{}", schema_name, rand::random::()); + let hawk_database_url: &str = "postgres://postgres:postgres@localhost/postgres"; + + let mut rng = AesRng::seed_from_u64(0_u64); + let mut graph_store = + GraphPg::::new(hawk_database_url, &temporary_name()) + .await + .unwrap(); + let mut vector_store = PlaintextStoreDb::new(hawk_database_url, &temporary_name()) + .await + .unwrap(); + let plain_searcher = HawkSearcher::default(); + + let queries = (0..database_size) + .map(|_| { + let raw_query = IrisCode::random_rng(&mut rng); + vector_store.prepare_query(raw_query.into()) + }) + .collect::>(); + + for query in queries.iter() { + let neighbors = plain_searcher + .search_to_insert(&mut vector_store, &mut graph_store, query) + .await; + let inserted = vector_store.insert(query).await; + plain_searcher + .insert_from_search_results( + &mut vector_store, + &mut graph_store, + &mut rng, + inserted, + neighbors, + ) + .await; + } + let graph_path = graph_store.copy_out().await.unwrap(); + let vectors_path = vector_store.copy_out().await.unwrap(); + + graph_store.cleanup().await.unwrap(); + vector_store.cleanup().await.unwrap(); + + // Copy in to memory + { + let graph_store = + GraphPg::::new(hawk_database_url, &temporary_name()) + .await + .unwrap(); + let vector_store = PlaintextStoreDb::new(hawk_database_url, &temporary_name()) + .await + .unwrap(); + + graph_store.copy_in(graph_path).await.unwrap(); + let mut graph_mem = graph_store.to_graph_mem().await; + + vector_store.copy_in(vectors_path).await.unwrap(); + let mut vector_mem = vector_store.to_plaintext_store().await; + vector_store.cleanup().await.unwrap(); + + let plain_searcher = HawkSearcher::default(); + + for query in queries.iter() { + let neighbors = plain_searcher + .search_to_insert(&mut vector_mem, &mut graph_mem, query) + .await; + assert!(plain_searcher.is_match(&mut vector_mem, &neighbors).await); + } + graph_store.cleanup().await.unwrap(); + } + } + + #[tokio::test] + async fn checkpoint_from_data() { + use std::io::BufReader; + + let step_size = 50000; + let database_size = 1000000; + let m_values = [128]; + + let mut rng = AesRng::seed_from_u64(0_u64); + let hawk_database_url: &str = "postgres://postgres:postgres@localhost/postgres"; + + let mut queries = vec![]; + for chunk in 0..10 { + let dat_filename = format!("benches/assets/processed_masked_irises_chunk_{}", chunk); + let dat_path = path::absolute(dat_filename).unwrap(); + + let input = + BufReader::new(std::fs::File::open(dat_path.clone()).expect("Failed to open file")); + let values: Vec> = PlaintextPointReader::new(input).collect(); + let mut values: Vec<_> = values + .into_iter() + .map(|data| plaintext_point_from_data(data)) + .collect(); + queries.append(&mut values); + } + + for m in m_values.iter() { + let mut prev_checkpoint_name = None; + for checkpoint in 1..=(database_size / step_size) { + let checkpoint_time = std::time::Instant::now(); + println!("M: {:?}, checkpoint: {:?}", m, checkpoint * step_size); + + // Copy in vectors + let mut vector_mem = PlaintextStore { + points: queries + .iter() + .enumerate() + .map(|(id, val)| (PointId(id), val.clone())) + .collect(), + }; + + println!("vector_mem.points.len(): {:?}", vector_mem.points.len()); + + // Copy in graph + let mut graph_mem = { + let graph_store = + GraphPg::::new(hawk_database_url, &"hnsw_1M") + .await + .unwrap(); + if let Some(prev_checkpoint_name) = prev_checkpoint_name { + graph_store.copy_in(prev_checkpoint_name).await.unwrap(); + } + let graph_mem = graph_store.to_graph_mem().await; + graph_store.cleanup().await.unwrap(); + graph_mem + }; + + println!( + "graph_mem.layers: {:?}", + graph_mem + .get_layers() + .iter() + .map(|layer| layer.get_links_map().iter()) + .fold(0, |acc, link| { acc + link.len() }) + ); + + let hawk_searcher = HawkSearcher::new_with_m(*m); + + let queries = vector_mem.points.clone(); + for (query, _) in queries.range(PointId((checkpoint - 1) * step_size)..) { + if query.val() >= 1000 && (query.val() % 1000) == 0 { + println!( + "{:?}s: Inserting vector {}", + checkpoint_time.elapsed().as_secs(), + query.val() + ); + } + let neighbors = hawk_searcher + .search_to_insert(&mut vector_mem, &mut graph_mem, query) + .await; + hawk_searcher + .insert_from_search_results( + &mut vector_mem, + &mut graph_mem, + &mut rng, + *query, + neighbors, + ) + .await; + } + println!( + "{:?}s: Done searching and inserting vectors", + checkpoint_time.elapsed().as_secs() + ); + + let checkpoint_name = format!( + "1M_{}_M{}_checkpoint", + (checkpoint * step_size).to_string(), + m.to_string() + ); + prev_checkpoint_name = Some( + graph_mem + // .copy_out_with_filename(checkpoint_name) + .write_to_db(hawk_database_url, &checkpoint_name) + .await + .unwrap(), + ); + } + } + } +} diff --git a/iris-mpc-cpu/src/shares/vecshare.rs b/iris-mpc-cpu/src/shares/vecshare.rs index 9076bd23f..04727199a 100644 --- a/iris-mpc-cpu/src/shares/vecshare.rs +++ b/iris-mpc-cpu/src/shares/vecshare.rs @@ -92,7 +92,7 @@ impl<'a, T: IntRing2k> SliceShareMut<'a, T> { } } -#[derive(Clone, Debug, PartialEq, Default, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Default, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)] #[serde(bound = "")] #[repr(transparent)] pub struct VecShare {