Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
project (ROCKSDB_IVF)
set(CMAKE_BUILD_TYPE Debug)
find_package(faiss REQUIRED)
find_package(RocksDB REQUIRED)
add_executable(demo_rocksdb_ivf demo_rocksdb_ivf.cpp RocksDBInvertedLists.cpp)
target_link_libraries(demo_rocksdb_ivf faiss RocksDB::rocksdb)

View File

@@ -0,0 +1,23 @@
# Storing Faiss inverted lists in RocksDB
Demo of storing the inverted lists of any IVF index in RocksDB or any similar key-value store which supports the prefix scan operation.
# How to build
We use conda to create the build environment for simplicity. Only tested on Linux x86.
```
conda create -n rocksdb_ivf
conda activate rocksdb_ivf
conda install pytorch::faiss-cpu conda-forge::rocksdb cmake make gxx_linux-64 sysroot_linux-64
cd ~/faiss/demos/rocksdb_ivf
cmake -B build .
make -C build -j$(nproc)
```
# Run the example
```
cd ~/faiss/demos/rocksdb_ivf/build
./rocksdb_ivf test_db
```

View File

@@ -0,0 +1,114 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "RocksDBInvertedLists.h"
#include <faiss/impl/FaissAssert.h>
using namespace faiss;
namespace faiss_rocksdb {
RocksDBInvertedListsIterator::RocksDBInvertedListsIterator(
rocksdb::DB* db,
size_t list_no,
size_t code_size)
: InvertedListsIterator(),
it(db->NewIterator(rocksdb::ReadOptions())),
list_no(list_no),
code_size(code_size),
codes(code_size) {
it->Seek(rocksdb::Slice(
reinterpret_cast<const char*>(&list_no), sizeof(size_t)));
}
bool RocksDBInvertedListsIterator::is_available() const {
return it->Valid() &&
it->key().starts_with(rocksdb::Slice(
reinterpret_cast<const char*>(&list_no), sizeof(size_t)));
}
void RocksDBInvertedListsIterator::next() {
it->Next();
}
std::pair<idx_t, const uint8_t*> RocksDBInvertedListsIterator::
get_id_and_codes() {
idx_t id =
*reinterpret_cast<const idx_t*>(&it->key().data()[sizeof(size_t)]);
assert(code_size == it->value().size());
return {id, reinterpret_cast<const uint8_t*>(it->value().data())};
}
RocksDBInvertedLists::RocksDBInvertedLists(
const char* db_directory,
size_t nlist,
size_t code_size)
: InvertedLists(nlist, code_size) {
use_iterator = true;
rocksdb::Options options;
options.create_if_missing = true;
rocksdb::DB* db;
rocksdb::Status status = rocksdb::DB::Open(options, db_directory, &db);
db_ = std::unique_ptr<rocksdb::DB>(db);
assert(status.ok());
}
size_t RocksDBInvertedLists::list_size(size_t /*list_no*/) const {
FAISS_THROW_MSG("list_size is not supported");
}
const uint8_t* RocksDBInvertedLists::get_codes(size_t /*list_no*/) const {
FAISS_THROW_MSG("get_codes is not supported");
}
const idx_t* RocksDBInvertedLists::get_ids(size_t /*list_no*/) const {
FAISS_THROW_MSG("get_ids is not supported");
}
size_t RocksDBInvertedLists::add_entries(
size_t list_no,
size_t n_entry,
const idx_t* ids,
const uint8_t* code) {
rocksdb::WriteOptions wo;
std::vector<char> key(sizeof(size_t) + sizeof(idx_t));
memcpy(key.data(), &list_no, sizeof(size_t));
for (size_t i = 0; i < n_entry; i++) {
memcpy(key.data() + sizeof(size_t), ids + i, sizeof(idx_t));
rocksdb::Status status = db_->Put(
wo,
rocksdb::Slice(key.data(), key.size()),
rocksdb::Slice(
reinterpret_cast<const char*>(code + i * code_size),
code_size));
assert(status.ok());
}
return 0; // ignored
}
void RocksDBInvertedLists::update_entries(
size_t /*list_no*/,
size_t /*offset*/,
size_t /*n_entry*/,
const idx_t* /*ids*/,
const uint8_t* /*code*/) {
FAISS_THROW_MSG("update_entries is not supported");
}
void RocksDBInvertedLists::resize(size_t /*list_no*/, size_t /*new_size*/) {
FAISS_THROW_MSG("resize is not supported");
}
InvertedListsIterator* RocksDBInvertedLists::get_iterator(
size_t list_no,
void* inverted_list_context) const {
return new RocksDBInvertedListsIterator(db_.get(), list_no, code_size);
}
} // namespace faiss_rocksdb

View File

@@ -0,0 +1,67 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#pragma once
#include <faiss/invlists/InvertedLists.h>
#include <rocksdb/db.h>
namespace faiss_rocksdb {
struct RocksDBInvertedListsIterator : faiss::InvertedListsIterator {
RocksDBInvertedListsIterator(
rocksdb::DB* db,
size_t list_no,
size_t code_size);
virtual bool is_available() const override;
virtual void next() override;
virtual std::pair<faiss::idx_t, const uint8_t*> get_id_and_codes() override;
private:
std::unique_ptr<rocksdb::Iterator> it;
size_t list_no;
size_t code_size;
std::vector<uint8_t> codes; // buffer for returning codes in next()
};
struct RocksDBInvertedLists : faiss::InvertedLists {
RocksDBInvertedLists(
const char* db_directory,
size_t nlist,
size_t code_size);
size_t list_size(size_t list_no) const override;
const uint8_t* get_codes(size_t list_no) const override;
const faiss::idx_t* get_ids(size_t list_no) const override;
size_t add_entries(
size_t list_no,
size_t n_entry,
const faiss::idx_t* ids,
const uint8_t* code) override;
void update_entries(
size_t list_no,
size_t offset,
size_t n_entry,
const faiss::idx_t* ids,
const uint8_t* code) override;
void resize(size_t list_no, size_t new_size) override;
faiss::InvertedListsIterator* get_iterator(
size_t list_no,
void* inverted_list_context) const override;
private:
std::unique_ptr<rocksdb::DB> db_;
};
} // namespace faiss_rocksdb

View File

@@ -0,0 +1,88 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <exception>
#include <iostream>
#include <memory>
#include "RocksDBInvertedLists.h"
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissException.h>
#include <faiss/utils/random.h>
using namespace faiss;
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
std::cerr << "missing db directory argument" << std::endl;
return -1;
}
size_t d = 128;
size_t nlist = 100;
IndexFlatL2 quantizer(d);
IndexIVFFlat index(&quantizer, d, nlist);
faiss_rocksdb::RocksDBInvertedLists ril(
argv[1], nlist, index.code_size);
index.replace_invlists(&ril, false);
idx_t nb = 10000;
std::vector<float> xb(d * nb);
float_rand(xb.data(), d * nb, 12345);
std::vector<idx_t> xids(nb);
std::iota(xids.begin(), xids.end(), 0);
index.train(nb, xb.data());
index.add_with_ids(nb, xb.data(), xids.data());
idx_t nq = 20; // nb;
index.nprobe = 2;
std::cout << "search" << std::endl;
idx_t k = 5;
std::vector<float> distances(nq * k);
std::vector<idx_t> labels(nq * k, -1);
index.search(
nq, xb.data(), k, distances.data(), labels.data(), nullptr);
for (idx_t iq = 0; iq < nq; iq++) {
std::cout << iq << ": ";
for (auto j = 0; j < k; j++) {
std::cout << labels[iq * k + j] << " " << distances[iq * k + j]
<< " | ";
}
std::cout << std::endl;
}
std::cout << std::endl << "range search" << std::endl;
float range = 15.0f;
RangeSearchResult result(nq);
index.range_search(nq, xb.data(), range, &result);
for (idx_t iq = 0; iq < nq; iq++) {
std::cout << iq << ": ";
for (auto j = result.lims[iq]; j < result.lims[iq + 1]; j++) {
std::cout << result.labels[j] << " " << result.distances[j]
<< " | ";
}
std::cout << std::endl;
}
} catch (FaissException& e) {
std::cerr << e.what() << '\n';
} catch (std::exception& e) {
std::cerr << e.what() << '\n';
} catch (...) {
std::cerr << "Unrecognized exception!\n";
}
return 0;
}