Skip to content

Commit

Permalink
improve test coverage (#95)
Browse files Browse the repository at this point in the history
Co-authored-by: blaise-muhirwa <[email protected]>
  • Loading branch information
BlaiseMuhirwa and blaise-muhirwa authored Feb 5, 2025
1 parent 31b4e2b commit 6c2114a
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 15 deletions.
6 changes: 6 additions & 0 deletions include/flatnav/distances/DistanceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <cstddef> // for size_t
#include <fstream> // for ifstream, ofstream
#include <iostream>
#include <flatnav/util/Datatype.h>


using flatnav::util::DataType;

namespace flatnav::distances {

Expand Down Expand Up @@ -36,6 +40,8 @@ class DistanceInterface {
// Prints the parameters of the distance function.
void getSummary() { static_cast<T*>(this)->getSummaryImpl(); }

DataType getDataType() { return static_cast<T*>(this)->getDataTypeImpl(); }

// This transforms the data located at src into a form that is writeable
// to disk / storable in RAM. For distance functions that don't
// compress the input, this just passses through a copy from src to
Expand Down
2 changes: 2 additions & 0 deletions include/flatnav/distances/InnerProductDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class InnerProductDistance : public DistanceInterface<InnerProductDistance<data_
_dimension);
}

DataType getDataTypeImpl() const { return data_type; }

private:
size_t _dimension;
size_t _data_size_bytes;
Expand Down
2 changes: 2 additions & 0 deletions include/flatnav/distances/SquaredL2Distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class SquaredL2Distance : public DistanceInterface<SquaredL2Distance<data_type>>
_dimension);
}

inline DataType getDataTypeImpl() const { return data_type; }

private:
size_t _dimension;
size_t _data_size_bytes;
Expand Down
102 changes: 87 additions & 15 deletions include/flatnav/tests/test_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using flatnav::Index;
using flatnav::distances::DistanceInterface;
using flatnav::distances::InnerProductDistance;
using flatnav::distances::SquaredL2Distance;
using flatnav::util::DataType;


namespace flatnav::testing {

Expand All @@ -22,26 +24,28 @@ static const uint32_t EF_SEARCH = 50;

// TODO: This is duplicated a couple times. Move it to a common testing
// utils file.
std::vector<float> generateRandomVectors(uint32_t num_vectors, uint32_t dim) {
std::vector<float> vectors(num_vectors * dim);
template<typename T>
std::vector<T> generateRandomVectors(uint32_t num_vectors, uint32_t dim) {
std::vector<T> vectors(num_vectors * dim);
for (uint32_t i = 0; i < num_vectors * dim; i++) {
vectors[i] = (float)rand() / RAND_MAX;
vectors[i] = static_cast<T>(rand()) / RAND_MAX;
}
return vectors;
}

template <typename dist_t, typename label_t>
void runTest(float* data, std::unique_ptr<DistanceInterface<dist_t>>&& distance, int N, int M, int dim,
template <typename T, typename dist_t, typename label_t>
void runTest(T* data, std::unique_ptr<DistanceInterface<dist_t>>&& distance, int N, int M, int dim,
int ef_construction, const std::string& save_file) {
auto data_size = distance->dataSize();
auto data_type = distance->getDataType();

std::unique_ptr<Index<dist_t, label_t>> index = std::make_unique<Index<dist_t, label_t>>(
/* dist = */ std::move(distance), /* dataset_size = */ N,
/* max_edges = */ M);
/* max_edges = */ M, /* collect_stats= */ false, /* data_type = */ data_type);

std::vector<int> labels(N);
std::iota(labels.begin(), labels.end(), 0);
index->template addBatch<float>(data, labels, ef_construction);
index->template addBatch<T>(data, labels, ef_construction);
index->saveIndex(/* filename = */ save_file);

auto new_index = Index<dist_t, label_t>::loadIndex(/* filename = */ save_file);
Expand All @@ -51,13 +55,14 @@ void runTest(float* data, std::unique_ptr<DistanceInterface<dist_t>>&& distance,
// 4 bytes for each node_id, 4 bytes for the label
ASSERT_EQ(new_index->nodeSizeBytes(), data_size + (4 * M) + 4);
ASSERT_EQ(new_index->maxNodeCount(), N);
ASSERT_EQ(new_index->getDataType(), data_type);

uint64_t total_index_size = new_index->nodeSizeBytes() * new_index->maxNodeCount();

std::vector<float> queries = generateRandomVectors(QUERY_VECTORS, dim);
std::vector<T> queries = generateRandomVectors<T>(QUERY_VECTORS, dim);

for (uint32_t i = 0; i < QUERY_VECTORS; i++) {
float* q = queries.data() + (dim * i);
T* q = queries.data() + (dim * i);

std::vector<std::pair<float, int>> query_result = index->search(q, K, EF_SEARCH);

Expand All @@ -70,15 +75,49 @@ void runTest(float* data, std::unique_ptr<DistanceInterface<dist_t>>&& distance,
}
}

TEST(FlatnavSerializationTest, TestL2IndexSerialization) {
auto vectors = generateRandomVectors(INDEXED_VECTORS, VEC_DIM);
TEST(FlatnavSerializationTest, TestL2FloatIndexSerialization) {
auto vectors = generateRandomVectors<float>(INDEXED_VECTORS, VEC_DIM);
auto distance = std::make_unique<SquaredL2Distance<>>(VEC_DIM);
std::string save_file = "l2_index.bin";

uint32_t ef_construction = 100;
uint32_t M = 16;

runTest<SquaredL2Distance<flatnav::util::DataType::float32>, int>(
runTest<float, SquaredL2Distance<DataType::float32>, int>(
/* data = */ vectors.data(), /* distance */ std::move(distance),
/* N = */ INDEXED_VECTORS, /* M = */ M, /* dim = */ VEC_DIM,
/* ef_construction = */ ef_construction,
/* save_file = */ save_file);

EXPECT_EQ(std::remove(save_file.c_str()), 0);
}

TEST(FlatnavSerializationTest, TestL2Uint8IndexSerialization) {
auto vectors = generateRandomVectors<uint8_t>(INDEXED_VECTORS, VEC_DIM);
auto distance = std::make_unique<SquaredL2Distance<DataType::uint8>>(VEC_DIM);
std::string save_file = "l2_index.bin";

uint32_t ef_construction = 100;
uint32_t M = 16;

runTest<uint8_t, SquaredL2Distance<DataType::uint8>, int>(
/* data = */ vectors.data(), /* distance */ std::move(distance),
/* N = */ INDEXED_VECTORS, /* M = */ M, /* dim = */ VEC_DIM,
/* ef_construction = */ ef_construction,
/* save_file = */ save_file);

EXPECT_EQ(std::remove(save_file.c_str()), 0);
}

TEST(FlatnavSerializationTest, TestL2Int8IndexSerialization) {
auto vectors = generateRandomVectors<int8_t>(INDEXED_VECTORS, VEC_DIM);
auto distance = std::make_unique<SquaredL2Distance<DataType::int8>>(VEC_DIM);
std::string save_file = "l2_index.bin";

uint32_t ef_construction = 100;
uint32_t M = 16;

runTest<int8_t, SquaredL2Distance<DataType::int8>, int>(
/* data = */ vectors.data(), /* distance */ std::move(distance),
/* N = */ INDEXED_VECTORS, /* M = */ M, /* dim = */ VEC_DIM,
/* ef_construction = */ ef_construction,
Expand All @@ -87,15 +126,48 @@ TEST(FlatnavSerializationTest, TestL2IndexSerialization) {
EXPECT_EQ(std::remove(save_file.c_str()), 0);
}

TEST(FlatnavSerializationTest, TestInnerProductIndexSerialization) {
auto vectors = generateRandomVectors(INDEXED_VECTORS, VEC_DIM);

TEST(FlatnavSerializationTest, TestInnerProductFloatIndexSerialization) {
auto vectors = generateRandomVectors<float>(INDEXED_VECTORS, VEC_DIM);
auto distance = std::make_unique<InnerProductDistance<>>(VEC_DIM);
std::string save_file = "ip_index.bin";

uint32_t ef_construction = 100;
uint32_t M = 16;

runTest<InnerProductDistance<flatnav::util::DataType::float32>, int>(
runTest<float, InnerProductDistance<DataType::float32>, int>(
/* data = */ vectors.data(), /* distance = */ std::move(distance),
/* N = */ INDEXED_VECTORS, /* M = */ M, /* dim = */ VEC_DIM,
/* ef_construction */ ef_construction, /* save_file = */ save_file);

EXPECT_EQ(std::remove(save_file.c_str()), 0);
}

TEST(FlatnavSerializationTest, TestInnerProductUint8IndexSerialization) {
auto vectors = generateRandomVectors<uint8_t>(INDEXED_VECTORS, VEC_DIM);
auto distance = std::make_unique<InnerProductDistance<DataType::uint8>>(VEC_DIM);
std::string save_file = "ip_index.bin";

uint32_t ef_construction = 100;
uint32_t M = 16;

runTest<uint8_t, InnerProductDistance<DataType::uint8>, int>(
/* data = */ vectors.data(), /* distance = */ std::move(distance),
/* N = */ INDEXED_VECTORS, /* M = */ M, /* dim = */ VEC_DIM,
/* ef_construction */ ef_construction, /* save_file = */ save_file);

EXPECT_EQ(std::remove(save_file.c_str()), 0);
}

TEST(FlatnavSerializationTest, TestInnerProductInt8IndexSerialization) {
auto vectors = generateRandomVectors<int8_t>(INDEXED_VECTORS, VEC_DIM);
auto distance = std::make_unique<InnerProductDistance<DataType::int8>>(VEC_DIM);
std::string save_file = "ip_index.bin";

uint32_t ef_construction = 100;
uint32_t M = 16;

runTest<int8_t, InnerProductDistance<DataType::int8>, int>(
/* data = */ vectors.data(), /* distance = */ std::move(distance),
/* N = */ INDEXED_VECTORS, /* M = */ M, /* dim = */ VEC_DIM,
/* ef_construction */ ef_construction, /* save_file = */ save_file);
Expand Down

0 comments on commit 6c2114a

Please sign in to comment.