Initial commit
This commit is contained in:
205
packages/leann-backend-hnsw/third_party/faiss/tests/test_sliding_ivf.cpp
vendored
Normal file
205
packages/leann-backend-hnsw/third_party/faiss/tests/test_sliding_ivf.cpp
vendored
Normal file
@@ -0,0 +1,205 @@
|
||||
/*
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <faiss/IVFlib.h>
|
||||
#include <faiss/IndexIVF.h>
|
||||
#include <faiss/clone_index.h>
|
||||
#include <faiss/index_factory.h>
|
||||
|
||||
using namespace faiss;
|
||||
|
||||
// dimension of the vectors to index
|
||||
int d = 32;
|
||||
|
||||
// nb of training vectors
|
||||
size_t nt = 5000;
|
||||
|
||||
// size of the database points per window step
|
||||
size_t nb = 1000;
|
||||
|
||||
// nb of queries
|
||||
size_t nq = 200;
|
||||
|
||||
int total_size = 40;
|
||||
int window_size = 10;
|
||||
|
||||
std::vector<float> make_data(size_t n) {
|
||||
std::vector<float> database(n * d);
|
||||
std::mt19937 rng;
|
||||
std::uniform_real_distribution<> distrib;
|
||||
for (size_t i = 0; i < n * d; i++) {
|
||||
database[i] = distrib(rng);
|
||||
}
|
||||
return database;
|
||||
}
|
||||
|
||||
std::unique_ptr<Index> make_trained_index(const char* index_type) {
|
||||
auto index = std::unique_ptr<Index>(index_factory(d, index_type));
|
||||
auto xt = make_data(nt * d);
|
||||
index->train(nt, xt.data());
|
||||
ParameterSpace().set_index_parameter(index.get(), "nprobe", 4);
|
||||
return index;
|
||||
}
|
||||
|
||||
std::vector<idx_t> search_index(Index* index, const float* xq) {
|
||||
int k = 10;
|
||||
std::vector<idx_t> I(k * nq);
|
||||
std::vector<float> D(k * nq);
|
||||
index->search(nq, xq, k, D.data(), I.data());
|
||||
return I;
|
||||
}
|
||||
|
||||
/*************************************************************
|
||||
* Test functions for a given index type
|
||||
*************************************************************/
|
||||
|
||||
// make a few slices of indexes that can be merged
|
||||
void make_index_slices(
|
||||
const Index* trained_index,
|
||||
std::vector<std::unique_ptr<Index>>& sub_indexes) {
|
||||
for (int i = 0; i < total_size; i++) {
|
||||
sub_indexes.emplace_back(clone_index(trained_index));
|
||||
|
||||
Index* index = sub_indexes.back().get();
|
||||
|
||||
auto xb = make_data(nb * d);
|
||||
std::vector<faiss::idx_t> ids(nb);
|
||||
std::mt19937 rng;
|
||||
std::uniform_int_distribution<> distrib;
|
||||
for (int j = 0; j < nb; j++) {
|
||||
ids[j] = distrib(rng);
|
||||
}
|
||||
index->add_with_ids(nb, xb.data(), ids.data());
|
||||
}
|
||||
}
|
||||
|
||||
// build merged index explicitly at sliding window position i
|
||||
Index* make_merged_index(
|
||||
const Index* trained_index,
|
||||
const std::vector<std::unique_ptr<Index>>& sub_indexes,
|
||||
int i) {
|
||||
Index* merged_index = clone_index(trained_index);
|
||||
for (int j = i - window_size + 1; j <= i; j++) {
|
||||
if (j < 0 || j >= total_size)
|
||||
continue;
|
||||
std::unique_ptr<Index> sub_index(clone_index(sub_indexes[j].get()));
|
||||
IndexIVF* ivf0 = ivflib::extract_index_ivf(merged_index);
|
||||
IndexIVF* ivf1 = ivflib::extract_index_ivf(sub_index.get());
|
||||
ivf0->merge_from(*ivf1, 0);
|
||||
merged_index->ntotal = ivf0->ntotal;
|
||||
}
|
||||
return merged_index;
|
||||
}
|
||||
|
||||
int test_sliding_window(const char* index_key) {
|
||||
std::unique_ptr<Index> trained_index = make_trained_index(index_key);
|
||||
|
||||
// make the index slices
|
||||
std::vector<std::unique_ptr<Index>> sub_indexes;
|
||||
|
||||
make_index_slices(trained_index.get(), sub_indexes);
|
||||
|
||||
// now slide over the windows
|
||||
std::unique_ptr<Index> index(clone_index(trained_index.get()));
|
||||
ivflib::SlidingIndexWindow window(index.get());
|
||||
|
||||
auto xq = make_data(nq * d);
|
||||
|
||||
for (int i = 0; i < total_size + window_size; i++) {
|
||||
// update the index
|
||||
window.step(
|
||||
i < total_size ? sub_indexes[i].get() : nullptr,
|
||||
i >= window_size);
|
||||
|
||||
auto new_res = search_index(index.get(), xq.data());
|
||||
|
||||
std::unique_ptr<Index> merged_index(
|
||||
make_merged_index(trained_index.get(), sub_indexes, i));
|
||||
|
||||
auto ref_res = search_index(merged_index.get(), xq.data());
|
||||
|
||||
EXPECT_EQ(ref_res.size(), new_res.size());
|
||||
|
||||
EXPECT_EQ(ref_res, new_res);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int test_sliding_invlists(const char* index_key) {
|
||||
std::unique_ptr<Index> trained_index = make_trained_index(index_key);
|
||||
|
||||
// make the index slices
|
||||
std::vector<std::unique_ptr<Index>> sub_indexes;
|
||||
|
||||
make_index_slices(trained_index.get(), sub_indexes);
|
||||
|
||||
// now slide over the windows
|
||||
std::unique_ptr<Index> index(clone_index(trained_index.get()));
|
||||
IndexIVF* index_ivf = ivflib::extract_index_ivf(index.get());
|
||||
|
||||
auto xq = make_data(nq * d);
|
||||
|
||||
for (int i = 0; i < total_size + window_size; i++) {
|
||||
// update the index
|
||||
std::vector<const InvertedLists*> ils;
|
||||
for (int j = i - window_size + 1; j <= i; j++) {
|
||||
if (j < 0 || j >= total_size)
|
||||
continue;
|
||||
ils.push_back(
|
||||
ivflib::extract_index_ivf(sub_indexes[j].get())->invlists);
|
||||
}
|
||||
if (ils.size() == 0)
|
||||
continue;
|
||||
|
||||
ConcatenatedInvertedLists* ci =
|
||||
new ConcatenatedInvertedLists(ils.size(), ils.data());
|
||||
|
||||
// will be deleted by the index
|
||||
index_ivf->replace_invlists(ci, true);
|
||||
|
||||
auto new_res = search_index(index.get(), xq.data());
|
||||
|
||||
std::unique_ptr<Index> merged_index(
|
||||
make_merged_index(trained_index.get(), sub_indexes, i));
|
||||
|
||||
auto ref_res = search_index(merged_index.get(), xq.data());
|
||||
|
||||
EXPECT_EQ(ref_res.size(), new_res.size());
|
||||
EXPECT_EQ(ref_res, new_res);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*************************************************************
|
||||
* Test entry points
|
||||
*************************************************************/
|
||||
|
||||
TEST(SlidingWindow, IVFFlat) {
|
||||
test_sliding_window("IVF32,Flat");
|
||||
}
|
||||
|
||||
TEST(SlidingWindow, PCAIVFFlat) {
|
||||
test_sliding_window("PCA24,IVF32,Flat");
|
||||
}
|
||||
|
||||
TEST(SlidingInvlists, IVFFlat) {
|
||||
test_sliding_invlists("IVF32,Flat");
|
||||
}
|
||||
|
||||
TEST(SlidingInvlists, PCAIVFFlat) {
|
||||
test_sliding_invlists("PCA24,IVF32,Flat");
|
||||
}
|
||||
Reference in New Issue
Block a user