Initial commit
This commit is contained in:
260
packages/leann-backend-hnsw/third_party/faiss/tests/test_threaded_index.cpp
vendored
Normal file
260
packages/leann-backend-hnsw/third_party/faiss/tests/test_threaded_index.cpp
vendored
Normal file
@@ -0,0 +1,260 @@
|
||||
/*
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <faiss/IndexReplicas.h>
|
||||
#include <faiss/IndexShards.h>
|
||||
#include <faiss/impl/ThreadedIndex.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestException : public std::exception {};
|
||||
|
||||
using idx_t = faiss::idx_t;
|
||||
|
||||
struct MockIndex : public faiss::Index {
|
||||
explicit MockIndex(idx_t d) : faiss::Index(d) {
|
||||
resetMock();
|
||||
}
|
||||
|
||||
void resetMock() {
|
||||
flag = false;
|
||||
nCalled = 0;
|
||||
xCalled = nullptr;
|
||||
kCalled = 0;
|
||||
distancesCalled = nullptr;
|
||||
labelsCalled = nullptr;
|
||||
}
|
||||
|
||||
void add(idx_t n, const float* x) override {
|
||||
nCalled = n;
|
||||
xCalled = x;
|
||||
}
|
||||
|
||||
void search(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
const faiss::SearchParameters* params) const override {
|
||||
FAISS_THROW_IF_NOT(!params);
|
||||
nCalled = n;
|
||||
xCalled = x;
|
||||
kCalled = k;
|
||||
distancesCalled = distances;
|
||||
labelsCalled = labels;
|
||||
}
|
||||
|
||||
void reset() override {}
|
||||
|
||||
bool flag;
|
||||
|
||||
mutable idx_t nCalled;
|
||||
mutable const float* xCalled;
|
||||
mutable idx_t kCalled;
|
||||
mutable float* distancesCalled;
|
||||
mutable idx_t* labelsCalled;
|
||||
};
|
||||
|
||||
template <typename IndexT>
|
||||
struct MockThreadedIndex : public faiss::ThreadedIndex<IndexT> {
|
||||
using idx_t = faiss::idx_t;
|
||||
|
||||
explicit MockThreadedIndex(bool threaded)
|
||||
: faiss::ThreadedIndex<IndexT>(threaded) {}
|
||||
|
||||
void add(idx_t, const float*) override {}
|
||||
void search(
|
||||
idx_t,
|
||||
const float*,
|
||||
idx_t,
|
||||
float*,
|
||||
idx_t*,
|
||||
const faiss::SearchParameters*) const override {}
|
||||
void reset() override {}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(ThreadedIndex, SingleException) {
|
||||
std::vector<std::unique_ptr<MockIndex>> idxs;
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
idxs.emplace_back(new MockIndex(1));
|
||||
}
|
||||
|
||||
auto fn = [](int i, MockIndex* index) {
|
||||
if (i == 1) {
|
||||
throw TestException();
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(i * 250));
|
||||
|
||||
index->flag = true;
|
||||
}
|
||||
};
|
||||
|
||||
// Try with threading and without
|
||||
for (bool threaded : {true, false}) {
|
||||
// clear flags
|
||||
for (auto& idx : idxs) {
|
||||
idx->resetMock();
|
||||
}
|
||||
|
||||
MockThreadedIndex<MockIndex> ti(threaded);
|
||||
for (auto& idx : idxs) {
|
||||
ti.addIndex(idx.get());
|
||||
}
|
||||
|
||||
// The second index should throw
|
||||
EXPECT_THROW(ti.runOnIndex(fn), TestException);
|
||||
|
||||
// Index 0 and 2 should have processed
|
||||
EXPECT_TRUE(idxs[0]->flag);
|
||||
EXPECT_TRUE(idxs[2]->flag);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ThreadedIndex, MultipleException) {
|
||||
std::vector<std::unique_ptr<MockIndex>> idxs;
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
idxs.emplace_back(new MockIndex(1));
|
||||
}
|
||||
|
||||
auto fn = [](int i, MockIndex* index) {
|
||||
if (i < 2) {
|
||||
throw TestException();
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(i * 250));
|
||||
|
||||
index->flag = true;
|
||||
}
|
||||
};
|
||||
|
||||
// Try with threading and without
|
||||
for (bool threaded : {true, false}) {
|
||||
// clear flags
|
||||
for (auto& idx : idxs) {
|
||||
idx->resetMock();
|
||||
}
|
||||
|
||||
MockThreadedIndex<MockIndex> ti(threaded);
|
||||
for (auto& idx : idxs) {
|
||||
ti.addIndex(idx.get());
|
||||
}
|
||||
|
||||
// Multiple indices threw an exception that was aggregated into a
|
||||
// FaissException
|
||||
EXPECT_THROW(ti.runOnIndex(fn), faiss::FaissException);
|
||||
|
||||
// Index 2 should have processed
|
||||
EXPECT_TRUE(idxs[2]->flag);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ThreadedIndex, TestReplica) {
|
||||
int numReplicas = 5;
|
||||
int n = 10 * numReplicas;
|
||||
int d = 3;
|
||||
int k = 6;
|
||||
|
||||
// Try with threading and without
|
||||
for ([[maybe_unused]] const bool threaded : {true, false}) {
|
||||
std::vector<std::unique_ptr<MockIndex>> idxs;
|
||||
faiss::IndexReplicas replica(d);
|
||||
|
||||
for (int i = 0; i < numReplicas; ++i) {
|
||||
idxs.emplace_back(new MockIndex(d));
|
||||
replica.addIndex(idxs.back().get());
|
||||
}
|
||||
|
||||
std::vector<float> x(n * d);
|
||||
std::vector<float> distances(n * k);
|
||||
std::vector<faiss::idx_t> labels(n * k);
|
||||
|
||||
replica.add(n, x.data());
|
||||
|
||||
for (int i = 0; i < idxs.size(); ++i) {
|
||||
EXPECT_EQ(idxs[i]->nCalled, n);
|
||||
EXPECT_EQ(idxs[i]->xCalled, x.data());
|
||||
}
|
||||
|
||||
for (auto& idx : idxs) {
|
||||
idx->resetMock();
|
||||
}
|
||||
|
||||
replica.search(n, x.data(), k, distances.data(), labels.data());
|
||||
|
||||
for (int i = 0; i < idxs.size(); ++i) {
|
||||
auto perReplica = n / idxs.size();
|
||||
|
||||
EXPECT_EQ(idxs[i]->nCalled, perReplica);
|
||||
EXPECT_EQ(idxs[i]->xCalled, x.data() + i * perReplica * d);
|
||||
EXPECT_EQ(idxs[i]->kCalled, k);
|
||||
EXPECT_EQ(
|
||||
idxs[i]->distancesCalled,
|
||||
distances.data() + (i * perReplica) * k);
|
||||
EXPECT_EQ(
|
||||
idxs[i]->labelsCalled,
|
||||
labels.data() + (i * perReplica) * k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ThreadedIndex, TestShards) {
|
||||
int numShards = 7;
|
||||
int d = 3;
|
||||
int n = 10 * numShards;
|
||||
int k = 6;
|
||||
|
||||
// Try with threading and without
|
||||
for (bool threaded : {true, false}) {
|
||||
std::vector<std::unique_ptr<MockIndex>> idxs;
|
||||
faiss::IndexShards shards(d, threaded);
|
||||
|
||||
for (int i = 0; i < numShards; ++i) {
|
||||
idxs.emplace_back(new MockIndex(d));
|
||||
shards.addIndex(idxs.back().get());
|
||||
}
|
||||
|
||||
std::vector<float> x(n * d);
|
||||
std::vector<float> distances(n * k);
|
||||
std::vector<faiss::idx_t> labels(n * k);
|
||||
|
||||
shards.add(n, x.data());
|
||||
|
||||
for (int i = 0; i < idxs.size(); ++i) {
|
||||
auto perShard = n / idxs.size();
|
||||
|
||||
EXPECT_EQ(idxs[i]->nCalled, perShard);
|
||||
EXPECT_EQ(idxs[i]->xCalled, x.data() + i * perShard * d);
|
||||
}
|
||||
|
||||
for (auto& idx : idxs) {
|
||||
idx->resetMock();
|
||||
}
|
||||
|
||||
shards.search(n, x.data(), k, distances.data(), labels.data());
|
||||
|
||||
for (int i = 0; i < idxs.size(); ++i) {
|
||||
EXPECT_EQ(idxs[i]->nCalled, n);
|
||||
EXPECT_EQ(idxs[i]->xCalled, x.data());
|
||||
EXPECT_EQ(idxs[i]->kCalled, k);
|
||||
// There is a temporary buffer used for shards
|
||||
EXPECT_EQ(
|
||||
idxs[i]->distancesCalled,
|
||||
idxs[0]->distancesCalled + i * k * n);
|
||||
EXPECT_EQ(idxs[i]->labelsCalled, idxs[0]->labelsCalled + i * k * n);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user