Initial commit
This commit is contained in:
42
packages/leann-backend-diskann/third_party/DiskANN/apps/CMakeLists.txt
vendored
Normal file
42
packages/leann-backend-diskann/third_party/DiskANN/apps/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
|
||||
|
||||
add_executable(build_memory_index build_memory_index.cpp)
|
||||
target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(build_stitched_index build_stitched_index.cpp)
|
||||
target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(search_memory_index search_memory_index.cpp)
|
||||
target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(build_disk_index build_disk_index.cpp)
|
||||
target_link_libraries(build_disk_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
add_executable(search_disk_index search_disk_index.cpp)
|
||||
target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(range_search_disk_index range_search_disk_index.cpp)
|
||||
target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(test_streaming_scenario test_streaming_scenario.cpp)
|
||||
target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(test_insert_deletes_consolidate test_insert_deletes_consolidate.cpp)
|
||||
target_link_libraries(test_insert_deletes_consolidate ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
if (NOT MSVC)
|
||||
install(TARGETS build_memory_index
|
||||
build_stitched_index
|
||||
search_memory_index
|
||||
build_disk_index
|
||||
search_disk_index
|
||||
range_search_disk_index
|
||||
test_streaming_scenario
|
||||
test_insert_deletes_consolidate
|
||||
RUNTIME
|
||||
)
|
||||
endif()
|
||||
191
packages/leann-backend-diskann/third_party/DiskANN/apps/build_disk_index.cpp
vendored
Normal file
191
packages/leann-backend-diskann/third_party/DiskANN/apps/build_disk_index.cpp
vendored
Normal file
@@ -0,0 +1,191 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "index.h"
|
||||
#include "partition.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label,
|
||||
label_type;
|
||||
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold;
|
||||
float B, M;
|
||||
bool append_reorder_data = false;
|
||||
bool use_opq = false;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("search_DRAM_budget,B", po::value<float>(&B)->required(),
|
||||
"DRAM budget in GB for searching the index to set the "
|
||||
"compressed level for data while search happens");
|
||||
required_configs.add_options()("build_DRAM_budget,M", po::value<float>(&M)->required(),
|
||||
"DRAM budget in GB for building the index");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("QD", po::value<uint32_t>(&QD)->default_value(0),
|
||||
" Quantized Dimension for compression");
|
||||
optional_configs.add_options()("codebook_prefix", po::value<std::string>(&codebook_prefix)->default_value(""),
|
||||
"Path prefix for pre-trained codebook");
|
||||
optional_configs.add_options()("PQ_disk_bytes", po::value<uint32_t>(&disk_PQ)->default_value(0),
|
||||
"Number of bytes to which vectors should be compressed "
|
||||
"on SSD; 0 for no compression");
|
||||
optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false),
|
||||
"Include full precision data in the index. Use only in "
|
||||
"conjuction with compressed data on SSD.");
|
||||
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ)->default_value(0),
|
||||
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
|
||||
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
|
||||
program_options_utils::USE_OPQ);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
program_options_utils::FILTERED_LBUILD);
|
||||
optional_configs.add_options()("filter_threshold,F", po::value<uint32_t>(&filter_threshold)->default_value(0),
|
||||
"Threshold to break up the existing nodes to generate new graph "
|
||||
"internally where each node has a maximum F labels.");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (vm["append_reorder_data"].as<bool>())
|
||||
append_reorder_data = true;
|
||||
if (vm["use_opq"].as<bool>())
|
||||
use_opq = true;
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool use_filters = (label_file != "") ? true : false;
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
metric = diskann::Metric::COSINE;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (append_reorder_data)
|
||||
{
|
||||
if (disk_PQ == 0)
|
||||
{
|
||||
std::cout << "Error: It is not necessary to append data for reordering "
|
||||
"when vectors are not compressed on disk."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (data_type != std::string("float"))
|
||||
{
|
||||
std::cout << "Error: Appending data for reordering currently only "
|
||||
"supported for float data type."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " +
|
||||
std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " +
|
||||
std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " +
|
||||
std::string(std::to_string(append_reorder_data)) + " " +
|
||||
std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD));
|
||||
|
||||
try
|
||||
{
|
||||
if (label_file != "" && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
|
||||
use_filters, label_file, universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
|
||||
use_filters, label_file, universal_label, filter_threshold, Lf);
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
164
packages/leann-backend-diskann/third_party/DiskANN/apps/build_memory_index.cpp
vendored
Normal file
164
packages/leann-backend-diskann/third_party/DiskANN/apps/build_memory_index.cpp
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <cstring>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
#include "ann_exception.h"
|
||||
#include "index_factory.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
|
||||
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
|
||||
float alpha;
|
||||
bool use_pq_build, use_opq;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
|
||||
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
|
||||
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
|
||||
program_options_utils::USE_OPQ);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
|
||||
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
program_options_utils::FILTERED_LBUILD);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
use_pq_build = (build_PQ_bytes > 0);
|
||||
use_opq = vm["use_opq"].as<bool>();
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha
|
||||
<< " #threads: " << num_threads << std::endl;
|
||||
|
||||
size_t data_num, data_dim;
|
||||
diskann::get_bin_metadata(data_path, data_num, data_dim);
|
||||
|
||||
auto index_build_params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_filter_list_size(Lf)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
auto filter_params = diskann::IndexFilterParamsBuilder()
|
||||
.with_universal_label(universal_label)
|
||||
.with_label_file(label_file)
|
||||
.with_save_path_prefix(index_path_prefix)
|
||||
.build();
|
||||
auto config = diskann::IndexConfigBuilder()
|
||||
.with_metric(metric)
|
||||
.with_dimension(data_dim)
|
||||
.with_max_points(data_num)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.with_data_type(data_type)
|
||||
.with_label_type(label_type)
|
||||
.is_dynamic_index(false)
|
||||
.with_index_write_params(index_build_params)
|
||||
.is_enable_tags(false)
|
||||
.is_use_opq(use_opq)
|
||||
.is_pq_dist_build(use_pq_build)
|
||||
.with_num_pq_chunks(build_PQ_bytes)
|
||||
.build();
|
||||
|
||||
auto index_factory = diskann::IndexFactory(config);
|
||||
auto index = index_factory.create_instance();
|
||||
index->build(data_path, data_num, filter_params);
|
||||
index->save(index_path_prefix.c_str());
|
||||
index.reset();
|
||||
return 0;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
441
packages/leann-backend-diskann/third_party/DiskANN/apps/build_stitched_index.cpp
vendored
Normal file
441
packages/leann-backend-diskann/third_party/DiskANN/apps/build_stitched_index.cpp
vendored
Normal file
@@ -0,0 +1,441 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "filter_utils.h"
|
||||
#include <omp.h>
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/uio.h>
|
||||
#endif
|
||||
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "parameters.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
typedef std::tuple<std::vector<std::vector<uint32_t>>, uint64_t> stitch_indices_return_values;
|
||||
|
||||
/*
|
||||
* Inline function to display progress bar.
|
||||
*/
|
||||
inline void print_progress(double percentage)
|
||||
{
|
||||
int val = (int)(percentage * 100);
|
||||
int lpad = (int)(percentage * PBWIDTH);
|
||||
int rpad = PBWIDTH - lpad;
|
||||
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
/*
|
||||
* Inline function to generate a random integer in a range.
|
||||
*/
|
||||
inline size_t random(size_t range_from, size_t range_to)
|
||||
{
|
||||
std::random_device rand_dev;
|
||||
std::mt19937 generator(rand_dev());
|
||||
std::uniform_int_distribution<size_t> distr(range_from, range_to);
|
||||
return distr(generator);
|
||||
}
|
||||
|
||||
/*
|
||||
* function to handle command line parsing.
|
||||
*
|
||||
* Arguments are merely the inputs from the command line.
|
||||
*/
|
||||
void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix,
|
||||
path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L,
|
||||
uint32_t &stitched_R, float &alpha)
|
||||
{
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix",
|
||||
po::value<std::string>(&final_index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&input_data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_data_path)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
optional_configs.add_options()("stitched_R", po::value<uint32_t>(&stitched_R)->default_value(100),
|
||||
"Degree to prune final graph down to");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
exit(0);
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Custom index save to write the in-memory index to disk.
|
||||
* Also writes required files for diskANN API -
|
||||
* 1. labels_to_medoids
|
||||
* 2. universal_label
|
||||
* 3. data (redundant for static indices)
|
||||
* 4. labels (redundant for static indices)
|
||||
*/
|
||||
void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size,
|
||||
std::vector<std::vector<uint32_t>> stitched_graph,
|
||||
tsl::robin_map<std::string, uint32_t> entry_points, std::string universal_label,
|
||||
path label_data_path)
|
||||
{
|
||||
// aux. file 1
|
||||
auto saving_index_timer = std::chrono::high_resolution_clock::now();
|
||||
std::ifstream original_label_data_stream;
|
||||
original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
original_label_data_stream.open(label_data_path, std::ios::binary);
|
||||
std::ofstream new_label_data_stream;
|
||||
new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary);
|
||||
new_label_data_stream << original_label_data_stream.rdbuf();
|
||||
original_label_data_stream.close();
|
||||
new_label_data_stream.close();
|
||||
|
||||
// aux. file 2
|
||||
std::ifstream original_input_data_stream;
|
||||
original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
original_input_data_stream.open(input_data_path, std::ios::binary);
|
||||
std::ofstream new_input_data_stream;
|
||||
new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary);
|
||||
new_input_data_stream << original_input_data_stream.rdbuf();
|
||||
original_input_data_stream.close();
|
||||
new_input_data_stream.close();
|
||||
|
||||
// aux. file 3
|
||||
std::ofstream labels_to_medoids_writer;
|
||||
labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt");
|
||||
for (auto iter : entry_points)
|
||||
labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl;
|
||||
labels_to_medoids_writer.close();
|
||||
|
||||
// aux. file 4 (only if we're using a universal label)
|
||||
if (universal_label != "")
|
||||
{
|
||||
std::ofstream universal_label_writer;
|
||||
universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
universal_label_writer.open(final_index_path_prefix + "_universal_label.txt");
|
||||
universal_label_writer << universal_label << std::endl;
|
||||
universal_label_writer.close();
|
||||
}
|
||||
|
||||
// main index
|
||||
uint64_t index_num_frozen_points = 0, index_num_edges = 0;
|
||||
uint32_t index_max_observed_degree = 0, index_entry_point = 0;
|
||||
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
|
||||
for (auto &point_neighbors : stitched_graph)
|
||||
{
|
||||
index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size());
|
||||
}
|
||||
|
||||
std::ofstream stitched_graph_writer;
|
||||
stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary);
|
||||
|
||||
stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t));
|
||||
stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t));
|
||||
stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t));
|
||||
stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t));
|
||||
|
||||
size_t bytes_written = METADATA;
|
||||
for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++)
|
||||
{
|
||||
uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size();
|
||||
std::vector<uint32_t> current_node_neighbors = stitched_graph[node_point];
|
||||
stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(uint32_t));
|
||||
bytes_written += sizeof(uint32_t);
|
||||
for (const auto ¤t_node_neighbor : current_node_neighbors)
|
||||
{
|
||||
stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(uint32_t));
|
||||
bytes_written += sizeof(uint32_t);
|
||||
}
|
||||
index_num_edges += current_node_num_neighbors;
|
||||
}
|
||||
|
||||
if (bytes_written != final_index_size)
|
||||
{
|
||||
std::cerr << "Error: written bytes does not match allocated space" << std::endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
stitched_graph_writer.close();
|
||||
|
||||
std::chrono::duration<double> saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer;
|
||||
std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl;
|
||||
std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size()))
|
||||
<< std::endl;
|
||||
std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Unions the per-label graph indices together via the following policy:
|
||||
* - any two nodes can only have at most one edge between them -
|
||||
*
|
||||
* Returns the "stitched" graph and its expected file size.
|
||||
*/
|
||||
template <typename T>
|
||||
stitch_indices_return_values stitch_label_indices(
|
||||
path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels,
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
|
||||
tsl::robin_map<std::string, uint32_t> &label_entry_points,
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map)
|
||||
{
|
||||
size_t final_index_size = 0;
|
||||
std::vector<std::vector<uint32_t>> stitched_graph(total_number_of_points);
|
||||
|
||||
auto stitching_index_timer = std::chrono::high_resolution_clock::now();
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
std::vector<std::vector<uint32_t>> curr_label_index;
|
||||
uint64_t curr_label_index_size;
|
||||
uint32_t curr_label_entry_point;
|
||||
|
||||
std::tie(curr_label_index, curr_label_index_size) =
|
||||
diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]);
|
||||
curr_label_entry_point = (uint32_t)random(0, curr_label_index.size());
|
||||
label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point];
|
||||
|
||||
for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++)
|
||||
{
|
||||
uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point];
|
||||
for (auto &node_neighbor : curr_label_index[node_point])
|
||||
{
|
||||
uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor];
|
||||
std::vector<uint32_t> curr_point_neighbors = stitched_graph[original_point_id];
|
||||
if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) ==
|
||||
curr_point_neighbors.end())
|
||||
{
|
||||
stitched_graph[original_point_id].push_back(original_neighbor_id);
|
||||
final_index_size += sizeof(uint32_t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
|
||||
final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA);
|
||||
|
||||
std::chrono::duration<double> stitching_index_time =
|
||||
std::chrono::high_resolution_clock::now() - stitching_index_timer;
|
||||
std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl;
|
||||
|
||||
return std::make_tuple(stitched_graph, final_index_size);
|
||||
}
|
||||
|
||||
/*
|
||||
* Applies the prune_neighbors function from src/index.cpp to
|
||||
* every node in the stitched graph.
|
||||
*
|
||||
* This is an optional step, hence the saving of both the full
|
||||
* and pruned graph.
|
||||
*/
|
||||
template <typename T>
|
||||
void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path,
|
||||
std::vector<std::vector<uint32_t>> stitched_graph, uint32_t stitched_R,
|
||||
tsl::robin_map<std::string, uint32_t> label_entry_points, std::string universal_label,
|
||||
path label_data_path, uint32_t num_threads)
|
||||
{
|
||||
size_t dimension, number_of_label_points;
|
||||
auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr);
|
||||
auto std_cout_buffer = std::cout.rdbuf(nullptr);
|
||||
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
|
||||
|
||||
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
|
||||
|
||||
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false,
|
||||
false, false, 0, false);
|
||||
|
||||
// not searching this index, set search_l to 0
|
||||
index.load(full_index_path_prefix.c_str(), num_threads, 1);
|
||||
|
||||
std::cout << "parsing labels" << std::endl;
|
||||
|
||||
index.prune_all_neighbors(stitched_R, 750, 1.2);
|
||||
index.save((final_index_path_prefix).c_str());
|
||||
|
||||
diskann::cout.rdbuf(diskann_cout_buffer);
|
||||
std::cout.rdbuf(std_cout_buffer);
|
||||
std::chrono::duration<double> pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer;
|
||||
std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Delete all temporary artifacts.
|
||||
* In the process of creating the stitched index, some temporary artifacts are
|
||||
* created:
|
||||
* 1. the separate bin files for each labels' points
|
||||
* 2. the separate diskANN indices built for each label
|
||||
* 3. the '.data' file created while generating the indices
|
||||
*/
|
||||
void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels)
|
||||
{
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
path curr_label_input_data_path(input_data_path + "_" + lbl);
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
path curr_label_index_path_data(curr_label_index_path + ".data");
|
||||
|
||||
if (std::remove(curr_label_index_path.c_str()) != 0)
|
||||
throw;
|
||||
if (std::remove(curr_label_input_data_path.c_str()) != 0)
|
||||
throw;
|
||||
if (std::remove(curr_label_index_path_data.c_str()) != 0)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
// 1. handle cmdline inputs
|
||||
std::string data_type;
|
||||
path input_data_path, final_index_path_prefix, label_data_path;
|
||||
std::string universal_label;
|
||||
uint32_t num_threads, R, L, stitched_R;
|
||||
float alpha;
|
||||
|
||||
auto index_timer = std::chrono::high_resolution_clock::now();
|
||||
handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label,
|
||||
num_threads, R, L, stitched_R, alpha);
|
||||
|
||||
path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt";
|
||||
path labels_map_file = final_index_path_prefix + "_labels_map.txt";
|
||||
|
||||
convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label);
|
||||
|
||||
// 2. parse label file and create necessary data structures
|
||||
std::vector<label_set> point_ids_to_labels;
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
|
||||
label_set all_labels;
|
||||
|
||||
std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) =
|
||||
diskann::parse_label_file(labels_file_to_use, universal_label);
|
||||
|
||||
// 3. for each label, make a separate data file
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map;
|
||||
uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size();
|
||||
|
||||
#ifndef _WINDOWS
|
||||
if (data_type == "uint8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<uint8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "int8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<int8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "float")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<float>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else
|
||||
throw;
|
||||
#else
|
||||
if (data_type == "uint8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "int8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "float")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else
|
||||
throw;
|
||||
#endif
|
||||
|
||||
// 4. for each created data file, create a vanilla diskANN index
|
||||
if (data_type == "uint8")
|
||||
diskann::generate_label_indices<uint8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else if (data_type == "int8")
|
||||
diskann::generate_label_indices<int8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else if (data_type == "float")
|
||||
diskann::generate_label_indices<float>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else
|
||||
throw;
|
||||
|
||||
// 5. "stitch" the indices together
|
||||
std::vector<std::vector<uint32_t>> stitched_graph;
|
||||
tsl::robin_map<std::string, uint32_t> label_entry_points;
|
||||
uint64_t stitched_graph_size;
|
||||
|
||||
if (data_type == "uint8")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<uint8_t>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else if (data_type == "int8")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<int8_t>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else if (data_type == "float")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<float>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else
|
||||
throw;
|
||||
path full_index_path_prefix = final_index_path_prefix + "_full";
|
||||
// 5a. save the stitched graph to disk
|
||||
save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points,
|
||||
universal_label, labels_file_to_use);
|
||||
|
||||
// 6. run a prune on the stitched index, and save to disk
|
||||
if (data_type == "uint8")
|
||||
prune_and_save<uint8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else if (data_type == "int8")
|
||||
prune_and_save<int8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else if (data_type == "float")
|
||||
prune_and_save<float>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else
|
||||
throw;
|
||||
|
||||
std::chrono::duration<double> index_time = std::chrono::high_resolution_clock::now() - index_timer;
|
||||
std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl;
|
||||
|
||||
clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels);
|
||||
}
|
||||
46
packages/leann-backend-diskann/third_party/DiskANN/apps/python/README.md
vendored
Normal file
46
packages/leann-backend-diskann/third_party/DiskANN/apps/python/README.md
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
<!-- Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Licensed under the MIT license. -->
|
||||
|
||||
# Integration Tests
|
||||
The following tests use Python to prepare, run, verify, and tear down the rest api services.
|
||||
|
||||
We do make use of the built-in `unittest` library, but that's only to take advantage of test reporting purposes.
|
||||
|
||||
These are decidedly **not** _unit_ tests. These are end to end integration tests.
|
||||
|
||||
## Caveats
|
||||
This has only been tested or built for Linux, though we have written platform agnostic Python for the smoke test
|
||||
(i.e. using `os.path.join`, etc)
|
||||
|
||||
It has been tested on Python 3.9 and 3.10, but should work on Python 3.6+.
|
||||
|
||||
## How to Run
|
||||
|
||||
First, build the DiskANN RestAPI code; see $REPOSITORY_ROOT/workflows/rest_api.md for detailed instructions.
|
||||
|
||||
```bash
|
||||
cd tests/python
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
export DISKANN_BUILD_DIR=/path/to/your/diskann/build
|
||||
python -m unittest
|
||||
```
|
||||
|
||||
## Smoke Test Failed, Now What?
|
||||
The smoke test written takes advantage of temporary directories that are only valid during the
|
||||
lifetime of the test. The contents of these directories include:
|
||||
- Randomized vectors (first in tsv, then bin form) used to build the PQFlashIndex
|
||||
- The PQFlashIndex files
|
||||
|
||||
It is useful to keep these around. By setting some environment variables, you can control whether an ephemeral,
|
||||
temporary directory is used (and deleted on test completion), or left as an exercise for the developer to
|
||||
clean up.
|
||||
|
||||
The valid environment variables are:
|
||||
- `DISKANN_REST_TEST_WORKING_DIR` (example: `$USER/DiskANNRestTest`)
|
||||
- If this is specified, it **must exist** and **must be writeable**. Any existing files will be clobbered.
|
||||
- `DISKANN_REST_SERVER` (example: `http://127.0.0.1:10067`)
|
||||
- Note that if this is set, no data will be generated, nor will a server be started; it is presumed you have done
|
||||
all the work in creating and starting the rest server prior to running the test and just submits requests against it.
|
||||
0
packages/leann-backend-diskann/third_party/DiskANN/apps/python/restapi/__init__.py
vendored
Normal file
0
packages/leann-backend-diskann/third_party/DiskANN/apps/python/restapi/__init__.py
vendored
Normal file
67
packages/leann-backend-diskann/third_party/DiskANN/apps/python/restapi/disk_ann_util.py
vendored
Normal file
67
packages/leann-backend-diskann/third_party/DiskANN/apps/python/restapi/disk_ann_util.py
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def output_vectors(
|
||||
diskann_build_path: str,
|
||||
temporary_file_path: str,
|
||||
vectors: np.ndarray,
|
||||
timeout: int = 60
|
||||
) -> str:
|
||||
vectors_as_tsv_path = os.path.join(temporary_file_path, "vectors.tsv")
|
||||
with open(vectors_as_tsv_path, "w") as vectors_tsv_out:
|
||||
for vector in vectors:
|
||||
as_str = "\t".join((str(component) for component in vector))
|
||||
print(as_str, file=vectors_tsv_out)
|
||||
# there is probably a clever way to have numpy write out C++ friendly floats, so feel free to remove this in
|
||||
# favor of something more sane later
|
||||
vectors_as_bin_path = os.path.join(temporary_file_path, "vectors.bin")
|
||||
tsv_to_bin_path = os.path.join(diskann_build_path, "apps", "utils", "tsv_to_bin")
|
||||
|
||||
number_of_points, dimensions = vectors.shape
|
||||
args = [
|
||||
tsv_to_bin_path,
|
||||
"float",
|
||||
vectors_as_tsv_path,
|
||||
vectors_as_bin_path,
|
||||
str(dimensions),
|
||||
str(number_of_points)
|
||||
]
|
||||
completed = subprocess.run(args, timeout=timeout)
|
||||
if completed.returncode != 0:
|
||||
raise Exception(f"Unable to convert tsv to binary using tsv_to_bin, completed_process: {completed}")
|
||||
return vectors_as_bin_path
|
||||
|
||||
|
||||
def build_ssd_index(
|
||||
diskann_build_path: str,
|
||||
temporary_file_path: str,
|
||||
vectors: np.ndarray,
|
||||
per_process_timeout: int = 60 # this may not be long enough if you're doing something larger
|
||||
):
|
||||
vectors_as_bin_path = output_vectors(diskann_build_path, temporary_file_path, vectors, timeout=per_process_timeout)
|
||||
|
||||
ssd_builder_path = os.path.join(diskann_build_path, "apps", "build_disk_index")
|
||||
args = [
|
||||
ssd_builder_path,
|
||||
"--data_type", "float",
|
||||
"--dist_fn", "l2",
|
||||
"--data_path", vectors_as_bin_path,
|
||||
"--index_path_prefix", os.path.join(temporary_file_path, "smoke_test"),
|
||||
"-R", "64",
|
||||
"-L", "100",
|
||||
"--search_DRAM_budget", "1",
|
||||
"--build_DRAM_budget", "1",
|
||||
"--num_threads", "1",
|
||||
"--PQ_disk_bytes", "0"
|
||||
]
|
||||
completed = subprocess.run(args, timeout=per_process_timeout)
|
||||
|
||||
if completed.returncode != 0:
|
||||
command_run = " ".join(args)
|
||||
raise Exception(f"Unable to build a disk index with the command: '{command_run}'\ncompleted_process: {completed}\nstdout: {completed.stdout}\nstderr: {completed.stderr}")
|
||||
# index is now built inside of temporary_file_path
|
||||
379
packages/leann-backend-diskann/third_party/DiskANN/apps/range_search_disk_index.cpp
vendored
Normal file
379
packages/leann-backend-diskann/third_party/DiskANN/apps/range_search_disk_index.cpp
vendored
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "pq_flash_index.h"
|
||||
#include "partition.h"
|
||||
#include "timer.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include "linux_aligned_file_reader.h"
|
||||
#else
|
||||
#ifdef USE_BING_INFRA
|
||||
#include "bing_aligned_file_reader.h"
|
||||
#else
|
||||
#include "windows_aligned_file_reader.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
#define WARMUP false
|
||||
|
||||
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
|
||||
{
|
||||
diskann::cout << std::setw(20) << category << ": " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(8) << percentiles[s] << "%";
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << std::setw(22) << " " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(9) << results[s];
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file,
|
||||
std::string >_file, const uint32_t num_threads, const float search_range,
|
||||
const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector<uint32_t> &Lvec)
|
||||
{
|
||||
std::string pq_prefix = index_path_prefix + "_pq";
|
||||
std::string disk_index_file = index_path_prefix + "_disk.index";
|
||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
||||
|
||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
||||
if (beamwidth <= 0)
|
||||
diskann::cout << "beamwidth to be optimized for each L value" << std::endl;
|
||||
else
|
||||
diskann::cout << " beamwidth: " << beamwidth << std::endl;
|
||||
|
||||
// load query bin
|
||||
T *query = nullptr;
|
||||
std::vector<std::vector<uint32_t>> groundtruth_ids;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (gt_file != std::string("null") && file_exists(gt_file))
|
||||
{
|
||||
diskann::load_range_truthset(gt_file, groundtruth_ids,
|
||||
gt_num); // use for range search type of truthset
|
||||
// diskann::prune_truthset_for_range(gt_file, search_range,
|
||||
// groundtruth_ids, gt_num); // use for traditional truthset
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<AlignedFileReader> reader = nullptr;
|
||||
#ifdef _WINDOWS
|
||||
#ifndef USE_BING_INFRA
|
||||
reader.reset(new WindowsAlignedFileReader());
|
||||
#else
|
||||
reader.reset(new diskann::BingAlignedFileReader());
|
||||
#endif
|
||||
#else
|
||||
reader.reset(new LinuxAlignedFileReader());
|
||||
#endif
|
||||
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
||||
|
||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
||||
|
||||
if (res != 0)
|
||||
{
|
||||
return res;
|
||||
}
|
||||
// cache bfs levels
|
||||
std::vector<uint32_t> node_list;
|
||||
diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl;
|
||||
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
|
||||
// _pFlashIndex->generate_cache_list_from_sample_queries(
|
||||
// warmup_query_file, 15, 6, num_nodes_to_cache, num_threads,
|
||||
// node_list);
|
||||
_pFlashIndex->load_cache_list(node_list);
|
||||
node_list.clear();
|
||||
node_list.shrink_to_fit();
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
|
||||
uint64_t warmup_L = 20;
|
||||
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
|
||||
T *warmup = nullptr;
|
||||
|
||||
if (WARMUP)
|
||||
{
|
||||
if (file_exists(warmup_query_file))
|
||||
{
|
||||
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
|
||||
}
|
||||
else
|
||||
{
|
||||
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
|
||||
warmup_dim = query_dim;
|
||||
warmup_aligned_dim = query_aligned_dim;
|
||||
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(-128, 127);
|
||||
for (uint32_t i = 0; i < warmup_num; i++)
|
||||
{
|
||||
for (uint32_t d = 0; d < warmup_dim; d++)
|
||||
{
|
||||
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
diskann::cout << "Warming up index... " << std::flush;
|
||||
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
|
||||
std::vector<float> warmup_result_dists(warmup_num, 0);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
|
||||
warmup_result_ids_64.data() + (i * 1),
|
||||
warmup_result_dists.data() + (i * 1), 4);
|
||||
}
|
||||
diskann::cout << "..done" << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
diskann::cout.precision(2);
|
||||
|
||||
std::string recall_string = "Recall@rng=" + std::to_string(search_range);
|
||||
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
|
||||
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
|
||||
<< "CPU (s)";
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall_string << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << "==============================================================="
|
||||
"==========================================="
|
||||
<< std::endl;
|
||||
|
||||
std::vector<std::vector<std::vector<uint32_t>>> query_result_ids(Lvec.size());
|
||||
|
||||
uint32_t optimized_beamwidth = 2;
|
||||
uint32_t max_list_size = 10000;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
|
||||
if (beamwidth <= 0)
|
||||
{
|
||||
optimized_beamwidth =
|
||||
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
|
||||
}
|
||||
else
|
||||
optimized_beamwidth = beamwidth;
|
||||
|
||||
query_result_ids[test_id].clear();
|
||||
query_result_ids[test_id].resize(query_num);
|
||||
|
||||
diskann::QueryStats *stats = new diskann::QueryStats[query_num];
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
std::vector<uint64_t> indices;
|
||||
std::vector<float> distances;
|
||||
uint32_t res_count =
|
||||
_pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices,
|
||||
distances, optimized_beamwidth, stats + i);
|
||||
query_result_ids[test_id][i].reserve(res_count);
|
||||
query_result_ids[test_id][i].resize(res_count);
|
||||
for (uint32_t idx = 0; idx < res_count; idx++)
|
||||
query_result_ids[test_id][i][idx] = (uint32_t)indices[idx];
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e - s;
|
||||
auto qps = (1.0 * query_num) / (1.0 * diff.count());
|
||||
|
||||
auto mean_latency = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto latency_999 = diskann::get_percentile_stats<float>(
|
||||
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.n_ios; });
|
||||
|
||||
double mean_cpuus = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; });
|
||||
|
||||
double recall = 0;
|
||||
double ratio_of_sums = 0;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recall =
|
||||
diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]);
|
||||
|
||||
uint32_t total_true_positive = 0;
|
||||
uint32_t total_positive = 0;
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
total_true_positive += (uint32_t)query_result_ids[test_id][i].size();
|
||||
total_positive += (uint32_t)groundtruth_ids[i].size();
|
||||
}
|
||||
|
||||
ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive);
|
||||
}
|
||||
|
||||
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
|
||||
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
|
||||
<< std::setw(16) << mean_cpuus;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall << "," << ratio_of_sums << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout << "Done searching. " << std::endl;
|
||||
|
||||
diskann::aligned_free(query);
|
||||
if (warmup != nullptr)
|
||||
diskann::aligned_free(warmup);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file;
|
||||
uint32_t num_threads, W, num_nodes_to_cache;
|
||||
std::vector<uint32_t> Lvec;
|
||||
float range;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description(
|
||||
"range_search_disk_index", "Searches disk DiskANN indexes using ranges")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
required_configs.add_options()("range_threshold,K", po::value<float>(&range)->required(),
|
||||
"Number of neighbors to be returned");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
|
||||
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
|
||||
program_options_utils::BEAMWIDTH);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
|
||||
{
|
||||
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
40
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/CMakeLists.txt
vendored
Normal file
40
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
add_executable(inmem_server inmem_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(inmem_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(inmem_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(inmem_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(inmem_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(ssd_server ssd_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(ssd_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(ssd_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(ssd_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(ssd_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(multiple_ssdindex_server multiple_ssdindex_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(multiple_ssdindex_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(multiple_ssdindex_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(multiple_ssdindex_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(multiple_ssdindex_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(client client.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(client PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(client debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(client optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(client ${PROJECT_NAME} -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
124
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/client.cpp
vendored
Normal file
124
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/client.cpp
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include <cpprest/http_client.h>
|
||||
#include <restapi/common.h>
|
||||
|
||||
using namespace web;
|
||||
using namespace web::http;
|
||||
using namespace web::http::client;
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T>
|
||||
void query_loop(const std::string &ip_addr_port, const std::string &query_file, const unsigned nq, const unsigned Ls,
|
||||
const unsigned k_value)
|
||||
{
|
||||
web::http::client::http_client client(U(ip_addr_port));
|
||||
|
||||
T *data;
|
||||
size_t npts = 1, ndims = 128, rounded_dim = 128;
|
||||
diskann::load_aligned_bin<T>(query_file, data, npts, ndims, rounded_dim);
|
||||
|
||||
for (unsigned i = 0; i < nq; ++i)
|
||||
{
|
||||
T *vec = data + i * rounded_dim;
|
||||
web::http::http_request http_query(methods::POST);
|
||||
web::json::value queryJson = web::json::value::object();
|
||||
queryJson[QUERY_ID_KEY] = i;
|
||||
queryJson[K_KEY] = k_value;
|
||||
queryJson[L_KEY] = Ls;
|
||||
for (size_t i = 0; i < ndims; ++i)
|
||||
{
|
||||
queryJson[VECTOR_KEY][i] = web::json::value::number(vec[i]);
|
||||
}
|
||||
http_query.set_body(queryJson);
|
||||
|
||||
client.request(http_query)
|
||||
.then([](web::http::http_response response) -> pplx::task<utility::string_t> {
|
||||
if (response.status_code() == status_codes::OK)
|
||||
{
|
||||
return response.extract_string();
|
||||
}
|
||||
std::cerr << "Query failed" << std::endl;
|
||||
return pplx::task_from_result(utility::string_t());
|
||||
})
|
||||
.then([](pplx::task<utility::string_t> previousTask) {
|
||||
try
|
||||
{
|
||||
std::cout << previousTask.get() << std::endl;
|
||||
}
|
||||
catch (http_exception const &e)
|
||||
{
|
||||
std::wcout << e.what() << std::endl;
|
||||
}
|
||||
})
|
||||
.wait();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, query_file, address;
|
||||
uint32_t num_queries;
|
||||
uint32_t l_search, k_value;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the queries to search");
|
||||
desc.add_options()("num_queries,Q", po::value<uint32_t>(&num_queries)->required(),
|
||||
"Number of queries to search");
|
||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
|
||||
desc.add_options()("k_value,K", po::value<uint32_t>(&k_value)->default_value(10), "Value of K (default 10)");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
query_loop<float>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
query_loop<int8_t>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
query_loop<uint8_t>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported type " << argv[2] << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
138
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/inmem_server.cpp
vendored
Normal file
138
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/inmem_server.cpp
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_inMemorySearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
|
||||
uint32_t num_threads;
|
||||
uint32_t l_search;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("data_file", po::value<std::string>(&data_file)->required(),
|
||||
"File containing the data found in the index");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_file)->required(),
|
||||
"Path prefix for saving index file components");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->required(),
|
||||
"Number of threads used for building index");
|
||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<float>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
83
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/main.cpp
vendored
Normal file
83
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/main.cpp
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <restapi/server.h>
|
||||
#include <restapi/in_memory_search.h>
|
||||
#include <codecvt>
|
||||
#include <iostream>
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::unique_ptr<diskann::InMemorySearch> g_inMemorySearch(nullptr);
|
||||
|
||||
void setup(const utility::string_t &address)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::wcout << L"Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch));
|
||||
g_httpServer->open().wait();
|
||||
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
void loadIndex(const char *indexFile, const char *baseFile, const char *idsFile)
|
||||
{
|
||||
auto nsgSearch = new diskann::InMemorySearch(baseFile, indexFile, idsFile, diskann::L2);
|
||||
g_inMemorySearch = std::unique_ptr<diskann::InMemorySearch>(nsgSearch);
|
||||
}
|
||||
|
||||
std::wstring getHostingAddress(const char *hostNameAndPort)
|
||||
{
|
||||
wchar_t buffer[4096];
|
||||
mbstowcs_s(nullptr, buffer, sizeof(buffer) / sizeof(buffer[0]), hostNameAndPort,
|
||||
sizeof(buffer) / sizeof(buffer[0]));
|
||||
return std::wstring(buffer);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: nsg_server <ip_addr_and_port> <index_file> "
|
||||
"<base_file> <ids_file> "
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto address = getHostingAddress(argv[1]);
|
||||
loadIndex(argv[2], argv[3], argv[4]);
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
182
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/multiple_ssdindex_server.cpp
vendored
Normal file
182
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/multiple_ssdindex_server.cpp
vendored
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <omp.h>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("index_prefix_paths", po::value<std::string>(&index_prefix_paths)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
"Number of nodes to cache during search");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> index_tag_paths;
|
||||
std::ifstream index_in(index_prefix_paths);
|
||||
if (!index_in.is_open())
|
||||
{
|
||||
std::cerr << "Could not open " << index_prefix_paths << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream tags_in(tags_file);
|
||||
if (!tags_in.is_open())
|
||||
{
|
||||
std::cerr << "Could not open " << tags_file << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::string prefix, tagfile;
|
||||
while (std::getline(index_in, prefix))
|
||||
{
|
||||
if (std::getline(tags_in, tagfile))
|
||||
{
|
||||
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "The number of tags specified does not match the number of "
|
||||
"indices specified"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
index_in.close();
|
||||
tags_in.close();
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<float>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<uint8_t>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << data_type << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
141
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/ssd_server.cpp
vendored
Normal file
141
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/ssd_server.cpp
vendored
Normal file
@@ -0,0 +1,141 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <omp.h>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
"Number of nodes to cache during search");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<int8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<uint8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
499
packages/leann-backend-diskann/third_party/DiskANN/apps/search_disk_index.cpp
vendored
Normal file
499
packages/leann-backend-diskann/third_party/DiskANN/apps/search_disk_index.cpp
vendored
Normal file
@@ -0,0 +1,499 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "common_includes.h"
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "partition.h"
|
||||
#include "pq_flash_index.h"
|
||||
#include "timer.h"
|
||||
#include "percentile_stats.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include "linux_aligned_file_reader.h"
|
||||
#else
|
||||
#ifdef USE_BING_INFRA
|
||||
#include "bing_aligned_file_reader.h"
|
||||
#else
|
||||
#include "windows_aligned_file_reader.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define WARMUP false
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
|
||||
{
|
||||
diskann::cout << std::setw(20) << category << ": " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(8) << percentiles[s] << "%";
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << std::setw(22) << " " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(9) << results[s];
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
|
||||
const std::string &result_output_prefix, const std::string &query_file, std::string >_file,
|
||||
const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth,
|
||||
const uint32_t num_nodes_to_cache, const uint32_t search_io_limit,
|
||||
const std::vector<uint32_t> &Lvec, const float fail_if_recall_below,
|
||||
const std::vector<std::string> &query_filters, const bool use_reorder_data = false)
|
||||
{
|
||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
||||
if (beamwidth <= 0)
|
||||
diskann::cout << "beamwidth to be optimized for each L value" << std::flush;
|
||||
else
|
||||
diskann::cout << " beamwidth: " << beamwidth << std::flush;
|
||||
if (search_io_limit == std::numeric_limits<uint32_t>::max())
|
||||
diskann::cout << "." << std::endl;
|
||||
else
|
||||
diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl;
|
||||
|
||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
||||
|
||||
// load query bin
|
||||
T *query = nullptr;
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool filtered_search = false;
|
||||
if (!query_filters.empty())
|
||||
{
|
||||
filtered_search = true;
|
||||
if (query_filters.size() != 1 && query_filters.size() != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and size of query "
|
||||
"filters file"
|
||||
<< std::endl;
|
||||
return -1; // To return -1 or some other error handling?
|
||||
}
|
||||
}
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file))
|
||||
{
|
||||
diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<AlignedFileReader> reader = nullptr;
|
||||
#ifdef _WINDOWS
|
||||
#ifndef USE_BING_INFRA
|
||||
reader.reset(new WindowsAlignedFileReader());
|
||||
#else
|
||||
reader.reset(new diskann::BingAlignedFileReader());
|
||||
#endif
|
||||
#else
|
||||
reader.reset(new LinuxAlignedFileReader());
|
||||
#endif
|
||||
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
||||
|
||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
||||
|
||||
if (res != 0)
|
||||
{
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> node_list;
|
||||
diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl;
|
||||
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
|
||||
// if (num_nodes_to_cache > 0)
|
||||
// _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache,
|
||||
// num_threads, node_list);
|
||||
_pFlashIndex->load_cache_list(node_list);
|
||||
node_list.clear();
|
||||
node_list.shrink_to_fit();
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
|
||||
uint64_t warmup_L = 20;
|
||||
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
|
||||
T *warmup = nullptr;
|
||||
|
||||
if (WARMUP)
|
||||
{
|
||||
if (file_exists(warmup_query_file))
|
||||
{
|
||||
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
|
||||
}
|
||||
else
|
||||
{
|
||||
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
|
||||
warmup_dim = query_dim;
|
||||
warmup_aligned_dim = query_aligned_dim;
|
||||
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(-128, 127);
|
||||
for (uint32_t i = 0; i < warmup_num; i++)
|
||||
{
|
||||
for (uint32_t d = 0; d < warmup_dim; d++)
|
||||
{
|
||||
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
diskann::cout << "Warming up index... " << std::flush;
|
||||
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
|
||||
std::vector<float> warmup_result_dists(warmup_num, 0);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
|
||||
warmup_result_ids_64.data() + (i * 1),
|
||||
warmup_result_dists.data() + (i * 1), 4);
|
||||
}
|
||||
diskann::cout << "..done" << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
diskann::cout.precision(2);
|
||||
|
||||
std::string recall_string = "Recall@" + std::to_string(recall_at);
|
||||
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
|
||||
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
|
||||
<< "Mean IO (us)" << std::setw(16) << "CPU (s)";
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall_string << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << "=================================================================="
|
||||
"================================================================="
|
||||
<< std::endl;
|
||||
|
||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
||||
|
||||
uint32_t optimized_beamwidth = 2;
|
||||
|
||||
double best_recall = 0.0;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (beamwidth <= 0)
|
||||
{
|
||||
diskann::cout << "Tuning beamwidth.." << std::endl;
|
||||
optimized_beamwidth =
|
||||
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
|
||||
}
|
||||
else
|
||||
optimized_beamwidth = beamwidth;
|
||||
|
||||
query_result_ids[test_id].resize(recall_at * query_num);
|
||||
query_result_dists[test_id].resize(recall_at * query_num);
|
||||
|
||||
auto stats = new diskann::QueryStats[query_num];
|
||||
|
||||
std::vector<uint64_t> query_result_ids_64(recall_at * query_num);
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
if (!filtered_search)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
|
||||
query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at),
|
||||
optimized_beamwidth, use_reorder_data, stats + i);
|
||||
}
|
||||
else
|
||||
{
|
||||
LabelT label_for_search;
|
||||
if (query_filters.size() == 1)
|
||||
{ // one label for all queries
|
||||
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
|
||||
}
|
||||
else
|
||||
{ // one label for each query
|
||||
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
|
||||
}
|
||||
_pFlashIndex->cached_beam_search(
|
||||
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
|
||||
use_reorder_data, stats + i);
|
||||
}
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e - s;
|
||||
double qps = (1.0 * query_num) / (1.0 * diff.count());
|
||||
|
||||
diskann::convert_types<uint64_t, uint32_t>(query_result_ids_64.data(), query_result_ids[test_id].data(),
|
||||
query_num, recall_at);
|
||||
|
||||
auto mean_latency = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto latency_999 = diskann::get_percentile_stats<float>(
|
||||
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.n_ios; });
|
||||
|
||||
auto mean_cpuus = diskann::get_mean_stats<float>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.cpu_us; });
|
||||
|
||||
auto mean_io_us = diskann::get_mean_stats<float>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.io_us; });
|
||||
|
||||
double recall = 0;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
|
||||
query_result_ids[test_id].data(), recall_at, recall_at);
|
||||
best_recall = std::max(recall, best_recall);
|
||||
}
|
||||
|
||||
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
|
||||
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
|
||||
<< std::setw(16) << mean_io_us << std::setw(16) << mean_cpuus;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
delete[] stats;
|
||||
}
|
||||
|
||||
diskann::cout << "Done searching. Now saving results " << std::endl;
|
||||
uint64_t test_id = 0;
|
||||
for (auto L : Lvec)
|
||||
{
|
||||
if (L < recall_at)
|
||||
continue;
|
||||
|
||||
std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin";
|
||||
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
|
||||
|
||||
cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin";
|
||||
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at);
|
||||
}
|
||||
|
||||
diskann::aligned_free(query);
|
||||
if (warmup != nullptr)
|
||||
diskann::aligned_free(warmup);
|
||||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label,
|
||||
label_type, query_filters_file;
|
||||
uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit;
|
||||
std::vector<uint32_t> Lvec;
|
||||
bool use_reorder_data = false;
|
||||
float fail_if_recall_below = 0.0f;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("result_path", po::value<std::string>(&result_path_prefix)->required(),
|
||||
program_options_utils::RESULT_PATH_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
|
||||
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
|
||||
program_options_utils::BEAMWIDTH);
|
||||
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
|
||||
optional_configs.add_options()(
|
||||
"search_io_limit",
|
||||
po::value<uint32_t>(&search_io_limit)->default_value(std::numeric_limits<uint32_t>::max()),
|
||||
"Max #IOs for search. Default value: uint32::max()");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false),
|
||||
"Include full precision data in the index. Use only in "
|
||||
"conjuction with compressed data on SSD. Default value: false");
|
||||
optional_configs.add_options()("filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
||||
program_options_utils::FILTER_LABEL_DESCRIPTION);
|
||||
optional_configs.add_options()("query_filters_file",
|
||||
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
|
||||
program_options_utils::FILTERS_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
optional_configs.add_options()("fail_if_recall_below",
|
||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
||||
program_options_utils::FAIL_IF_RECALL_BELOW);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (vm["use_reorder_data"].as<bool>())
|
||||
use_reorder_data = true;
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
|
||||
{
|
||||
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (use_reorder_data && data_type != std::string("float"))
|
||||
{
|
||||
std::cout << "Error: Reorder data for reordering currently only "
|
||||
"supported for float data type."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && query_filters_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> query_filters;
|
||||
if (filter_label != "")
|
||||
{
|
||||
query_filters.push_back(filter_label);
|
||||
}
|
||||
else if (query_filters_file != "")
|
||||
{
|
||||
query_filters = read_file_to_vector_of_strings(query_filters_file);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (!query_filters.empty() && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
477
packages/leann-backend-diskann/third_party/DiskANN/apps/search_memory_index.cpp
vendored
Normal file
477
packages/leann-backend-diskann/third_party/DiskANN/apps/search_memory_index.cpp
vendored
Normal file
@@ -0,0 +1,477 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <string.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
#include "index_factory.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
|
||||
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
|
||||
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
|
||||
const bool dynamic, const bool tags, const bool show_qps_per_thread,
|
||||
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
|
||||
{
|
||||
using TagT = uint32_t;
|
||||
// Load the query file
|
||||
T *query = nullptr;
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (truthset_file != std::string("null") && file_exists(truthset_file))
|
||||
{
|
||||
diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl;
|
||||
}
|
||||
|
||||
bool filtered_search = false;
|
||||
if (!query_filters.empty())
|
||||
{
|
||||
filtered_search = true;
|
||||
if (query_filters.size() != 1 && query_filters.size() != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and size of query "
|
||||
"filters file"
|
||||
<< std::endl;
|
||||
return -1; // To return -1 or some other error handling?
|
||||
}
|
||||
}
|
||||
|
||||
const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path);
|
||||
|
||||
auto config = diskann::IndexConfigBuilder()
|
||||
.with_metric(metric)
|
||||
.with_dimension(query_dim)
|
||||
.with_max_points(0)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.is_dynamic_index(dynamic)
|
||||
.is_enable_tags(tags)
|
||||
.is_concurrent_consolidate(false)
|
||||
.is_pq_dist_build(false)
|
||||
.is_use_opq(false)
|
||||
.with_num_pq_chunks(0)
|
||||
.with_num_frozen_pts(num_frozen_pts)
|
||||
.build();
|
||||
|
||||
auto index_factory = diskann::IndexFactory(config);
|
||||
auto index = index_factory.create_instance();
|
||||
index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end())));
|
||||
std::cout << "Index loaded" << std::endl;
|
||||
|
||||
if (metric == diskann::FAST_L2)
|
||||
index->optimize_index_layout();
|
||||
|
||||
std::cout << "Using " << num_threads << " threads to search" << std::endl;
|
||||
std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
std::cout.precision(2);
|
||||
const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS";
|
||||
uint32_t table_width = 0;
|
||||
if (tags)
|
||||
{
|
||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)"
|
||||
<< std::setw(15) << "99.9 Latency";
|
||||
table_width += 4 + 12 + 20 + 15;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps"
|
||||
<< std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency";
|
||||
table_width += 4 + 12 + 18 + 20 + 15;
|
||||
}
|
||||
uint32_t recalls_to_print = 0;
|
||||
const uint32_t first_recall = print_all_recalls ? 1 : recall_at;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
|
||||
{
|
||||
std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall));
|
||||
}
|
||||
recalls_to_print = recall_at + 1 - first_recall;
|
||||
table_width += recalls_to_print * 12;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << std::string(table_width, '=') << std::endl;
|
||||
|
||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
||||
std::vector<float> latency_stats(query_num, 0);
|
||||
std::vector<uint32_t> cmp_stats;
|
||||
if (not tags || filtered_search)
|
||||
{
|
||||
cmp_stats = std::vector<uint32_t>(query_num, 0);
|
||||
}
|
||||
|
||||
std::vector<TagT> query_result_tags;
|
||||
if (tags)
|
||||
{
|
||||
query_result_tags.resize(recall_at * query_num);
|
||||
}
|
||||
|
||||
double best_recall = 0.0;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
query_result_ids[test_id].resize(recall_at * query_num);
|
||||
query_result_dists[test_id].resize(recall_at * query_num);
|
||||
std::vector<T *> res = std::vector<T *>();
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
omp_set_num_threads(num_threads);
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
auto qs = std::chrono::high_resolution_clock::now();
|
||||
if (filtered_search && !tags)
|
||||
{
|
||||
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
|
||||
|
||||
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at,
|
||||
query_result_dists[test_id].data() + i * recall_at);
|
||||
cmp_stats[i] = retval.second;
|
||||
}
|
||||
else if (metric == diskann::FAST_L2)
|
||||
{
|
||||
index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at);
|
||||
}
|
||||
else if (tags)
|
||||
{
|
||||
if (!filtered_search)
|
||||
{
|
||||
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_tags.data() + i * recall_at, nullptr, res);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
|
||||
|
||||
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter);
|
||||
}
|
||||
|
||||
for (int64_t r = 0; r < (int64_t)recall_at; r++)
|
||||
{
|
||||
query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cmp_stats[i] = index
|
||||
->search(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at)
|
||||
.second;
|
||||
}
|
||||
auto qe = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = qe - qs;
|
||||
latency_stats[i] = (float)(diff.count() * 1000000);
|
||||
}
|
||||
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
|
||||
|
||||
double displayed_qps = query_num / diff.count();
|
||||
|
||||
if (show_qps_per_thread)
|
||||
displayed_qps /= num_threads;
|
||||
|
||||
std::vector<double> recalls;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recalls.reserve(recalls_to_print);
|
||||
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
|
||||
{
|
||||
recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
|
||||
query_result_ids[test_id].data(), recall_at, curr_recall));
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(latency_stats.begin(), latency_stats.end());
|
||||
double mean_latency =
|
||||
std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast<float>(query_num);
|
||||
|
||||
float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num;
|
||||
|
||||
if (tags && !filtered_search)
|
||||
{
|
||||
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency
|
||||
<< std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)];
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps
|
||||
<< std::setw(20) << (float)mean_latency << std::setw(15)
|
||||
<< (float)latency_stats[(uint64_t)(0.999 * query_num)];
|
||||
}
|
||||
for (double recall : recalls)
|
||||
{
|
||||
std::cout << std::setw(12) << recall;
|
||||
best_recall = std::max(recall, best_recall);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Done searching. Now saving results " << std::endl;
|
||||
uint64_t test_id = 0;
|
||||
for (auto L : Lvec)
|
||||
{
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L);
|
||||
|
||||
std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin";
|
||||
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
|
||||
|
||||
cur_result_path = cur_result_path_prefix + "_dists_float.bin";
|
||||
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at);
|
||||
|
||||
test_id++;
|
||||
}
|
||||
|
||||
diskann::aligned_free(query);
|
||||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
|
||||
query_filters_file;
|
||||
uint32_t num_threads, K;
|
||||
std::vector<uint32_t> Lvec;
|
||||
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
|
||||
float fail_if_recall_below = 0.0f;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print this information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("result_path", po::value<std::string>(&result_path)->required(),
|
||||
program_options_utils::RESULT_PATH_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
|
||||
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
||||
program_options_utils::FILTER_LABEL_DESCRIPTION);
|
||||
optional_configs.add_options()("query_filters_file",
|
||||
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
|
||||
program_options_utils::FILTERS_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()(
|
||||
"dynamic", po::value<bool>(&dynamic)->default_value(false),
|
||||
"Whether the index is dynamic. Dynamic indices must have associated tags. Default false.");
|
||||
optional_configs.add_options()("tags", po::value<bool>(&tags)->default_value(false),
|
||||
"Whether to search with external identifiers (tags). Default false.");
|
||||
optional_configs.add_options()("fail_if_recall_below",
|
||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
||||
program_options_utils::FAIL_IF_RECALL_BELOW);
|
||||
|
||||
// Output controls
|
||||
po::options_description output_controls("Output controls");
|
||||
output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls),
|
||||
"Print recalls at all positions, from 1 up to specified "
|
||||
"recall_at value");
|
||||
output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread),
|
||||
"Print overall QPS divided by the number of threads in "
|
||||
"the output table");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs).add(output_controls);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if ((dist_fn == std::string("mips")) && (data_type == std::string("float")))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float")))
|
||||
{
|
||||
metric = diskann::Metric::FAST_L2;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only l2/ cosine are "
|
||||
"supported in general, and mips/fast_l2 only for floating "
|
||||
"point data."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (dynamic && not tags)
|
||||
{
|
||||
std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0)
|
||||
{
|
||||
std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && query_filters_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> query_filters;
|
||||
if (filter_label != "")
|
||||
{
|
||||
query_filters.push_back(filter_label);
|
||||
}
|
||||
else if (query_filters_file != "")
|
||||
{
|
||||
query_filters = read_file_to_vector_of_strings(query_filters_file);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (!query_filters.empty() && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
{
|
||||
return search_memory_index<int8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
|
||||
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
return search_memory_index<uint8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
|
||||
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
{
|
||||
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
536
packages/leann-backend-diskann/third_party/DiskANN/apps/test_insert_deletes_consolidate.cpp
vendored
Normal file
536
packages/leann-backend-diskann/third_party/DiskANN/apps/test_insert_deletes_consolidate.cpp
vendored
Normal file
@@ -0,0 +1,536 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <timer.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <future>
|
||||
|
||||
#include "utils.h"
|
||||
#include "filter_utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
#include "index_factory.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
// load_aligned_bin modified to read pieces of the file, but using ifstream
|
||||
// instead of cached_ifstream.
|
||||
template <typename T>
|
||||
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
|
||||
{
|
||||
diskann::Timer timer;
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(bin_file, std::ios::binary | std::ios::ate);
|
||||
size_t actual_file_size = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
size_t npts = (uint32_t)npts_i32;
|
||||
size_t dim = (uint32_t)dim_i32;
|
||||
|
||||
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != expected_actual_file_size)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
|
||||
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
|
||||
<< std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (offset_points + points_to_read > npts)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
|
||||
<< " points, but have only " << npts << " points" << std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
|
||||
|
||||
const size_t rounded_dim = ROUND_UP(dim, 8);
|
||||
|
||||
for (size_t i = 0; i < points_to_read; i++)
|
||||
{
|
||||
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
|
||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
||||
}
|
||||
reader.close();
|
||||
|
||||
const double elapsedSeconds = timer.elapsed() / 1000000.0;
|
||||
std::cout << "Read " << points_to_read << " points using non-cached reads in " << elapsedSeconds << std::endl;
|
||||
}
|
||||
|
||||
std::string get_save_filename(const std::string &save_path, size_t points_to_skip, size_t points_deleted,
|
||||
size_t last_point_threshold)
|
||||
{
|
||||
std::string final_path = save_path;
|
||||
if (points_to_skip > 0)
|
||||
{
|
||||
final_path += "skip" + std::to_string(points_to_skip) + "-";
|
||||
}
|
||||
|
||||
final_path += "del" + std::to_string(points_deleted) + "-";
|
||||
final_path += std::to_string(last_point_threshold);
|
||||
return final_path;
|
||||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data,
|
||||
size_t aligned_dim, std::vector<std::vector<LabelT>> &location_to_labels)
|
||||
{
|
||||
diskann::Timer insert_timer;
|
||||
#pragma omp parallel for num_threads(thread_count) schedule(dynamic)
|
||||
for (int64_t j = start; j < (int64_t)end; j++)
|
||||
{
|
||||
if (!location_to_labels.empty())
|
||||
{
|
||||
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
|
||||
location_to_labels[j - start]);
|
||||
}
|
||||
else
|
||||
{
|
||||
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
|
||||
}
|
||||
}
|
||||
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
|
||||
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
|
||||
<< " points/second overall, " << (end - start) / elapsedSeconds / thread_count << " per thread)\n ";
|
||||
}
|
||||
|
||||
template <typename T, typename TagT>
|
||||
void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params,
|
||||
size_t points_to_skip, size_t points_to_delete_from_beginning)
|
||||
{
|
||||
try
|
||||
{
|
||||
std::cout << std::endl
|
||||
<< "Lazy deleting points " << points_to_skip << " to "
|
||||
<< points_to_skip + points_to_delete_from_beginning << "... ";
|
||||
for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i)
|
||||
index.lazy_delete(static_cast<TagT>(i + 1)); // Since tags are data location + 1
|
||||
std::cout << "done." << std::endl;
|
||||
|
||||
auto report = index.consolidate_deletes(delete_params);
|
||||
std::cout << "#active points: " << report._active_points << std::endl
|
||||
<< "max points: " << report._max_points << std::endl
|
||||
<< "empty slots: " << report._empty_slots << std::endl
|
||||
<< "deletes processed: " << report._slots_released << std::endl
|
||||
<< "latest delete size: " << report._delete_set_size << std::endl
|
||||
<< "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, "
|
||||
<< points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)"
|
||||
<< std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cout << "Exception caught in deletion thread: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters ¶ms, size_t points_to_skip,
|
||||
size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm,
|
||||
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
|
||||
const std::string &save_path, size_t points_to_delete_from_beginning,
|
||||
size_t start_deletes_after, bool concurrent, const std::string &label_file,
|
||||
const std::string &universal_label)
|
||||
{
|
||||
size_t dim, aligned_dim;
|
||||
size_t num_points;
|
||||
diskann::get_bin_metadata(data_path, num_points, dim);
|
||||
aligned_dim = ROUND_UP(dim, 8);
|
||||
bool has_labels = label_file != "";
|
||||
using TagT = uint32_t;
|
||||
using LabelT = uint32_t;
|
||||
|
||||
size_t current_point_offset = points_to_skip;
|
||||
const size_t last_point_threshold = points_to_skip + max_points_to_insert;
|
||||
|
||||
bool enable_tags = true;
|
||||
using TagT = uint32_t;
|
||||
auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads);
|
||||
diskann::IndexConfig index_config = diskann::IndexConfigBuilder()
|
||||
.with_metric(diskann::L2)
|
||||
.with_dimension(dim)
|
||||
.with_max_points(max_points_to_insert)
|
||||
.is_dynamic_index(true)
|
||||
.with_index_write_params(params)
|
||||
.with_index_search_params(index_search_params)
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.is_enable_tags(enable_tags)
|
||||
.is_filtered(has_labels)
|
||||
.with_num_frozen_pts(num_start_pts)
|
||||
.is_concurrent_consolidate(concurrent)
|
||||
.build();
|
||||
|
||||
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
|
||||
auto index = index_factory.create_instance();
|
||||
|
||||
if (universal_label != "")
|
||||
{
|
||||
LabelT u_label = 0;
|
||||
index->set_universal_label(u_label);
|
||||
}
|
||||
|
||||
if (points_to_skip > num_points)
|
||||
{
|
||||
throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (max_points_to_insert == 0)
|
||||
{
|
||||
max_points_to_insert = num_points;
|
||||
}
|
||||
|
||||
if (points_to_skip + max_points_to_insert > num_points)
|
||||
{
|
||||
max_points_to_insert = num_points - points_to_skip;
|
||||
std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
|
||||
if (beginning_index_size > max_points_to_insert)
|
||||
{
|
||||
beginning_index_size = max_points_to_insert;
|
||||
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
if (checkpoints_per_snapshot > 0 && beginning_index_size > points_per_checkpoint)
|
||||
{
|
||||
beginning_index_size = points_per_checkpoint;
|
||||
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size << std::endl;
|
||||
}
|
||||
|
||||
T *data = nullptr;
|
||||
diskann::alloc_aligned(
|
||||
(void **)&data, std::max(points_per_checkpoint, beginning_index_size) * aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
|
||||
std::vector<TagT> tags(beginning_index_size);
|
||||
std::iota(tags.begin(), tags.end(), 1 + static_cast<TagT>(current_point_offset));
|
||||
|
||||
load_aligned_bin_part(data_path, data, current_point_offset, beginning_index_size);
|
||||
std::cout << "load aligned bin succeeded" << std::endl;
|
||||
diskann::Timer timer;
|
||||
|
||||
if (beginning_index_size > 0)
|
||||
{
|
||||
index->build(data, beginning_index_size, tags);
|
||||
}
|
||||
else
|
||||
{
|
||||
index->set_start_points_at_random(static_cast<T>(start_point_norm));
|
||||
}
|
||||
|
||||
const double elapsedSeconds = timer.elapsed() / 1000000.0;
|
||||
std::cout << "Initial non-incremental index build time for " << beginning_index_size << " points took "
|
||||
<< elapsedSeconds << " seconds (" << beginning_index_size / elapsedSeconds << " points/second)\n ";
|
||||
|
||||
current_point_offset += beginning_index_size;
|
||||
|
||||
if (points_to_delete_from_beginning > max_points_to_insert)
|
||||
{
|
||||
points_to_delete_from_beginning = static_cast<uint32_t>(max_points_to_insert);
|
||||
std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::vector<LabelT>> location_to_labels;
|
||||
if (concurrent)
|
||||
{
|
||||
// handle labels
|
||||
const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
location_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
int32_t sub_threads = (params.num_threads + 1) / 2;
|
||||
bool delete_launched = false;
|
||||
std::future<void> delete_task;
|
||||
|
||||
diskann::Timer timer;
|
||||
|
||||
for (size_t start = current_point_offset; start < last_point_threshold;
|
||||
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
|
||||
{
|
||||
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, sub_threads, data, aligned_dim,
|
||||
location_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
if (!delete_launched && end >= start_deletes_after &&
|
||||
end >= points_to_skip + points_to_delete_from_beginning)
|
||||
{
|
||||
delete_launched = true;
|
||||
diskann::IndexWriteParameters delete_params =
|
||||
diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build();
|
||||
|
||||
delete_task = std::async(std::launch::async, [&]() {
|
||||
delete_from_beginning<T, TagT>(*index, delete_params, points_to_skip,
|
||||
points_to_delete_from_beginning);
|
||||
});
|
||||
}
|
||||
}
|
||||
delete_task.wait();
|
||||
|
||||
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
location_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
size_t last_snapshot_points_threshold = 0;
|
||||
size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot;
|
||||
|
||||
for (size_t start = current_point_offset; start < last_point_threshold;
|
||||
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
|
||||
{
|
||||
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, (int32_t)params.num_threads, data,
|
||||
aligned_dim, location_to_labels);
|
||||
|
||||
if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0)
|
||||
{
|
||||
diskann::Timer save_timer;
|
||||
|
||||
const auto save_path_inc =
|
||||
get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end);
|
||||
index->save(save_path_inc.c_str(), false);
|
||||
const double elapsedSeconds = save_timer.elapsed() / 1000000.0;
|
||||
const size_t points_saved = end - points_to_skip;
|
||||
|
||||
std::cout << "Saved " << points_saved << " points in " << elapsedSeconds << " seconds ("
|
||||
<< points_saved / elapsedSeconds << " points/second)\n";
|
||||
|
||||
num_checkpoints_till_snapshot = checkpoints_per_snapshot;
|
||||
last_snapshot_points_threshold = end;
|
||||
}
|
||||
|
||||
std::cout << "Number of points in the index post insertion " << end << std::endl;
|
||||
}
|
||||
|
||||
if (checkpoints_per_snapshot > 0 && last_snapshot_points_threshold != last_point_threshold)
|
||||
{
|
||||
const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
// index.save(save_path_inc.c_str(), false);
|
||||
}
|
||||
|
||||
if (points_to_delete_from_beginning > 0)
|
||||
{
|
||||
delete_from_beginning<T, TagT>(*index, params, points_to_skip, points_to_delete_from_beginning);
|
||||
}
|
||||
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
}
|
||||
|
||||
diskann::aligned_free(data);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix;
|
||||
uint32_t num_threads, R, L, num_start_pts;
|
||||
float alpha, start_point_norm;
|
||||
size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot,
|
||||
points_to_delete_from_beginning, start_deletes_after;
|
||||
bool concurrent;
|
||||
|
||||
// label options
|
||||
std::string label_file, label_type, universal_label;
|
||||
std::uint32_t Lf, unique_labels_supported;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate",
|
||||
"Test insert deletes & consolidate")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("points_to_skip", po::value<uint64_t>(&points_to_skip)->required(),
|
||||
"Skip these first set of points from file");
|
||||
required_configs.add_options()("beginning_index_size", po::value<uint64_t>(&beginning_index_size)->required(),
|
||||
"Batch build will be called on these set of points");
|
||||
required_configs.add_options()("points_per_checkpoint", po::value<uint64_t>(&points_per_checkpoint)->required(),
|
||||
"Insertions are done in batches of points_per_checkpoint");
|
||||
required_configs.add_options()("checkpoints_per_snapshot",
|
||||
po::value<uint64_t>(&checkpoints_per_snapshot)->required(),
|
||||
"Save the index to disk every few checkpoints");
|
||||
required_configs.add_options()("points_to_delete_from_beginning",
|
||||
po::value<uint64_t>(&points_to_delete_from_beginning)->required(), "");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("max_points_to_insert",
|
||||
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
|
||||
"These number of points from the file are inserted after "
|
||||
"points_to_skip");
|
||||
optional_configs.add_options()("do_concurrent", po::value<bool>(&concurrent)->default_value(false), "");
|
||||
optional_configs.add_options()("start_deletes_after",
|
||||
po::value<uint64_t>(&start_deletes_after)->default_value(0), "");
|
||||
optional_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->default_value(0),
|
||||
"Set the start point to a random point on a sphere of this radius");
|
||||
|
||||
// optional params for filters
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input label file in txt format for Filtered Index search. "
|
||||
"The file should contain comma separated filters for each node "
|
||||
"with each line corresponding to a graph node");
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with labels_file");
|
||||
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
"Build complexity for filtered points, higher value "
|
||||
"results in better graphs");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
optional_configs.add_options()("unique_labels_supported",
|
||||
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
|
||||
"Number of unique labels supported by the dynamic index.");
|
||||
|
||||
optional_configs.add_options()(
|
||||
"num_start_points",
|
||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (beginning_index_size == 0)
|
||||
if (start_point_norm == 0)
|
||||
{
|
||||
std::cout << "When beginning_index_size is 0, use a start "
|
||||
"point with "
|
||||
"appropriate norm"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool has_labels = false;
|
||||
if (!label_file.empty() || label_file != "")
|
||||
{
|
||||
has_labels = true;
|
||||
}
|
||||
|
||||
if (num_start_pts < unique_labels_supported)
|
||||
{
|
||||
num_start_pts = unique_labels_supported;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(500)
|
||||
.with_alpha(alpha)
|
||||
.with_num_threads(num_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
if (data_type == std::string("int8"))
|
||||
build_incremental_index<int8_t>(
|
||||
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
|
||||
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
|
||||
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
|
||||
else if (data_type == std::string("uint8"))
|
||||
build_incremental_index<uint8_t>(
|
||||
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
|
||||
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
|
||||
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
|
||||
else if (data_type == std::string("float"))
|
||||
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
|
||||
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
|
||||
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
|
||||
start_deletes_after, concurrent, label_file, universal_label);
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cerr << "Caught exception: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Caught unknown exception" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
523
packages/leann-backend-diskann/third_party/DiskANN/apps/test_streaming_scenario.cpp
vendored
Normal file
523
packages/leann-backend-diskann/third_party/DiskANN/apps/test_streaming_scenario.cpp
vendored
Normal file
@@ -0,0 +1,523 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <timer.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <future>
|
||||
#include <abstract_index.h>
|
||||
#include <index_factory.h>
|
||||
|
||||
#include "utils.h"
|
||||
#include "filter_utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
// load_aligned_bin modified to read pieces of the file, but using ifstream
|
||||
// instead of cached_ifstream.
|
||||
template <typename T>
|
||||
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(bin_file, std::ios::binary | std::ios::ate);
|
||||
size_t actual_file_size = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
size_t npts = (uint32_t)npts_i32;
|
||||
size_t dim = (uint32_t)dim_i32;
|
||||
|
||||
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != expected_actual_file_size)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
|
||||
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
|
||||
<< std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (offset_points + points_to_read > npts)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
|
||||
<< " points, but have only " << npts << " points" << std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
|
||||
|
||||
const size_t rounded_dim = ROUND_UP(dim, 8);
|
||||
|
||||
for (size_t i = 0; i < points_to_read; i++)
|
||||
{
|
||||
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
|
||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
||||
}
|
||||
reader.close();
|
||||
}
|
||||
|
||||
std::string get_save_filename(const std::string &save_path, size_t active_window, size_t consolidate_interval,
|
||||
size_t max_points_to_insert)
|
||||
{
|
||||
std::string final_path = save_path;
|
||||
final_path += "act" + std::to_string(active_window) + "-";
|
||||
final_path += "cons" + std::to_string(consolidate_interval) + "-";
|
||||
final_path += "max" + std::to_string(max_points_to_insert);
|
||||
return final_path;
|
||||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data,
|
||||
size_t aligned_dim, std::vector<std::vector<LabelT>> &pts_to_labels)
|
||||
{
|
||||
try
|
||||
{
|
||||
diskann::Timer insert_timer;
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
size_t num_failed = 0;
|
||||
#pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed)
|
||||
for (int64_t j = start; j < (int64_t)end; j++)
|
||||
{
|
||||
int insert_result = -1;
|
||||
if (pts_to_labels.size() > 0)
|
||||
{
|
||||
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
|
||||
pts_to_labels[j - start]);
|
||||
}
|
||||
else
|
||||
{
|
||||
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
|
||||
}
|
||||
|
||||
if (insert_result != 0)
|
||||
{
|
||||
std::cerr << "Insert failed " << j << std::endl;
|
||||
num_failed++;
|
||||
}
|
||||
}
|
||||
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
|
||||
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
|
||||
<< " points/second overall, " << (end - start) / elapsedSeconds / insert_threads << " per thread)"
|
||||
<< std::endl;
|
||||
if (num_failed > 0)
|
||||
std::cout << num_failed << " of " << end - start << "inserts failed" << std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cout << "Exiting after catching exception in insertion task: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
||||
void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start,
|
||||
size_t end)
|
||||
{
|
||||
try
|
||||
{
|
||||
std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... ";
|
||||
for (size_t i = start; i < end; ++i)
|
||||
index.lazy_delete(static_cast<TagT>(1 + i));
|
||||
std::cout << "lazy delete done." << std::endl;
|
||||
|
||||
auto report = index.consolidate_deletes(delete_params);
|
||||
while (report._status != diskann::consolidation_report::status_code::SUCCESS)
|
||||
{
|
||||
int wait_time = 5;
|
||||
if (report._status == diskann::consolidation_report::status_code::LOCK_FAIL)
|
||||
{
|
||||
diskann::cerr << "Unable to acquire consolidate delete lock after "
|
||||
<< "deleting points " << start << " to " << end << ". Will retry in " << wait_time
|
||||
<< "seconds." << std::endl;
|
||||
}
|
||||
else if (report._status == diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR)
|
||||
{
|
||||
diskann::cerr << "Inconsistent counts in data structure. "
|
||||
<< "Will retry in " << wait_time << "seconds." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Exiting after unknown error in consolidate delete" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(wait_time));
|
||||
report = index.consolidate_deletes(delete_params);
|
||||
}
|
||||
auto points_processed = report._active_points + report._slots_released;
|
||||
auto deletion_rate = points_processed / report._time;
|
||||
std::cout << "#active points: " << report._active_points << std::endl
|
||||
<< "max points: " << report._max_points << std::endl
|
||||
<< "empty slots: " << report._empty_slots << std::endl
|
||||
<< "deletes processed: " << report._slots_released << std::endl
|
||||
<< "latest delete size: " << report._delete_set_size << std::endl
|
||||
<< "Deletion rate: " << deletion_rate << "/sec "
|
||||
<< "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cerr << "Exiting after catching exception in deletion task: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
||||
void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha,
|
||||
const uint32_t insert_threads, const uint32_t consolidate_threads,
|
||||
size_t max_points_to_insert, size_t active_window, size_t consolidate_interval,
|
||||
const float start_point_norm, uint32_t num_start_pts, const std::string &save_path,
|
||||
const std::string &label_file, const std::string &universal_label, const uint32_t Lf)
|
||||
{
|
||||
const uint32_t C = 500;
|
||||
const bool saturate_graph = false;
|
||||
bool has_labels = label_file != "";
|
||||
|
||||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(C)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_num_threads(insert_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
auto index_search_params = diskann::IndexSearchParams(L, insert_threads);
|
||||
diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(C)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_num_threads(consolidate_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
size_t dim, aligned_dim;
|
||||
size_t num_points;
|
||||
|
||||
std::vector<std::vector<LabelT>> pts_to_labels;
|
||||
|
||||
const auto save_path_inc =
|
||||
get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
pts_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
diskann::get_bin_metadata(data_path, num_points, dim);
|
||||
diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims"
|
||||
<< std::endl;
|
||||
aligned_dim = ROUND_UP(dim, 8);
|
||||
auto index_config = diskann::IndexConfigBuilder()
|
||||
.with_metric(diskann::L2)
|
||||
.with_dimension(dim)
|
||||
.with_max_points(active_window + 4 * consolidate_interval)
|
||||
.is_dynamic_index(true)
|
||||
.is_enable_tags(true)
|
||||
.is_use_opq(false)
|
||||
.is_filtered(has_labels)
|
||||
.with_num_pq_chunks(0)
|
||||
.is_pq_dist_build(false)
|
||||
.with_num_frozen_pts(num_start_pts)
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_index_write_params(params)
|
||||
.with_index_search_params(index_search_params)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.build();
|
||||
|
||||
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
|
||||
auto index = index_factory.create_instance();
|
||||
|
||||
if (universal_label != "")
|
||||
{
|
||||
LabelT u_label = 0;
|
||||
index->set_universal_label(u_label);
|
||||
}
|
||||
|
||||
if (max_points_to_insert == 0)
|
||||
{
|
||||
max_points_to_insert = num_points;
|
||||
}
|
||||
|
||||
if (num_points < max_points_to_insert)
|
||||
throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) +
|
||||
") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")",
|
||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
if (max_points_to_insert < active_window + consolidate_interval)
|
||||
throw diskann::ANNException("ERROR: max_points_to_insert < "
|
||||
"active_window + consolidate_interval",
|
||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
if (consolidate_interval < max_points_to_insert / 1000)
|
||||
throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
index->set_start_points_at_random(static_cast<T>(start_point_norm));
|
||||
|
||||
T *data = nullptr;
|
||||
diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T),
|
||||
8 * sizeof(T));
|
||||
|
||||
std::vector<TagT> tags(max_points_to_insert);
|
||||
std::iota(tags.begin(), tags.end(), static_cast<TagT>(0));
|
||||
|
||||
diskann::Timer timer;
|
||||
|
||||
std::vector<std::future<void>> delete_tasks;
|
||||
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, 0, active_window);
|
||||
insert_next_batch<T, TagT, LabelT>(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim,
|
||||
pts_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
for (size_t start = active_window; start + consolidate_interval <= max_points_to_insert;
|
||||
start += consolidate_interval)
|
||||
{
|
||||
auto end = std::min(start + consolidate_interval, max_points_to_insert);
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_next_batch<T, TagT, LabelT>(*index, start, end, params.num_threads, data, aligned_dim,
|
||||
pts_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
if (delete_tasks.size() > 0)
|
||||
delete_tasks[delete_tasks.size() - 1].wait();
|
||||
if (start >= active_window + consolidate_interval)
|
||||
{
|
||||
auto start_del = start - active_window - consolidate_interval;
|
||||
auto end_del = start - active_window;
|
||||
|
||||
delete_tasks.emplace_back(std::async(std::launch::async, [&]() {
|
||||
delete_and_consolidate<T, TagT, LabelT>(*index, delete_params, (size_t)start_del, (size_t)end_del);
|
||||
}));
|
||||
}
|
||||
}
|
||||
if (delete_tasks.size() > 0)
|
||||
delete_tasks[delete_tasks.size() - 1].wait();
|
||||
|
||||
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
|
||||
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
|
||||
diskann::aligned_free(data);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
|
||||
uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, unique_labels_supported;
|
||||
float alpha, start_point_norm;
|
||||
size_t max_points_to_insert, active_window, consolidate_interval;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario",
|
||||
"Test insert deletes & consolidate")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("active_window", po::value<uint64_t>(&active_window)->required(),
|
||||
"Program maintains an index over an active window of "
|
||||
"this size that slides through the data");
|
||||
required_configs.add_options()("consolidate_interval", po::value<uint64_t>(&consolidate_interval)->required(),
|
||||
"The program simultaneously adds this number of points to the "
|
||||
"right of "
|
||||
"the window while deleting the same number from the left");
|
||||
required_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->required(),
|
||||
"Set the start point to a random point on a sphere of this radius");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("insert_threads",
|
||||
po::value<uint32_t>(&insert_threads)->default_value(omp_get_num_procs() / 2),
|
||||
"Number of threads used for inserting into the index (defaults to "
|
||||
"omp_get_num_procs()/2)");
|
||||
optional_configs.add_options()(
|
||||
"consolidate_threads", po::value<uint32_t>(&consolidate_threads)->default_value(omp_get_num_procs() / 2),
|
||||
"Number of threads used for consolidating deletes to "
|
||||
"the index (defaults to omp_get_num_procs()/2)");
|
||||
optional_configs.add_options()("max_points_to_insert",
|
||||
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
|
||||
"The number of points from the file that the program streams "
|
||||
"over ");
|
||||
optional_configs.add_options()(
|
||||
"num_start_points",
|
||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input label file in txt format for Filtered Index search. "
|
||||
"The file should contain comma separated filters for each node "
|
||||
"with each line corresponding to a graph node");
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with labels_file");
|
||||
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
"Build complexity for filtered points, higher value "
|
||||
"results in better graphs");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
optional_configs.add_options()("unique_labels_supported",
|
||||
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
|
||||
"Number of unique labels supported by the dynamic index.");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Validate arguments
|
||||
if (start_point_norm == 0)
|
||||
{
|
||||
std::cout << "When beginning_index_size is 0, use a start point with "
|
||||
"appropriate norm"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (label_type != std::string("ushort") && label_type != std::string("uint"))
|
||||
{
|
||||
std::cerr << "Invalid label type. Supported types are uint and ushort" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("int8") && data_type != std::string("uint8") && data_type != std::string("float"))
|
||||
{
|
||||
std::cerr << "Invalid data type. Supported types are int8, uint8 and float" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// TODO: Are additional distance functions supported?
|
||||
if (dist_fn != std::string("l2") && dist_fn != std::string("mips"))
|
||||
{
|
||||
std::cerr << "Invalid distance function. Supported functions are l2 and mips" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_start_pts < unique_labels_supported)
|
||||
{
|
||||
num_start_pts = unique_labels_supported;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("uint8"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<uint8_t, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<uint8_t, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<int8_t, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<int8_t, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<float, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<float, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cerr << "Caught exception: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Caught unknown exception" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
110
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/CMakeLists.txt
vendored
Normal file
110
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
|
||||
|
||||
|
||||
add_executable(fvecs_to_bin fvecs_to_bin.cpp)
|
||||
|
||||
add_executable(fvecs_to_bvecs fvecs_to_bvecs.cpp)
|
||||
|
||||
add_executable(rand_data_gen rand_data_gen.cpp)
|
||||
target_link_libraries(rand_data_gen ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(float_bin_to_int8 float_bin_to_int8.cpp)
|
||||
|
||||
add_executable(ivecs_to_bin ivecs_to_bin.cpp)
|
||||
|
||||
add_executable(count_bfs_levels count_bfs_levels.cpp)
|
||||
target_link_libraries(count_bfs_levels ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(tsv_to_bin tsv_to_bin.cpp)
|
||||
|
||||
add_executable(bin_to_tsv bin_to_tsv.cpp)
|
||||
|
||||
add_executable(int8_to_float int8_to_float.cpp)
|
||||
target_link_libraries(int8_to_float ${PROJECT_NAME})
|
||||
|
||||
add_executable(int8_to_float_scale int8_to_float_scale.cpp)
|
||||
target_link_libraries(int8_to_float_scale ${PROJECT_NAME})
|
||||
|
||||
add_executable(uint8_to_float uint8_to_float.cpp)
|
||||
target_link_libraries(uint8_to_float ${PROJECT_NAME})
|
||||
|
||||
add_executable(uint32_to_uint8 uint32_to_uint8.cpp)
|
||||
target_link_libraries(uint32_to_uint8 ${PROJECT_NAME})
|
||||
|
||||
add_executable(vector_analysis vector_analysis.cpp)
|
||||
target_link_libraries(vector_analysis ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(gen_random_slice gen_random_slice.cpp)
|
||||
target_link_libraries(gen_random_slice ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(simulate_aggregate_recall simulate_aggregate_recall.cpp)
|
||||
|
||||
add_executable(calculate_recall calculate_recall.cpp)
|
||||
target_link_libraries(calculate_recall ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
# Compute ground truth thing outside of DiskANN main source that depends on MKL.
|
||||
add_executable(compute_groundtruth compute_groundtruth.cpp)
|
||||
target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
|
||||
target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp)
|
||||
target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
|
||||
target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
|
||||
add_executable(generate_pq generate_pq.cpp)
|
||||
target_link_libraries(generate_pq ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
|
||||
add_executable(partition_data partition_data.cpp)
|
||||
target_link_libraries(partition_data ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(partition_with_ram_budget partition_with_ram_budget.cpp)
|
||||
target_link_libraries(partition_with_ram_budget ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(merge_shards merge_shards.cpp)
|
||||
target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB})
|
||||
|
||||
add_executable(create_disk_layout create_disk_layout.cpp)
|
||||
target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(generate_synthetic_labels generate_synthetic_labels.cpp)
|
||||
target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(stats_label_data stats_label_data.cpp)
|
||||
target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
if (NOT MSVC)
|
||||
include(GNUInstallDirs)
|
||||
install(TARGETS fvecs_to_bin
|
||||
fvecs_to_bvecs
|
||||
rand_data_gen
|
||||
float_bin_to_int8
|
||||
ivecs_to_bin
|
||||
count_bfs_levels
|
||||
tsv_to_bin
|
||||
bin_to_tsv
|
||||
int8_to_float
|
||||
int8_to_float_scale
|
||||
uint8_to_float
|
||||
uint32_to_uint8
|
||||
vector_analysis
|
||||
gen_random_slice
|
||||
simulate_aggregate_recall
|
||||
calculate_recall
|
||||
compute_groundtruth
|
||||
compute_groundtruth_for_filters
|
||||
generate_pq
|
||||
partition_data
|
||||
partition_with_ram_budget
|
||||
merge_shards
|
||||
create_disk_layout
|
||||
generate_synthetic_labels
|
||||
stats_label_data
|
||||
RUNTIME
|
||||
)
|
||||
endif()
|
||||
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/bin_to_fvecs.cpp
vendored
Normal file
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/bin_to_fvecs.cpp
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "util.h"
|
||||
|
||||
void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts,
|
||||
uint64_t ndims)
|
||||
{
|
||||
writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned)));
|
||||
#pragma omp parallel for
|
||||
for (uint64_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
|
||||
}
|
||||
readr.read((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_bin output_fvecs" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream readr(argv[1], std::ios::binary);
|
||||
int npts_s32;
|
||||
int ndims_s32;
|
||||
readr.read((char *)&npts_s32, sizeof(int32_t));
|
||||
readr.read((char *)&ndims_s32, sizeof(int32_t));
|
||||
size_t npts = npts_s32;
|
||||
size_t ndims = ndims_s32;
|
||||
uint32_t ndims_u32 = (uint32_t)ndims_s32;
|
||||
// uint64_t fsize = writr.tellg();
|
||||
readr.seekg(0, std::ios::beg);
|
||||
|
||||
unsigned ndims_u32;
|
||||
writr.write((char *)&ndims_u32, sizeof(unsigned));
|
||||
writr.seekg(0, std::ios::beg);
|
||||
uint64_t ndims = (uint64_t)ndims_u32;
|
||||
uint64_t npts = fsize / ((ndims + 1) * sizeof(float));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
uint64_t blk_size = 131072;
|
||||
uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
|
||||
std::ofstream writr(argv[2], std::ios::binary);
|
||||
float *read_buf = new float[npts * (ndims + 1)];
|
||||
float *write_buf = new float[npts * ndims];
|
||||
for (uint64_t i = 0; i < nblks; i++)
|
||||
{
|
||||
uint64_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writr.close();
|
||||
readr.close();
|
||||
}
|
||||
69
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/bin_to_tsv.cpp
vendored
Normal file
69
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/bin_to_tsv.cpp
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
template <class T>
|
||||
void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(float));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
writer << read_buf[d + i * ndims];
|
||||
if (d < ndims - 1)
|
||||
writer << "\t";
|
||||
else
|
||||
writer << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <float/int8/uint8> input_bin output_tsv" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::string type_string(argv[1]);
|
||||
if ((type_string != std::string("float")) && (type_string != std::string("int8")) &&
|
||||
(type_string != std::string("uin8")))
|
||||
{
|
||||
std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl;
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[2], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[3]);
|
||||
char *read_buf = new char[blk_size * ndims * 4];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (type_string == std::string("float"))
|
||||
block_convert<float>(writer, reader, (float *)read_buf, cblk_size, ndims);
|
||||
else if (type_string == std::string("int8"))
|
||||
block_convert<int8_t>(writer, reader, (int8_t *)read_buf, cblk_size, ndims);
|
||||
else if (type_string == std::string("uint8"))
|
||||
block_convert<uint8_t>(writer, reader, (uint8_t *)read_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
55
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/calculate_recall.cpp
vendored
Normal file
55
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/calculate_recall.cpp
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <ground_truth_bin> <our_results_bin> <r> " << std::endl;
|
||||
return -1;
|
||||
}
|
||||
uint32_t *gold_std = NULL;
|
||||
float *gs_dist = nullptr;
|
||||
uint32_t *our_results = NULL;
|
||||
float *or_dist = nullptr;
|
||||
size_t points_num, points_num_gs, points_num_or;
|
||||
size_t dim_gs;
|
||||
size_t dim_or;
|
||||
diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs);
|
||||
diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or);
|
||||
|
||||
if (points_num_gs != points_num_or)
|
||||
{
|
||||
std::cout << "Error. Number of queries mismatch in ground truth and "
|
||||
"our results"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
points_num = points_num_gs;
|
||||
|
||||
uint32_t recall_at = std::atoi(argv[3]);
|
||||
|
||||
if ((dim_or < recall_at) || (recall_at > dim_gs))
|
||||
{
|
||||
std::cout << "ground truth has size " << dim_gs << "; our set has " << dim_or << " points. Asking for recall "
|
||||
<< recall_at << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "Calculating recall@" << recall_at << std::endl;
|
||||
double recall_val = diskann::calculate_recall((uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs,
|
||||
our_results, (uint32_t)dim_or, (uint32_t)recall_at);
|
||||
|
||||
// double avg_recall = (recall*1.0)/(points_num*1.0);
|
||||
std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n";
|
||||
}
|
||||
574
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/compute_groundtruth.cpp
vendored
Normal file
574
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/compute_groundtruth.cpp
vendored
Normal file
@@ -0,0 +1,574 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <random>
|
||||
#include <limits>
|
||||
#include <cstring>
|
||||
#include <queue>
|
||||
#include <omp.h>
|
||||
#include <mkl.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <unordered_map>
|
||||
#include <tsl/robin_map.h>
|
||||
#include <tsl/robin_set.h>
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
#include "filter_utils.h"
|
||||
#include "utils.h"
|
||||
|
||||
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
|
||||
|
||||
#define PARTSIZE 10000000
|
||||
#define ALIGNMENT 512
|
||||
|
||||
// custom types (for readability)
|
||||
typedef tsl::robin_set<std::string> label_set;
|
||||
typedef std::string path;
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <class T> T div_round_up(const T numerator, const T denominator)
|
||||
{
|
||||
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
|
||||
}
|
||||
|
||||
using pairIF = std::pair<size_t, float>;
|
||||
struct cmpmaxstruct
|
||||
{
|
||||
bool operator()(const pairIF &l, const pairIF &r)
|
||||
{
|
||||
return l.second < r.second;
|
||||
};
|
||||
};
|
||||
|
||||
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
|
||||
|
||||
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
|
||||
{
|
||||
#ifdef _WINDOWS
|
||||
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
|
||||
#else
|
||||
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
|
||||
{
|
||||
return a.second < b.second;
|
||||
}
|
||||
|
||||
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
|
||||
{
|
||||
assert(points_l2sq != NULL);
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t d = 0; d < num_points; ++d)
|
||||
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
|
||||
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
|
||||
}
|
||||
|
||||
void distsq_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points,
|
||||
const float *const points_l2sq, // points in Col major
|
||||
size_t nqueries, const float *const queries,
|
||||
const float *const queries_l2sq, // queries in Col major
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
|
||||
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
|
||||
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void inner_prod_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void exact_knn(const size_t dim, const size_t k,
|
||||
size_t *const closest_points, // k * num_queries preallocated, col
|
||||
// major, queries columns
|
||||
float *const dist_closest_points, // k * num_queries
|
||||
// preallocated, Dist to
|
||||
// corresponding closes_points
|
||||
size_t npoints,
|
||||
float *points_in, // points in Col major
|
||||
size_t nqueries, float *queries_in,
|
||||
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
|
||||
{
|
||||
float *points_l2sq = new float[npoints];
|
||||
float *queries_l2sq = new float[nqueries];
|
||||
compute_l2sq(points_l2sq, points_in, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
|
||||
|
||||
float *points = points_in;
|
||||
float *queries = queries_in;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{ // we convert cosine distance as
|
||||
// normalized L2 distnace
|
||||
points = new float[npoints * dim];
|
||||
queries = new float[nqueries * dim];
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)npoints; i++)
|
||||
{
|
||||
float norm = std::sqrt(points_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
points[i * dim + j] = points_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)nqueries; i++)
|
||||
{
|
||||
float norm = std::sqrt(queries_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
queries[i * dim + j] = queries_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
// recalculate norms after normalizing, they should all be one.
|
||||
compute_l2sq(points_l2sq, points, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries, nqueries, dim);
|
||||
}
|
||||
|
||||
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
|
||||
<< dim << " dimensions using";
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
std::cout << " MIPS ";
|
||||
else if (metric == diskann::Metric::COSINE)
|
||||
std::cout << " Cosine ";
|
||||
else
|
||||
std::cout << " L2 ";
|
||||
std::cout << "distance fn. " << std::endl;
|
||||
|
||||
size_t q_batch_size = (1 << 9);
|
||||
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
|
||||
|
||||
for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
|
||||
{
|
||||
int64_t q_b = b * q_batch_size;
|
||||
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
|
||||
|
||||
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
|
||||
{
|
||||
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
|
||||
}
|
||||
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 16)
|
||||
for (long long q = q_b; q < q_e; q++)
|
||||
{
|
||||
maxPQIFCS point_dist;
|
||||
for (size_t p = 0; p < k; p++)
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
for (size_t p = k; p < npoints; p++)
|
||||
{
|
||||
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
if (point_dist.size() > k)
|
||||
point_dist.pop();
|
||||
}
|
||||
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
|
||||
{
|
||||
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
|
||||
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
|
||||
point_dist.pop();
|
||||
}
|
||||
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
|
||||
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
|
||||
}
|
||||
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
}
|
||||
|
||||
delete[] dist_matrix;
|
||||
|
||||
delete[] points_l2sq;
|
||||
delete[] queries_l2sq;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{
|
||||
delete[] points;
|
||||
delete[] queries;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> inline int get_num_parts(const char *filename)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
|
||||
reader.close();
|
||||
uint32_t num_parts =
|
||||
(npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
|
||||
std::cout << "Number of parts: " << num_parts << std::endl;
|
||||
return num_parts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts = end_id - start_id;
|
||||
ndims = (uint64_t)ndims_i32;
|
||||
std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B"
|
||||
<< std::endl;
|
||||
|
||||
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
T *data_T = new T[npts * ndims];
|
||||
reader.read((char *)data_T, sizeof(T) * npts * ndims);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
data = aligned_malloc<float>(npts * ndims, ALIGNMENT);
|
||||
#pragma omp parallel for schedule(dynamic, 32768)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
for (int64_t j = 0; j < (int64_t)ndims; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndims + j];
|
||||
std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float." << std::endl;
|
||||
}
|
||||
|
||||
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
writer.open(filename, std::ios::binary | std::ios::out);
|
||||
std::cout << "Writing bin: " << filename << "\n";
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(T));
|
||||
writer.close();
|
||||
std::cout << "Finished writing bin" << std::endl;
|
||||
}
|
||||
|
||||
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
|
||||
"npts*dim dist-matrix) with npts = "
|
||||
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
|
||||
<< "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
|
||||
writer.write((char *)distances, npts * ndims * sizeof(float));
|
||||
writer.close();
|
||||
std::cout << "Finished writing truthset" << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
|
||||
size_t &nqueries, size_t &npoints,
|
||||
size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric,
|
||||
std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
float *base_data = nullptr;
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints ? k : npoints;
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
|
||||
metric);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (size_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
int aux_main(const std::string &base_file, const std::string &query_file, const std::string >_file, size_t k,
|
||||
const diskann::Metric &metric, const std::string &tags_file = std::string(""))
|
||||
{
|
||||
size_t npoints, nqueries, dim;
|
||||
|
||||
float *query_data;
|
||||
|
||||
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
|
||||
if (nqueries > PARTSIZE)
|
||||
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
|
||||
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
|
||||
|
||||
// load tags
|
||||
const bool tags_enabled = tags_file.empty() ? false : true;
|
||||
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
|
||||
|
||||
int *closest_points = new int[nqueries * k];
|
||||
float *dist_closest_points = new float[nqueries * k];
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> results =
|
||||
processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
|
||||
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
|
||||
size_t j = 0;
|
||||
for (auto iter : cur_res)
|
||||
{
|
||||
if (j == k)
|
||||
break;
|
||||
if (tags_enabled)
|
||||
{
|
||||
std::uint32_t index_with_tag = location_to_tag[iter.first];
|
||||
closest_points[i * k + j] = (int32_t)index_with_tag;
|
||||
}
|
||||
else
|
||||
{
|
||||
closest_points[i * k + j] = (int32_t)iter.first;
|
||||
}
|
||||
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
dist_closest_points[i * k + j] = -iter.second;
|
||||
else
|
||||
dist_closest_points[i * k + j] = iter.second;
|
||||
|
||||
++j;
|
||||
}
|
||||
if (j < k)
|
||||
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
|
||||
delete[] closest_points;
|
||||
delete[] dist_closest_points;
|
||||
diskann::aligned_free(query_data);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
|
||||
{
|
||||
size_t read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream reader(bin_file, read_blk_size);
|
||||
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
|
||||
size_t actual_file_size = reader.get_file_size();
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
npts = (uint32_t)npts_i32;
|
||||
dim = (uint32_t)dim_i32;
|
||||
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
|
||||
|
||||
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
|
||||
// only ids, -1 is error
|
||||
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_with_dists)
|
||||
truthset_type = 1;
|
||||
|
||||
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_just_ids)
|
||||
truthset_type = 2;
|
||||
|
||||
if (truthset_type == -1)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. File should have bin format, with "
|
||||
"npts followed by ngt followed by npts*ngt ids and optionally "
|
||||
"followed by npts*ngt distance values; actual size: "
|
||||
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
|
||||
<< expected_file_size_just_ids;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
ids = new uint32_t[npts * dim];
|
||||
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
|
||||
|
||||
if (truthset_type == 1)
|
||||
{
|
||||
dists = new float[npts * dim];
|
||||
reader.read((char *)dists, npts * dim * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file;
|
||||
uint64_t K;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
"distance function <l2/mips/cosine>");
|
||||
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
|
||||
"File containing the base vectors in binary format");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the query vectors in binary format");
|
||||
desc.add_options()("gt_file", po::value<std::string>(>_file)->required(),
|
||||
"File name for the writing ground truth in binary "
|
||||
"format, please don' append .bin at end if "
|
||||
"no filter_label or filter_label_file is provided it "
|
||||
"will save the file with '.bin' at end."
|
||||
"else it will save the file as filename_label.bin");
|
||||
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
|
||||
"Number of ground truth nearest neighbors to compute");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"File containing the tags in binary format");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
919
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/compute_groundtruth_for_filters.cpp
vendored
Normal file
919
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/compute_groundtruth_for_filters.cpp
vendored
Normal file
@@ -0,0 +1,919 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <random>
|
||||
#include <limits>
|
||||
#include <cstring>
|
||||
#include <queue>
|
||||
#include <omp.h>
|
||||
#include <mkl.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <unordered_map>
|
||||
#include <tsl/robin_map.h>
|
||||
#include <tsl/robin_set.h>
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
|
||||
#include "filter_utils.h"
|
||||
#include "utils.h"
|
||||
|
||||
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
|
||||
|
||||
#define PARTSIZE 10000000
|
||||
#define ALIGNMENT 512
|
||||
|
||||
// custom types (for readability)
|
||||
typedef tsl::robin_set<std::string> label_set;
|
||||
typedef std::string path;
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <class T> T div_round_up(const T numerator, const T denominator)
|
||||
{
|
||||
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
|
||||
}
|
||||
|
||||
using pairIF = std::pair<size_t, float>;
|
||||
struct cmpmaxstruct
|
||||
{
|
||||
bool operator()(const pairIF &l, const pairIF &r)
|
||||
{
|
||||
return l.second < r.second;
|
||||
};
|
||||
};
|
||||
|
||||
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
|
||||
|
||||
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
|
||||
{
|
||||
#ifdef _WINDOWS
|
||||
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
|
||||
#else
|
||||
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
|
||||
{
|
||||
return a.second < b.second;
|
||||
}
|
||||
|
||||
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
|
||||
{
|
||||
assert(points_l2sq != NULL);
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t d = 0; d < num_points; ++d)
|
||||
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
|
||||
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
|
||||
}
|
||||
|
||||
void distsq_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points,
|
||||
const float *const points_l2sq, // points in Col major
|
||||
size_t nqueries, const float *const queries,
|
||||
const float *const queries_l2sq, // queries in Col major
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
|
||||
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
|
||||
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void inner_prod_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void exact_knn(const size_t dim, const size_t k,
|
||||
size_t *const closest_points, // k * num_queries preallocated, col
|
||||
// major, queries columns
|
||||
float *const dist_closest_points, // k * num_queries
|
||||
// preallocated, Dist to
|
||||
// corresponding closes_points
|
||||
size_t npoints,
|
||||
float *points_in, // points in Col major
|
||||
size_t nqueries, float *queries_in,
|
||||
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
|
||||
{
|
||||
float *points_l2sq = new float[npoints];
|
||||
float *queries_l2sq = new float[nqueries];
|
||||
compute_l2sq(points_l2sq, points_in, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
|
||||
|
||||
float *points = points_in;
|
||||
float *queries = queries_in;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{ // we convert cosine distance as
|
||||
// normalized L2 distnace
|
||||
points = new float[npoints * dim];
|
||||
queries = new float[nqueries * dim];
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)npoints; i++)
|
||||
{
|
||||
float norm = std::sqrt(points_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
points[i * dim + j] = points_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)nqueries; i++)
|
||||
{
|
||||
float norm = std::sqrt(queries_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
queries[i * dim + j] = queries_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
// recalculate norms after normalizing, they should all be one.
|
||||
compute_l2sq(points_l2sq, points, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries, nqueries, dim);
|
||||
}
|
||||
|
||||
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
|
||||
<< dim << " dimensions using";
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
std::cout << " MIPS ";
|
||||
else if (metric == diskann::Metric::COSINE)
|
||||
std::cout << " Cosine ";
|
||||
else
|
||||
std::cout << " L2 ";
|
||||
std::cout << "distance fn. " << std::endl;
|
||||
|
||||
size_t q_batch_size = (1 << 9);
|
||||
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
|
||||
|
||||
for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
|
||||
{
|
||||
int64_t q_b = b * q_batch_size;
|
||||
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
|
||||
|
||||
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
|
||||
{
|
||||
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
|
||||
}
|
||||
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 16)
|
||||
for (long long q = q_b; q < q_e; q++)
|
||||
{
|
||||
maxPQIFCS point_dist;
|
||||
for (size_t p = 0; p < k; p++)
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
for (size_t p = k; p < npoints; p++)
|
||||
{
|
||||
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
if (point_dist.size() > k)
|
||||
point_dist.pop();
|
||||
}
|
||||
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
|
||||
{
|
||||
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
|
||||
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
|
||||
point_dist.pop();
|
||||
}
|
||||
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
|
||||
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
|
||||
}
|
||||
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
}
|
||||
|
||||
delete[] dist_matrix;
|
||||
|
||||
delete[] points_l2sq;
|
||||
delete[] queries_l2sq;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{
|
||||
delete[] points;
|
||||
delete[] queries;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> inline int get_num_parts(const char *filename)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
|
||||
reader.close();
|
||||
int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
|
||||
std::cout << "Number of parts: " << num_parts << std::endl;
|
||||
return num_parts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts_u64 = end_id - start_id;
|
||||
ndims_u64 = (uint64_t)ndims_i32;
|
||||
std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64
|
||||
<< ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl;
|
||||
|
||||
reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
T *data_T = new T[npts_u64 * ndims_u64];
|
||||
reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
data = aligned_malloc<float>(npts_u64 * ndims_u64, ALIGNMENT);
|
||||
#pragma omp parallel for schedule(dynamic, 32768)
|
||||
for (int64_t i = 0; i < (int64_t)npts_u64; i++)
|
||||
{
|
||||
for (int64_t j = 0; j < (int64_t)ndims_u64; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndims_u64 + j];
|
||||
std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float." << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<size_t> load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims,
|
||||
int part_num, const char *label_file,
|
||||
const std::string &filter_label,
|
||||
const std::string &universal_label, size_t &npoints_filt,
|
||||
std::vector<std::vector<std::string>> &pts_to_labels)
|
||||
{
|
||||
std::ifstream reader(filename, std::ios::binary);
|
||||
if (reader.fail())
|
||||
{
|
||||
throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
|
||||
}
|
||||
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
std::vector<size_t> rev_map;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts = end_id - start_id;
|
||||
ndims = (uint32_t)ndims_i32;
|
||||
uint64_t nptsuint64_t = (uint64_t)npts;
|
||||
uint64_t ndimsuint64_t = (uint64_t)ndims;
|
||||
npoints_filt = 0;
|
||||
std::cout << "#pts in part = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl;
|
||||
std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl;
|
||||
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
|
||||
T *data_T = new T[nptsuint64_t * ndimsuint64_t];
|
||||
reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
|
||||
data = aligned_malloc<float>(nptsuint64_t * ndimsuint64_t, ALIGNMENT);
|
||||
|
||||
for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++)
|
||||
{
|
||||
if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) !=
|
||||
pts_to_labels[start_id + i].end() ||
|
||||
std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) !=
|
||||
pts_to_labels[start_id + i].end())
|
||||
{
|
||||
rev_map.push_back(start_id + i);
|
||||
for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndimsuint64_t + j];
|
||||
std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
npoints_filt++;
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float.. identified " << npoints_filt
|
||||
<< " points matching the filter." << std::endl;
|
||||
return rev_map;
|
||||
}
|
||||
|
||||
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
writer.open(filename, std::ios::binary | std::ios::out);
|
||||
std::cout << "Writing bin: " << filename << "\n";
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(T));
|
||||
writer.close();
|
||||
std::cout << "Finished writing bin" << std::endl;
|
||||
}
|
||||
|
||||
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
|
||||
"npts*dim dist-matrix) with npts = "
|
||||
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
|
||||
<< "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
|
||||
writer.write((char *)distances, npts * ndims * sizeof(float));
|
||||
writer.close();
|
||||
std::cout << "Finished writing truthset" << std::endl;
|
||||
}
|
||||
|
||||
inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file,
|
||||
std::vector<std::vector<std::string>> &pts_to_labels)
|
||||
{
|
||||
std::ifstream infile(map_file);
|
||||
std::string line, token;
|
||||
std::set<std::string> labels;
|
||||
infile.clear();
|
||||
infile.seekg(0, std::ios::beg);
|
||||
while (std::getline(infile, line))
|
||||
{
|
||||
std::istringstream iss(line);
|
||||
std::vector<std::string> lbls(0);
|
||||
|
||||
getline(iss, token, '\t');
|
||||
std::istringstream new_iss(token);
|
||||
while (getline(new_iss, token, ','))
|
||||
{
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
lbls.push_back(token);
|
||||
labels.insert(token);
|
||||
}
|
||||
std::sort(lbls.begin(), lbls.end());
|
||||
pts_to_labels.push_back(lbls);
|
||||
}
|
||||
std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for "
|
||||
<< pts_to_labels.size() << " points" << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
|
||||
size_t &nqueries, size_t &npoints,
|
||||
size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric,
|
||||
std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
float *base_data = nullptr;
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints ? k : npoints;
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
|
||||
metric);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processFilteredParts(
|
||||
const std::string &base_file, const std::string &label_file, const std::string &filter_label,
|
||||
const std::string &universal_label, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric, std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
size_t npoints_filt = 0;
|
||||
float *base_data = nullptr;
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
|
||||
std::vector<std::vector<std::string>> pts_to_labels;
|
||||
if (filter_label != "")
|
||||
parse_label_file_into_vec(npoints, label_file, pts_to_labels);
|
||||
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
std::vector<size_t> rev_map;
|
||||
if (filter_label != "")
|
||||
rev_map = load_filtered_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(),
|
||||
filter_label, universal_label, npoints_filt, pts_to_labels);
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints_filt ? k : npoints_filt;
|
||||
if (npoints_filt > 0)
|
||||
{
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries,
|
||||
query_data, metric);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file,
|
||||
const std::string >_file, size_t k, const std::string &universal_label, const diskann::Metric &metric,
|
||||
const std::string &filter_label, const std::string &tags_file = std::string(""))
|
||||
{
|
||||
size_t npoints, nqueries, dim;
|
||||
|
||||
float *query_data = nullptr;
|
||||
|
||||
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
|
||||
if (nqueries > PARTSIZE)
|
||||
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
|
||||
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
|
||||
|
||||
// load tags
|
||||
const bool tags_enabled = tags_file.empty() ? false : true;
|
||||
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
|
||||
|
||||
int *closest_points = new int[nqueries * k];
|
||||
float *dist_closest_points = new float[nqueries * k];
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> results;
|
||||
if (filter_label == "")
|
||||
{
|
||||
results = processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
|
||||
}
|
||||
else
|
||||
{
|
||||
results = processFilteredParts<T>(base_file, label_file, filter_label, universal_label, nqueries, npoints, dim,
|
||||
k, query_data, metric, location_to_tag);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
|
||||
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
|
||||
size_t j = 0;
|
||||
for (auto iter : cur_res)
|
||||
{
|
||||
if (j == k)
|
||||
break;
|
||||
if (tags_enabled)
|
||||
{
|
||||
std::uint32_t index_with_tag = location_to_tag[iter.first];
|
||||
closest_points[i * k + j] = (int32_t)index_with_tag;
|
||||
}
|
||||
else
|
||||
{
|
||||
closest_points[i * k + j] = (int32_t)iter.first;
|
||||
}
|
||||
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
dist_closest_points[i * k + j] = -iter.second;
|
||||
else
|
||||
dist_closest_points[i * k + j] = iter.second;
|
||||
|
||||
++j;
|
||||
}
|
||||
if (j < k)
|
||||
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
|
||||
delete[] closest_points;
|
||||
delete[] dist_closest_points;
|
||||
diskann::aligned_free(query_data);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
|
||||
{
|
||||
size_t read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream reader(bin_file, read_blk_size);
|
||||
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
|
||||
size_t actual_file_size = reader.get_file_size();
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
npts = (uint32_t)npts_i32;
|
||||
dim = (uint32_t)dim_i32;
|
||||
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
|
||||
|
||||
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
|
||||
// only ids, -1 is error
|
||||
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_with_dists)
|
||||
truthset_type = 1;
|
||||
|
||||
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_just_ids)
|
||||
truthset_type = 2;
|
||||
|
||||
if (truthset_type == -1)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. File should have bin format, with "
|
||||
"npts followed by ngt followed by npts*ngt ids and optionally "
|
||||
"followed by npts*ngt distance values; actual size: "
|
||||
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
|
||||
<< expected_file_size_just_ids;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
ids = new uint32_t[npts * dim];
|
||||
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
|
||||
|
||||
if (truthset_type == 1)
|
||||
{
|
||||
dists = new float[npts * dim];
|
||||
reader.read((char *)dists, npts * dim * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label,
|
||||
universal_label, filter_label_file;
|
||||
uint64_t K;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(), "distance function <l2/mips>");
|
||||
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
|
||||
"File containing the base vectors in binary format");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the query vectors in binary format");
|
||||
desc.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input labels file in txt format if present");
|
||||
desc.add_options()("filter_label", po::value<std::string>(&filter_label)->default_value(""),
|
||||
"Input filter label if doing filtered groundtruth");
|
||||
desc.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with label_file");
|
||||
desc.add_options()("gt_file", po::value<std::string>(>_file)->required(),
|
||||
"File name for the writing ground truth in binary "
|
||||
"format, please don' append .bin at end if "
|
||||
"no filter_label or filter_label_file is provided it "
|
||||
"will save the file with '.bin' at end."
|
||||
"else it will save the file as filename_label.bin");
|
||||
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
|
||||
"Number of ground truth nearest neighbors to compute");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"File containing the tags in binary format");
|
||||
desc.add_options()("filter_label_file",
|
||||
po::value<std::string>(&filter_label_file)->default_value(std::string("")),
|
||||
"Filter file for Queries for Filtered Search ");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && filter_label_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> filter_labels;
|
||||
if (filter_label != "")
|
||||
{
|
||||
filter_labels.push_back(filter_label);
|
||||
}
|
||||
else if (filter_label_file != "")
|
||||
{
|
||||
filter_labels = read_file_to_vector_of_strings(filter_label_file, false);
|
||||
}
|
||||
|
||||
// only when there is no filter label or 1 filter label for all queries
|
||||
if (filter_labels.size() == 1)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{ // Each query has its own filter label
|
||||
// Split up data and query bins into label specific ones
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_queries;
|
||||
|
||||
label_set all_labels;
|
||||
for (size_t i = 0; i < filter_labels.size(); i++)
|
||||
{
|
||||
std::string label = filter_labels[i];
|
||||
all_labels.insert(label);
|
||||
|
||||
if (labels_to_number_of_queries.find(label) == labels_to_number_of_queries.end())
|
||||
{
|
||||
labels_to_number_of_queries[label] = 0;
|
||||
}
|
||||
labels_to_number_of_queries[label] += 1;
|
||||
}
|
||||
|
||||
size_t npoints;
|
||||
std::vector<std::vector<std::string>> point_to_labels;
|
||||
parse_label_file_into_vec(npoints, label_file, point_to_labels);
|
||||
std::vector<label_set> point_ids_to_labels(point_to_labels.size());
|
||||
std::vector<label_set> query_ids_to_labels(filter_labels.size());
|
||||
|
||||
for (size_t i = 0; i < point_to_labels.size(); i++)
|
||||
{
|
||||
for (size_t j = 0; j < point_to_labels[i].size(); j++)
|
||||
{
|
||||
std::string label = point_to_labels[i][j];
|
||||
if (all_labels.find(label) != all_labels.end())
|
||||
{
|
||||
point_ids_to_labels[i].insert(point_to_labels[i][j]);
|
||||
if (labels_to_number_of_points.find(label) == labels_to_number_of_points.end())
|
||||
{
|
||||
labels_to_number_of_points[label] = 0;
|
||||
}
|
||||
labels_to_number_of_points[label] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < filter_labels.size(); i++)
|
||||
{
|
||||
query_ids_to_labels[i].insert(filter_labels[i]);
|
||||
}
|
||||
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id;
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_query_id_to_orig_id;
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Invalid data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Generate label specific ground truths
|
||||
|
||||
try
|
||||
{
|
||||
for (const auto &label : all_labels)
|
||||
{
|
||||
std::string filtered_base_file = base_file + "_" + label;
|
||||
std::string filtered_query_file = query_file + "_" + label;
|
||||
std::string filtered_gt_file = gt_file + "_" + label;
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Combine the label specific ground truths to produce a single GT file
|
||||
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t gt_num, gt_dim;
|
||||
|
||||
std::vector<std::vector<int32_t>> final_gt_ids;
|
||||
std::vector<std::vector<float>> final_gt_dists;
|
||||
|
||||
uint32_t query_num = 0;
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
query_num += labels_to_number_of_queries[lbl];
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
final_gt_ids.push_back(std::vector<int32_t>(K));
|
||||
final_gt_dists.push_back(std::vector<float>(K));
|
||||
}
|
||||
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
std::string filtered_gt_file = gt_file + "_" + lbl;
|
||||
load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
|
||||
for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++)
|
||||
{
|
||||
uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i];
|
||||
for (uint64_t j = 0; j < K; j++)
|
||||
{
|
||||
final_gt_ids[orig_query_id][j] = label_id_to_orig_id[lbl][gt_ids[i * K + j]];
|
||||
final_gt_dists[orig_query_id][j] = gt_dists[i * K + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int32_t *closest_points = new int32_t[query_num * K];
|
||||
float *dist_closest_points = new float[query_num * K];
|
||||
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
for (uint32_t j = 0; j < K; j++)
|
||||
{
|
||||
closest_points[i * K + j] = final_gt_ids[i][j];
|
||||
dist_closest_points[i * K + j] = final_gt_dists[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, query_num, K);
|
||||
|
||||
// cleanup artifacts
|
||||
std::cout << "Cleaning up artifacts..." << std::endl;
|
||||
tsl::robin_set<std::string> paths_to_clean{gt_file, base_file, query_file};
|
||||
clean_up_artifacts(paths_to_clean, all_labels);
|
||||
}
|
||||
}
|
||||
82
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/count_bfs_levels.cpp
vendored
Normal file
82
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/count_bfs_levels.cpp
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <string.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "utils.h"
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T> void bfs_count(const std::string &index_path, uint32_t data_dims)
|
||||
{
|
||||
using TagT = uint32_t;
|
||||
using LabelT = uint32_t;
|
||||
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false, false,
|
||||
false, 0, false);
|
||||
std::cout << "Index class instantiated" << std::endl;
|
||||
index.load(index_path.c_str(), 1, 100);
|
||||
std::cout << "Index loaded" << std::endl;
|
||||
index.count_nodes_at_bfs_levels();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, index_path_prefix;
|
||||
uint32_t data_dims;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
"Path prefix to the index");
|
||||
desc.add_options()("data_dims", po::value<uint32_t>(&data_dims)->required(), "Dimensionality of the data");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
bfs_count<int8_t>(index_path_prefix, data_dims);
|
||||
else if (data_type == std::string("uint8"))
|
||||
bfs_count<uint8_t>(index_path_prefix, data_dims);
|
||||
if (data_type == std::string("float"))
|
||||
bfs_count<float>(index_path_prefix, data_dims);
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index BFS failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
48
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/create_disk_layout.cpp
vendored
Normal file
48
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/create_disk_layout.cpp
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
#include "cached_io.h"
|
||||
|
||||
template <typename T> int create_disk_layout(char **argv)
|
||||
{
|
||||
std::string base_file(argv[2]);
|
||||
std::string vamana_file(argv[3]);
|
||||
std::string output_file(argv[4]);
|
||||
diskann::create_disk_layout<T>(base_file, vamana_file, output_file);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " data_type <float/int8/uint8> data_bin "
|
||||
"vamana_index_file output_diskann_index_file"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int ret_val = -1;
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
ret_val = create_disk_layout<float>(argv);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
ret_val = create_disk_layout<int8_t>(argv);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
ret_val = create_disk_layout<uint8_t>(argv);
|
||||
else
|
||||
{
|
||||
std::cout << "unsupported type. use int8/uint8/float " << std::endl;
|
||||
ret_val = -2;
|
||||
}
|
||||
return ret_val;
|
||||
}
|
||||
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/float_bin_to_int8.cpp
vendored
Normal file
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/float_bin_to_int8.cpp
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts,
|
||||
size_t ndims, float bias, float scale)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(float));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
write_buf[d + i * ndims] = (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale));
|
||||
}
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[1], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new float[blk_size * ndims];
|
||||
auto write_buf = new int8_t[blk_size * ndims];
|
||||
float bias = (float)atof(argv[3]);
|
||||
float scale = (float)atof(argv[4]);
|
||||
|
||||
writer.write((char *)(&npts_u32), sizeof(uint32_t));
|
||||
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
95
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/fvecs_to_bin.cpp
vendored
Normal file
95
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/fvecs_to_bin.cpp
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
// Convert float types
|
||||
void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
// Convert byte types
|
||||
void block_convert_byte(std::ifstream &reader, std::ofstream &writer, uint8_t *read_buf, uint8_t *write_buf,
|
||||
size_t npts, size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t),
|
||||
ndims * sizeof(uint8_t));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <float/int8/uint8> input_vecs output_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int datasize = sizeof(float);
|
||||
|
||||
if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0)
|
||||
{
|
||||
datasize = sizeof(uint8_t);
|
||||
}
|
||||
else if (strcmp(argv[1], "float") != 0)
|
||||
{
|
||||
std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[3], std::ios::binary);
|
||||
int32_t npts_s32 = (int32_t)npts;
|
||||
int32_t ndims_s32 = (int32_t)ndims;
|
||||
writer.write((char *)&npts_s32, sizeof(int32_t));
|
||||
writer.write((char *)&ndims_s32, sizeof(int32_t));
|
||||
|
||||
size_t chunknpts = std::min(npts, blk_size);
|
||||
uint8_t *read_buf = new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))];
|
||||
uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize];
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (datasize == sizeof(float))
|
||||
{
|
||||
block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, cblk_size, ndims);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
}
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
56
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/fvecs_to_bvecs.cpp
vendored
Normal file
56
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/fvecs_to_bvecs.cpp
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(uint32_t));
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
write_buf[i * (ndims + 4) + 4 + d] = (uint8_t)read_buf[i * (ndims + 1) + 1 + d];
|
||||
}
|
||||
writer.write((char *)write_buf, npts * (ndims * 1 + 4));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims + 1) * sizeof(float));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new float[npts * (ndims + 1)];
|
||||
auto write_buf = new uint8_t[npts * (ndims + 4)];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
58
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/gen_random_slice.cpp
vendored
Normal file
58
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/gen_random_slice.cpp
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "partition.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <typeinfo>
|
||||
|
||||
template <typename T> int aux_main(char **argv)
|
||||
{
|
||||
std::string base_file(argv[2]);
|
||||
std::string output_prefix(argv[3]);
|
||||
float sampling_rate = (float)(std::atof(argv[4]));
|
||||
gen_random_slice<T>(base_file, output_prefix, sampling_rate);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " data_type [float/int8/uint8] base_bin_file "
|
||||
"sample_output_prefix sampling_probability"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
{
|
||||
aux_main<float>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
{
|
||||
aux_main<int8_t>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
{
|
||||
aux_main<uint8_t>(argv);
|
||||
}
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
70
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/generate_pq.cpp
vendored
Normal file
70
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/generate_pq.cpp
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "math_utils.h"
|
||||
#include "pq.h"
|
||||
#include "partition.h"
|
||||
|
||||
#define KMEANS_ITERS_FOR_PQ 15
|
||||
|
||||
template <typename T>
|
||||
bool generate_pq(const std::string &data_path, const std::string &index_prefix_path, const size_t num_pq_centers,
|
||||
const size_t num_pq_chunks, const float sampling_rate, const bool opq)
|
||||
{
|
||||
std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
|
||||
std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin";
|
||||
|
||||
// generates random sample and sets it to train_data and updates train_size
|
||||
size_t train_size, train_dim;
|
||||
float *train_data;
|
||||
gen_random_slice<T>(data_path, sampling_rate, train_data, train_size, train_dim);
|
||||
std::cout << "For computing pivots, loaded sample data of size " << train_size << std::endl;
|
||||
|
||||
if (opq)
|
||||
{
|
||||
diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
|
||||
(uint32_t)num_pq_chunks, pq_pivots_path, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
|
||||
(uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path);
|
||||
}
|
||||
diskann::generate_pq_data_from_pivots<T>(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks,
|
||||
pq_pivots_path, pq_compressed_vectors_path, true);
|
||||
|
||||
delete[] train_data;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 7)
|
||||
{
|
||||
std::cout << "Usage: \n"
|
||||
<< argv[0]
|
||||
<< " <data_type[float/uint8/int8]> <data_file[.bin]>"
|
||||
" <PQ_prefix_path> <target-bytes/data-point> "
|
||||
"<sampling_rate> <PQ(0)/OPQ(1)>"
|
||||
<< std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string index_prefix_path(argv[3]);
|
||||
const size_t num_pq_centers = 256;
|
||||
const size_t num_pq_chunks = (size_t)atoi(argv[4]);
|
||||
const float sampling_rate = (float)atof(argv[5]);
|
||||
const bool opq = atoi(argv[6]) == 0 ? false : true;
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
generate_pq<float>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
generate_pq<int8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
generate_pq<uint8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else
|
||||
std::cout << "Error. wrong file type" << std::endl;
|
||||
}
|
||||
}
|
||||
204
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/generate_synthetic_labels.cpp
vendored
Normal file
204
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/generate_synthetic_labels.cpp
vendored
Normal file
@@ -0,0 +1,204 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <math.h>
|
||||
#include <cmath>
|
||||
#include "utils.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
class ZipfDistribution
|
||||
{
|
||||
public:
|
||||
ZipfDistribution(uint64_t num_points, uint32_t num_labels)
|
||||
: num_labels(num_labels), num_points(num_points),
|
||||
uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0))
|
||||
{
|
||||
}
|
||||
|
||||
std::unordered_map<uint32_t, uint32_t> createDistributionMap()
|
||||
{
|
||||
std::unordered_map<uint32_t, uint32_t> map;
|
||||
uint32_t primary_label_freq = (uint32_t)ceil(num_points * distribution_factor);
|
||||
for (uint32_t i{1}; i < num_labels + 1; i++)
|
||||
{
|
||||
map[i] = (uint32_t)ceil(primary_label_freq / i);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
int writeDistribution(std::ofstream &outfile)
|
||||
{
|
||||
auto distribution_map = createDistributionMap();
|
||||
for (uint32_t i{0}; i < num_points; i++)
|
||||
{
|
||||
bool label_written = false;
|
||||
for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); it++)
|
||||
{
|
||||
auto label_selection_probability = std::bernoulli_distribution(distribution_factor / (double)it->first);
|
||||
if (label_selection_probability(rand_engine) && distribution_map[it->first] > 0)
|
||||
{
|
||||
if (label_written)
|
||||
{
|
||||
outfile << ',';
|
||||
}
|
||||
outfile << it->first;
|
||||
label_written = true;
|
||||
// remove label from map if we have used all labels
|
||||
distribution_map[it->first] -= 1;
|
||||
}
|
||||
}
|
||||
if (!label_written)
|
||||
{
|
||||
outfile << 0;
|
||||
}
|
||||
if (i < num_points - 1)
|
||||
{
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int writeDistribution(std::string filename)
|
||||
{
|
||||
std::ofstream outfile(filename);
|
||||
if (!outfile.is_open())
|
||||
{
|
||||
std::cerr << "Error: could not open output file " << filename << '\n';
|
||||
return -1;
|
||||
}
|
||||
writeDistribution(outfile);
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
private:
|
||||
const uint32_t num_labels;
|
||||
const uint64_t num_points;
|
||||
const double distribution_factor = 0.7;
|
||||
std::knuth_b rand_engine;
|
||||
const std::uniform_real_distribution<double> uniform_zero_to_one;
|
||||
};
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string output_file, distribution_type;
|
||||
uint32_t num_labels;
|
||||
uint64_t num_points;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("output_file,O", po::value<std::string>(&output_file)->required(),
|
||||
"Filename for saving the label file");
|
||||
desc.add_options()("num_points,N", po::value<uint64_t>(&num_points)->required(), "Number of points in dataset");
|
||||
desc.add_options()("num_labels,L", po::value<uint32_t>(&num_labels)->required(),
|
||||
"Number of unique labels, up to 5000");
|
||||
desc.add_options()("distribution_type,DT", po::value<std::string>(&distribution_type)->default_value("random"),
|
||||
"Distribution function for labels <random/zipf/one_per_point> defaults "
|
||||
"to random");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_labels > 5000)
|
||||
{
|
||||
std::cerr << "Error: num_labels must be 5000 or less" << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_points <= 0)
|
||||
{
|
||||
std::cerr << "Error: num_points must be greater than 0" << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "Generating synthetic labels for " << num_points << " points with " << num_labels << " unique labels"
|
||||
<< '\n';
|
||||
|
||||
try
|
||||
{
|
||||
std::ofstream outfile(output_file);
|
||||
if (!outfile.is_open())
|
||||
{
|
||||
std::cerr << "Error: could not open output file " << output_file << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (distribution_type == "zipf")
|
||||
{
|
||||
ZipfDistribution zipf(num_points, num_labels);
|
||||
zipf.writeDistribution(outfile);
|
||||
}
|
||||
else if (distribution_type == "random")
|
||||
{
|
||||
for (size_t i = 0; i < num_points; i++)
|
||||
{
|
||||
bool label_written = false;
|
||||
for (size_t j = 1; j <= num_labels; j++)
|
||||
{
|
||||
// 50% chance to assign each label
|
||||
if (rand() > (RAND_MAX / 2))
|
||||
{
|
||||
if (label_written)
|
||||
{
|
||||
outfile << ',';
|
||||
}
|
||||
outfile << j;
|
||||
label_written = true;
|
||||
}
|
||||
}
|
||||
if (!label_written)
|
||||
{
|
||||
outfile << 0;
|
||||
}
|
||||
if (i < num_points - 1)
|
||||
{
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (distribution_type == "one_per_point")
|
||||
{
|
||||
std::random_device rd; // obtain a random number from hardware
|
||||
std::mt19937 gen(rd()); // seed the generator
|
||||
std::uniform_int_distribution<> distr(0, num_labels); // define the range
|
||||
|
||||
for (size_t i = 0; i < num_points; i++)
|
||||
{
|
||||
outfile << distr(gen);
|
||||
if (i != num_points - 1)
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
if (outfile.is_open())
|
||||
{
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
std::cout << "Labels written to " << output_file << '\n';
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Label generation failed: " << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
23
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/int8_to_float.cpp
vendored
Normal file
23
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/int8_to_float.cpp
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int8_t *input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<int8_t>(argv[1], input, npts, nd);
|
||||
float *output = new float[npts * nd];
|
||||
diskann::convert_types<int8_t, float>(input, output, npts, nd);
|
||||
diskann::save_bin<float>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|
||||
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/int8_to_float_scale.cpp
vendored
Normal file
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/int8_to_float_scale.cpp
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts,
|
||||
size_t ndims, float bias, float scale)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(int8_t));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
write_buf[d + i * ndims] = (((float)read_buf[d + i * ndims] - bias) * scale);
|
||||
}
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: " << argv[0] << " input-int8.bin output-float.bin bias scale" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[1], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new int8_t[blk_size * ndims];
|
||||
auto write_buf = new float[blk_size * ndims];
|
||||
float bias = (float)atof(argv[3]);
|
||||
float scale = (float)atof(argv[4]);
|
||||
|
||||
writer.write((char *)(&npts_u32), sizeof(uint32_t));
|
||||
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
58
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/ivecs_to_bin.cpp
vendored
Normal file
58
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/ivecs_to_bin.cpp
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(uint32_t));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_ivecs output_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
int npts_s32 = (int)npts;
|
||||
int ndims_s32 = (int)ndims;
|
||||
writer.write((char *)&npts_s32, sizeof(int));
|
||||
writer.write((char *)&ndims_s32, sizeof(int));
|
||||
uint32_t *read_buf = new uint32_t[npts * (ndims + 1)];
|
||||
uint32_t *write_buf = new uint32_t[npts * ndims];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
42
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/merge_shards.cpp
vendored
Normal file
42
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/merge_shards.cpp
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "disk_utils.h"
|
||||
#include "cached_io.h"
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 9)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " vamana_index_prefix[1] vamana_index_suffix[2] "
|
||||
"idmaps_prefix[3] "
|
||||
"idmaps_suffix[4] n_shards[5] max_degree[6] "
|
||||
"output_vamana_path[7] "
|
||||
"output_medoids_path[8]"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::string vamana_prefix(argv[1]);
|
||||
std::string vamana_suffix(argv[2]);
|
||||
std::string idmaps_prefix(argv[3]);
|
||||
std::string idmaps_suffix(argv[4]);
|
||||
uint64_t nshards = (uint64_t)std::atoi(argv[5]);
|
||||
uint32_t max_degree = (uint64_t)std::atoi(argv[6]);
|
||||
std::string output_index(argv[7]);
|
||||
std::string output_medoids(argv[8]);
|
||||
|
||||
return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, idmaps_suffix, nshards, max_degree,
|
||||
output_index, output_medoids);
|
||||
}
|
||||
39
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/partition_data.cpp
vendored
Normal file
39
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/partition_data.cpp
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <math_utils.h>
|
||||
#include "cached_io.h"
|
||||
#include "partition.h"
|
||||
|
||||
// DEPRECATED: NEED TO REPROGRAM
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 7)
|
||||
{
|
||||
std::cout << "Usage:\n"
|
||||
<< argv[0]
|
||||
<< " datatype<int8/uint8/float> <data_path>"
|
||||
" <prefix_path> <sampling_rate> "
|
||||
" <num_partitions> <k_index>"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string prefix_path(argv[3]);
|
||||
const float sampling_rate = (float)atof(argv[4]);
|
||||
const size_t num_partitions = (size_t)std::atoi(argv[5]);
|
||||
const size_t max_reps = 15;
|
||||
const size_t k_index = (size_t)std::atoi(argv[6]);
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
partition<float>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
partition<int8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
partition<uint8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else
|
||||
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
|
||||
}
|
||||
39
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/partition_with_ram_budget.cpp
vendored
Normal file
39
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/partition_with_ram_budget.cpp
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <math_utils.h>
|
||||
#include "cached_io.h"
|
||||
#include "partition.h"
|
||||
|
||||
// DEPRECATED: NEED TO REPROGRAM
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 8)
|
||||
{
|
||||
std::cout << "Usage:\n"
|
||||
<< argv[0]
|
||||
<< " datatype<int8/uint8/float> <data_path>"
|
||||
" <prefix_path> <sampling_rate> "
|
||||
" <ram_budget(GB)> <graph_degree> <k_index>"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string prefix_path(argv[3]);
|
||||
const float sampling_rate = (float)atof(argv[4]);
|
||||
const double ram_budget = (double)std::atof(argv[5]);
|
||||
const size_t graph_degree = (size_t)std::atoi(argv[6]);
|
||||
const size_t k_index = (size_t)std::atoi(argv[7]);
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
partition_with_ram_budget<float>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
partition_with_ram_budget<int8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
partition_with_ram_budget<uint8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else
|
||||
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
|
||||
}
|
||||
237
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/rand_data_gen.cpp
vendored
Normal file
237
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/rand_data_gen.cpp
vendored
Normal file
@@ -0,0 +1,237 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include <cmath>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm,
|
||||
float rand_scale)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
std::uniform_real_distribution<> unif_dis(1.0, rand_scale);
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
float scale = 1.0f;
|
||||
if (rand_scale > 1.0f)
|
||||
scale = (float)unif_dis(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = scale * (float)normal_rand(gen);
|
||||
if (normalization)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
}
|
||||
|
||||
writer.write((char *)vec, ndims * sizeof(float));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
auto vec_T = new int8_t[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = (float)normal_rand(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
vec_T[d] = (int8_t)std::round(vec[d]);
|
||||
}
|
||||
|
||||
writer.write((char *)vec_T, ndims * sizeof(int8_t));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
delete[] vec_T;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
auto vec_T = new int8_t[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = (float)normal_rand(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
vec_T[d] = 128 + (int8_t)std::round(vec[d]);
|
||||
}
|
||||
|
||||
writer.write((char *)vec_T, ndims * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
delete[] vec_T;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, output_file;
|
||||
size_t ndims, npts;
|
||||
float norm, rand_scaling;
|
||||
bool normalization = false;
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("output_file", po::value<std::string>(&output_file)->required(),
|
||||
"File name for saving the random vectors");
|
||||
desc.add_options()("ndims,D", po::value<uint64_t>(&ndims)->required(), "Dimensoinality of the vector");
|
||||
desc.add_options()("npts,N", po::value<uint64_t>(&npts)->required(), "Number of vectors");
|
||||
desc.add_options()("norm", po::value<float>(&norm)->default_value(-1.0f),
|
||||
"Norm of the vectors (if not specified, vectors are not normalized)");
|
||||
desc.add_options()("rand_scaling", po::value<float>(&rand_scaling)->default_value(1.0f),
|
||||
"Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from "
|
||||
"[1, rand_scale]. Only applicable for floating point data");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (norm > 0.0)
|
||||
{
|
||||
normalization = true;
|
||||
}
|
||||
|
||||
if (rand_scaling < 1.0)
|
||||
{
|
||||
std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((rand_scaling > 1.0) && (normalization == true))
|
||||
{
|
||||
std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("int8") || data_type == std::string("uint8"))
|
||||
{
|
||||
if (norm > 127)
|
||||
{
|
||||
std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be "
|
||||
"greater "
|
||||
"than 127"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (rand_scaling > 1.0)
|
||||
{
|
||||
std::cout << "Data scaling only supported for floating point data." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
|
||||
writer.open(output_file, std::ios::binary);
|
||||
auto npts_u32 = (uint32_t)npts;
|
||||
auto ndims_u32 = (uint32_t)ndims;
|
||||
writer.write((char *)&npts_u32, sizeof(uint32_t));
|
||||
writer.write((char *)&ndims_u32, sizeof(uint32_t));
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
|
||||
int ret = 0;
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling);
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
ret = block_write_int8(writer, ndims, cblk_size, norm);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
ret = block_write_uint8(writer, ndims, cblk_size, norm);
|
||||
}
|
||||
if (ret == 0)
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
else
|
||||
{
|
||||
writer.close();
|
||||
std::cout << "failed to write" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
writer.close();
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
85
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/simulate_aggregate_recall.cpp
vendored
Normal file
85
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/simulate_aggregate_recall.cpp
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include <cmath>
|
||||
|
||||
inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, uint32_t *count,
|
||||
const std::vector<float> &recalls)
|
||||
{
|
||||
float found = 0;
|
||||
for (uint32_t i = 0; i < npart; ++i)
|
||||
{
|
||||
size_t max_found = std::min(count[i], k);
|
||||
found += recalls[max_found - 1] * max_found;
|
||||
}
|
||||
return found / (float)k_aggr;
|
||||
}
|
||||
|
||||
void simulate(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, const uint32_t nsim,
|
||||
const std::vector<float> &recalls)
|
||||
{
|
||||
std::random_device r;
|
||||
std::default_random_engine randeng(r());
|
||||
std::uniform_int_distribution<int> uniform_dist(0, npart - 1);
|
||||
|
||||
uint32_t *count = new uint32_t[npart];
|
||||
double aggr_recall = 0;
|
||||
|
||||
for (uint32_t i = 0; i < nsim; ++i)
|
||||
{
|
||||
for (uint32_t p = 0; p < npart; ++p)
|
||||
{
|
||||
count[p] = 0;
|
||||
}
|
||||
for (uint32_t t = 0; t < k_aggr; ++t)
|
||||
{
|
||||
count[uniform_dist(randeng)]++;
|
||||
}
|
||||
aggr_recall += aggregate_recall(k_aggr, k, npart, count, recalls);
|
||||
}
|
||||
|
||||
std::cout << "Aggregate recall is " << aggr_recall / (double)nsim << std::endl;
|
||||
delete[] count;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc < 6)
|
||||
{
|
||||
std::cout << argv[0] << " k_aggregate k_out npart nsim recall@1 recall@2 ... recall@k" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const uint32_t k_aggr = atoi(argv[1]);
|
||||
const uint32_t k = atoi(argv[2]);
|
||||
const uint32_t npart = atoi(argv[3]);
|
||||
const uint32_t nsim = atoi(argv[4]);
|
||||
|
||||
std::vector<float> recalls;
|
||||
for (int ctr = 5; ctr < argc; ctr++)
|
||||
{
|
||||
recalls.push_back((float)atof(argv[ctr]));
|
||||
}
|
||||
|
||||
if (recalls.size() != k)
|
||||
{
|
||||
std::cerr << "Please input k numbers for recall@1, recall@2 .. recall@k" << std::endl;
|
||||
}
|
||||
if (k_aggr > npart * k)
|
||||
{
|
||||
std::cerr << "k_aggr must be <= k * npart" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
if (nsim <= npart * k_aggr)
|
||||
{
|
||||
std::cerr << "Choose nsim > npart*k_aggr" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
simulate(k_aggr, k, npart, nsim, recalls);
|
||||
|
||||
return 0;
|
||||
}
|
||||
147
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/stats_label_data.cpp
vendored
Normal file
147
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/stats_label_data.cpp
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
namespace po = boost::program_options;
|
||||
|
||||
void stats_analysis(const std::string labels_file, std::string univeral_label, uint32_t density = 10)
|
||||
{
|
||||
std::string token, line;
|
||||
std::ifstream labels_stream(labels_file);
|
||||
std::unordered_map<std::string, uint32_t> label_counts;
|
||||
std::string label_with_max_points;
|
||||
uint32_t max_points = 0;
|
||||
long long sum = 0;
|
||||
long long point_cnt = 0;
|
||||
float avg_labels_per_pt, mean_label_size;
|
||||
|
||||
std::vector<uint32_t> labels_per_point;
|
||||
uint32_t dense_pts = 0;
|
||||
if (labels_stream.is_open())
|
||||
{
|
||||
while (getline(labels_stream, line))
|
||||
{
|
||||
point_cnt++;
|
||||
std::stringstream iss(line);
|
||||
uint32_t lbl_cnt = 0;
|
||||
while (getline(iss, token, ','))
|
||||
{
|
||||
lbl_cnt++;
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
if (label_counts.find(token) == label_counts.end())
|
||||
label_counts[token] = 0;
|
||||
label_counts[token]++;
|
||||
}
|
||||
if (lbl_cnt >= density)
|
||||
{
|
||||
dense_pts++;
|
||||
}
|
||||
labels_per_point.emplace_back(lbl_cnt);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "fraction of dense points with >= " << density
|
||||
<< " labels = " << (float)dense_pts / (float)labels_per_point.size() << std::endl;
|
||||
std::sort(labels_per_point.begin(), labels_per_point.end());
|
||||
|
||||
std::vector<std::pair<std::string, uint32_t>> label_count_vec;
|
||||
|
||||
for (auto it = label_counts.begin(); it != label_counts.end(); it++)
|
||||
{
|
||||
auto &lbl = *it;
|
||||
label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second));
|
||||
if (lbl.second > max_points)
|
||||
{
|
||||
max_points = lbl.second;
|
||||
label_with_max_points = lbl.first;
|
||||
}
|
||||
sum += lbl.second;
|
||||
}
|
||||
|
||||
sort(label_count_vec.begin(), label_count_vec.end(),
|
||||
[](const std::pair<std::string, uint32_t> &lhs, const std::pair<std::string, uint32_t> &rhs) {
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
for (float p = 0; p < 1; p += 0.05)
|
||||
{
|
||||
std::cout << "Percentile " << (100 * p) << "\t" << label_count_vec[(size_t)(p * label_count_vec.size())].first
|
||||
<< " with count=" << label_count_vec[(size_t)(p * label_count_vec.size())].second << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Most common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 1].first
|
||||
<< " with count=" << label_count_vec[label_count_vec.size() - 1].second << std::endl;
|
||||
if (label_count_vec.size() > 1)
|
||||
std::cout << "Second common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 2].first
|
||||
<< " with count=" << label_count_vec[label_count_vec.size() - 2].second << std::endl;
|
||||
if (label_count_vec.size() > 2)
|
||||
std::cout << "Third common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 3].first
|
||||
<< " with count=" << label_count_vec[label_count_vec.size() - 3].second << std::endl;
|
||||
avg_labels_per_pt = sum / (float)point_cnt;
|
||||
mean_label_size = sum / (float)label_counts.size();
|
||||
std::cout << "Total number of points = " << point_cnt << ", number of labels = " << label_counts.size()
|
||||
<< std::endl;
|
||||
std::cout << "Average number of labels per point = " << avg_labels_per_pt << std::endl;
|
||||
std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl;
|
||||
std::cout << "Most popular label is " << label_with_max_points << " with " << max_points << " pts" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string labels_file, universal_label;
|
||||
uint32_t density;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("labels_file", po::value<std::string>(&labels_file)->required(),
|
||||
"path to labels data file.");
|
||||
desc.add_options()("universal_label", po::value<std::string>(&universal_label)->required(),
|
||||
"Universal label used in labels file.");
|
||||
desc.add_options()("density", po::value<uint32_t>(&density)->default_value(1),
|
||||
"Number of labels each point in labels file, defaults to 1");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cerr << e.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
stats_analysis(labels_file, universal_label, density);
|
||||
}
|
||||
121
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/tsv_to_bin.cpp
vendored
Normal file
121
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/tsv_to_bin.cpp
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert_float(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
|
||||
{
|
||||
auto read_buf = new float[npts * (ndims + 1)];
|
||||
|
||||
auto cursor = read_buf;
|
||||
float val;
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
reader >> val;
|
||||
*cursor = val;
|
||||
cursor++;
|
||||
}
|
||||
}
|
||||
writer.write((char *)read_buf, npts * ndims * sizeof(float));
|
||||
delete[] read_buf;
|
||||
}
|
||||
|
||||
void block_convert_int8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
|
||||
{
|
||||
auto read_buf = new int8_t[npts * (ndims + 1)];
|
||||
|
||||
auto cursor = read_buf;
|
||||
int val;
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
reader >> val;
|
||||
*cursor = (int8_t)val;
|
||||
cursor++;
|
||||
}
|
||||
}
|
||||
writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t));
|
||||
delete[] read_buf;
|
||||
}
|
||||
|
||||
void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
|
||||
{
|
||||
auto read_buf = new uint8_t[npts * (ndims + 1)];
|
||||
|
||||
auto cursor = read_buf;
|
||||
int val;
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
reader >> val;
|
||||
*cursor = (uint8_t)val;
|
||||
cursor++;
|
||||
}
|
||||
}
|
||||
writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t));
|
||||
delete[] read_buf;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 6)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< "<float/int8/uint8> input_filename.tsv output_filename.bin "
|
||||
"dim num_pts>"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (std::string(argv[1]) != std::string("float") && std::string(argv[1]) != std::string("int8") &&
|
||||
std::string(argv[1]) != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
}
|
||||
|
||||
size_t ndims = atoi(argv[4]);
|
||||
size_t npts = atoi(argv[5]);
|
||||
|
||||
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
|
||||
// size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[3], std::ios::binary);
|
||||
auto npts_u32 = (uint32_t)npts;
|
||||
auto ndims_u32 = (uint32_t)ndims;
|
||||
writer.write((char *)&npts_u32, sizeof(uint32_t));
|
||||
writer.write((char *)&ndims_u32, sizeof(uint32_t));
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
{
|
||||
block_convert_float(reader, writer, cblk_size, ndims);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
{
|
||||
block_convert_int8(reader, writer, cblk_size, ndims);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
{
|
||||
block_convert_uint8(reader, writer, cblk_size, ndims);
|
||||
}
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
23
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/uint32_to_uint8.cpp
vendored
Normal file
23
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/uint32_to_uint8.cpp
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_uint32_bin output_int8_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
uint32_t *input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<uint32_t>(argv[1], input, npts, nd);
|
||||
uint8_t *output = new uint8_t[npts * nd];
|
||||
diskann::convert_types<uint32_t, uint8_t>(input, output, npts, nd);
|
||||
diskann::save_bin<uint8_t>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|
||||
23
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/uint8_to_float.cpp
vendored
Normal file
23
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/uint8_to_float.cpp
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_uint8_bin output_float_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
uint8_t *input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<uint8_t>(argv[1], input, npts, nd);
|
||||
float *output = new float[npts * nd];
|
||||
diskann::convert_types<uint8_t, float>(input, output, npts, nd);
|
||||
diskann::save_bin<float>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|
||||
163
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/vector_analysis.cpp
vendored
Normal file
163
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/vector_analysis.cpp
vendored
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "partition.h"
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T> int analyze_norm(std::string base_file)
|
||||
{
|
||||
std::cout << "Analyzing data norms" << std::endl;
|
||||
T *data;
|
||||
size_t npts, ndims;
|
||||
diskann::load_bin<T>(base_file, data, npts, ndims);
|
||||
std::vector<float> norms(npts, 0);
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
norms[i] += data[i * ndims + d] * data[i * ndims + d];
|
||||
norms[i] = std::sqrt(norms[i]);
|
||||
}
|
||||
std::sort(norms.begin(), norms.end());
|
||||
for (int p = 0; p < 100; p += 5)
|
||||
std::cout << "percentile " << p << ": " << norms[(uint64_t)(std::floor((p / 100.0) * npts))] << std::endl;
|
||||
std::cout << "percentile 100"
|
||||
<< ": " << norms[npts - 1] << std::endl;
|
||||
delete[] data;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T> int normalize_base(std::string base_file, std::string out_file)
|
||||
{
|
||||
std::cout << "Normalizing base" << std::endl;
|
||||
T *data;
|
||||
size_t npts, ndims;
|
||||
diskann::load_bin<T>(base_file, data, npts, ndims);
|
||||
// std::vector<float> norms(npts, 0);
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
float pt_norm = 0;
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
pt_norm += data[i * ndims + d] * data[i * ndims + d];
|
||||
pt_norm = std::sqrt(pt_norm);
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
data[i * ndims + d] = static_cast<T>(data[i * ndims + d] / pt_norm);
|
||||
}
|
||||
diskann::save_bin<T>(out_file, data, npts, ndims);
|
||||
delete[] data;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T> int augment_base(std::string base_file, std::string out_file, bool prep_base = true)
|
||||
{
|
||||
std::cout << "Analyzing data norms" << std::endl;
|
||||
T *data;
|
||||
size_t npts, ndims;
|
||||
diskann::load_bin<T>(base_file, data, npts, ndims);
|
||||
std::vector<float> norms(npts, 0);
|
||||
float max_norm = 0;
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
norms[i] += data[i * ndims + d] * data[i * ndims + d];
|
||||
max_norm = norms[i] > max_norm ? norms[i] : max_norm;
|
||||
}
|
||||
// std::sort(norms.begin(), norms.end());
|
||||
max_norm = std::sqrt(max_norm);
|
||||
std::cout << "Max norm: " << max_norm << std::endl;
|
||||
T *new_data;
|
||||
size_t newdims = ndims + 1;
|
||||
new_data = new T[npts * newdims];
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
if (prep_base)
|
||||
{
|
||||
for (size_t j = 0; j < ndims; j++)
|
||||
{
|
||||
new_data[i * newdims + j] = static_cast<T>(data[i * ndims + j] / max_norm);
|
||||
}
|
||||
float diff = 1 - (norms[i] / (max_norm * max_norm));
|
||||
diff = diff <= 0 ? 0 : std::sqrt(diff);
|
||||
new_data[i * newdims + ndims] = static_cast<T>(diff);
|
||||
if (diff <= 0)
|
||||
{
|
||||
std::cout << i << " has large max norm, investigate if needed. diff = " << diff << std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t j = 0; j < ndims; j++)
|
||||
{
|
||||
new_data[i * newdims + j] = static_cast<T>(data[i * ndims + j] / std::sqrt(norms[i]));
|
||||
}
|
||||
new_data[i * newdims + ndims] = 0;
|
||||
}
|
||||
}
|
||||
diskann::save_bin<T>(out_file, new_data, npts, newdims);
|
||||
delete[] new_data;
|
||||
delete[] data;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T> int aux_main(char **argv)
|
||||
{
|
||||
std::string base_file(argv[2]);
|
||||
uint32_t option = atoi(argv[3]);
|
||||
if (option == 1)
|
||||
analyze_norm<T>(base_file);
|
||||
else if (option == 2)
|
||||
augment_base<T>(base_file, std::string(argv[4]), true);
|
||||
else if (option == 3)
|
||||
augment_base<T>(base_file, std::string(argv[4]), false);
|
||||
else if (option == 4)
|
||||
normalize_base<T>(base_file, std::string(argv[4]));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc < 4)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " data_type [float/int8/uint8] base_bin_file "
|
||||
"[option: 1-norm analysis, 2-prep_base_for_mip, "
|
||||
"3-prep_query_for_mip, 4-normalize-vecs] [out_file for "
|
||||
"options 2/3/4]"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
{
|
||||
aux_main<float>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
{
|
||||
aux_main<int8_t>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
{
|
||||
aux_main<uint8_t>(argv);
|
||||
}
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user