Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
add_executable(build_memory_index build_memory_index.cpp)
target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(build_stitched_index build_stitched_index.cpp)
target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(search_memory_index search_memory_index.cpp)
target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(build_disk_index build_disk_index.cpp)
target_link_libraries(build_disk_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB} Boost::program_options)
add_executable(search_disk_index search_disk_index.cpp)
target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(range_search_disk_index range_search_disk_index.cpp)
target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(test_streaming_scenario test_streaming_scenario.cpp)
target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(test_insert_deletes_consolidate test_insert_deletes_consolidate.cpp)
target_link_libraries(test_insert_deletes_consolidate ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
if (NOT MSVC)
install(TARGETS build_memory_index
build_stitched_index
search_memory_index
build_disk_index
search_disk_index
range_search_disk_index
test_streaming_scenario
test_insert_deletes_consolidate
RUNTIME
)
endif()

View File

@@ -0,0 +1,191 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <omp.h>
#include <boost/program_options.hpp>
#include "utils.h"
#include "disk_utils.h"
#include "math_utils.h"
#include "index.h"
#include "partition.h"
#include "program_options_utils.hpp"
namespace po = boost::program_options;
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label,
label_type;
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold;
float B, M;
bool append_reorder_data = false;
bool use_opq = false;
po::options_description desc{
program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
required_configs.add_options()("search_DRAM_budget,B", po::value<float>(&B)->required(),
"DRAM budget in GB for searching the index to set the "
"compressed level for data while search happens");
required_configs.add_options()("build_DRAM_budget,M", po::value<float>(&M)->required(),
"DRAM budget in GB for building the index");
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("QD", po::value<uint32_t>(&QD)->default_value(0),
" Quantized Dimension for compression");
optional_configs.add_options()("codebook_prefix", po::value<std::string>(&codebook_prefix)->default_value(""),
"Path prefix for pre-trained codebook");
optional_configs.add_options()("PQ_disk_bytes", po::value<uint32_t>(&disk_PQ)->default_value(0),
"Number of bytes to which vectors should be compressed "
"on SSD; 0 for no compression");
optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false),
"Include full precision data in the index. Use only in "
"conjuction with compressed data on SSD.");
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
program_options_utils::LABEL_FILE);
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
program_options_utils::UNIVERSAL_LABEL);
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
program_options_utils::FILTERED_LBUILD);
optional_configs.add_options()("filter_threshold,F", po::value<uint32_t>(&filter_threshold)->default_value(0),
"Threshold to break up the existing nodes to generate new graph "
"internally where each node has a maximum F labels.");
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
if (vm["append_reorder_data"].as<bool>())
append_reorder_data = true;
if (vm["use_opq"].as<bool>())
use_opq = true;
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
bool use_filters = (label_file != "") ? true : false;
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else if (dist_fn == std::string("cosine"))
metric = diskann::Metric::COSINE;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
if (append_reorder_data)
{
if (disk_PQ == 0)
{
std::cout << "Error: It is not necessary to append data for reordering "
"when vectors are not compressed on disk."
<< std::endl;
return -1;
}
if (data_type != std::string("float"))
{
std::cout << "Error: Appending data for reordering currently only "
"supported for float data type."
<< std::endl;
return -1;
}
}
std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " +
std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " +
std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " +
std::string(std::to_string(append_reorder_data)) + " " +
std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD));
try
{
if (label_file != "" && label_type == "ushort")
{
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
use_filters, label_file, universal_label, filter_threshold, Lf);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
use_filters, label_file, universal_label, filter_threshold, Lf);
else
{
diskann::cerr << "Error. Unsupported data type" << std::endl;
return -1;
}
}
else
{
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
else
{
diskann::cerr << "Error. Unsupported data type" << std::endl;
return -1;
}
}
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index build failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,164 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <omp.h>
#include <cstring>
#include <boost/program_options.hpp>
#include "index.h"
#include "utils.h"
#include "program_options_utils.hpp"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <unistd.h>
#else
#include <Windows.h>
#endif
#include "memory_mapper.h"
#include "ann_exception.h"
#include "index_factory.h"
namespace po = boost::program_options;
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
float alpha;
bool use_pq_build, use_opq;
po::options_description desc{
program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
program_options_utils::LABEL_FILE);
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
program_options_utils::UNIVERSAL_LABEL);
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
program_options_utils::FILTERED_LBUILD);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
use_pq_build = (build_PQ_bytes > 0);
use_opq = vm["use_opq"].as<bool>();
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cout << "Unsupported distance function. Currently only L2/ Inner "
"Product/Cosine are supported."
<< std::endl;
return -1;
}
try
{
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha
<< " #threads: " << num_threads << std::endl;
size_t data_num, data_dim;
diskann::get_bin_metadata(data_path, data_num, data_dim);
auto index_build_params = diskann::IndexWriteParametersBuilder(L, R)
.with_filter_list_size(Lf)
.with_alpha(alpha)
.with_saturate_graph(false)
.with_num_threads(num_threads)
.build();
auto filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(label_file)
.with_save_path_prefix(index_path_prefix)
.build();
auto config = diskann::IndexConfigBuilder()
.with_metric(metric)
.with_dimension(data_dim)
.with_max_points(data_num)
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
.with_data_type(data_type)
.with_label_type(label_type)
.is_dynamic_index(false)
.with_index_write_params(index_build_params)
.is_enable_tags(false)
.is_use_opq(use_opq)
.is_pq_dist_build(use_pq_build)
.with_num_pq_chunks(build_PQ_bytes)
.build();
auto index_factory = diskann::IndexFactory(config);
auto index = index_factory.create_instance();
index->build(data_path, data_num, filter_params);
index->save(index_path_prefix.c_str());
index.reset();
return 0;
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index build failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,441 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <boost/program_options.hpp>
#include <chrono>
#include <cstdio>
#include <cstring>
#include <random>
#include <string>
#include <tuple>
#include "filter_utils.h"
#include <omp.h>
#ifndef _WINDOWS
#include <sys/uio.h>
#endif
#include "index.h"
#include "memory_mapper.h"
#include "parameters.h"
#include "utils.h"
#include "program_options_utils.hpp"
namespace po = boost::program_options;
typedef std::tuple<std::vector<std::vector<uint32_t>>, uint64_t> stitch_indices_return_values;
/*
* Inline function to display progress bar.
*/
inline void print_progress(double percentage)
{
int val = (int)(percentage * 100);
int lpad = (int)(percentage * PBWIDTH);
int rpad = PBWIDTH - lpad;
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
fflush(stdout);
}
/*
* Inline function to generate a random integer in a range.
*/
inline size_t random(size_t range_from, size_t range_to)
{
std::random_device rand_dev;
std::mt19937 generator(rand_dev());
std::uniform_int_distribution<size_t> distr(range_from, range_to);
return distr(generator);
}
/*
* function to handle command line parsing.
*
* Arguments are merely the inputs from the command line.
*/
void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix,
path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L,
uint32_t &stitched_R, float &alpha)
{
po::options_description desc{
program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("index_path_prefix",
po::value<std::string>(&final_index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&input_data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("label_file", po::value<std::string>(&label_data_path)->default_value(""),
program_options_utils::LABEL_FILE);
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
program_options_utils::UNIVERSAL_LABEL);
optional_configs.add_options()("stitched_R", po::value<uint32_t>(&stitched_R)->default_value(100),
"Degree to prune final graph down to");
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
exit(0);
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
throw;
}
}
/*
* Custom index save to write the in-memory index to disk.
* Also writes required files for diskANN API -
* 1. labels_to_medoids
* 2. universal_label
* 3. data (redundant for static indices)
* 4. labels (redundant for static indices)
*/
void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size,
std::vector<std::vector<uint32_t>> stitched_graph,
tsl::robin_map<std::string, uint32_t> entry_points, std::string universal_label,
path label_data_path)
{
// aux. file 1
auto saving_index_timer = std::chrono::high_resolution_clock::now();
std::ifstream original_label_data_stream;
original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
original_label_data_stream.open(label_data_path, std::ios::binary);
std::ofstream new_label_data_stream;
new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary);
new_label_data_stream << original_label_data_stream.rdbuf();
original_label_data_stream.close();
new_label_data_stream.close();
// aux. file 2
std::ifstream original_input_data_stream;
original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
original_input_data_stream.open(input_data_path, std::ios::binary);
std::ofstream new_input_data_stream;
new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary);
new_input_data_stream << original_input_data_stream.rdbuf();
original_input_data_stream.close();
new_input_data_stream.close();
// aux. file 3
std::ofstream labels_to_medoids_writer;
labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit);
labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt");
for (auto iter : entry_points)
labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl;
labels_to_medoids_writer.close();
// aux. file 4 (only if we're using a universal label)
if (universal_label != "")
{
std::ofstream universal_label_writer;
universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit);
universal_label_writer.open(final_index_path_prefix + "_universal_label.txt");
universal_label_writer << universal_label << std::endl;
universal_label_writer.close();
}
// main index
uint64_t index_num_frozen_points = 0, index_num_edges = 0;
uint32_t index_max_observed_degree = 0, index_entry_point = 0;
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
for (auto &point_neighbors : stitched_graph)
{
index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size());
}
std::ofstream stitched_graph_writer;
stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit);
stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary);
stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t));
stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t));
stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t));
stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t));
size_t bytes_written = METADATA;
for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++)
{
uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size();
std::vector<uint32_t> current_node_neighbors = stitched_graph[node_point];
stitched_graph_writer.write((char *)&current_node_num_neighbors, sizeof(uint32_t));
bytes_written += sizeof(uint32_t);
for (const auto &current_node_neighbor : current_node_neighbors)
{
stitched_graph_writer.write((char *)&current_node_neighbor, sizeof(uint32_t));
bytes_written += sizeof(uint32_t);
}
index_num_edges += current_node_num_neighbors;
}
if (bytes_written != final_index_size)
{
std::cerr << "Error: written bytes does not match allocated space" << std::endl;
throw;
}
stitched_graph_writer.close();
std::chrono::duration<double> saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer;
std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl;
std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size()))
<< std::endl;
std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl;
}
/*
* Unions the per-label graph indices together via the following policy:
* - any two nodes can only have at most one edge between them -
*
* Returns the "stitched" graph and its expected file size.
*/
template <typename T>
stitch_indices_return_values stitch_label_indices(
path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels,
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
tsl::robin_map<std::string, uint32_t> &label_entry_points,
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map)
{
size_t final_index_size = 0;
std::vector<std::vector<uint32_t>> stitched_graph(total_number_of_points);
auto stitching_index_timer = std::chrono::high_resolution_clock::now();
for (const auto &lbl : all_labels)
{
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
std::vector<std::vector<uint32_t>> curr_label_index;
uint64_t curr_label_index_size;
uint32_t curr_label_entry_point;
std::tie(curr_label_index, curr_label_index_size) =
diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]);
curr_label_entry_point = (uint32_t)random(0, curr_label_index.size());
label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point];
for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++)
{
uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point];
for (auto &node_neighbor : curr_label_index[node_point])
{
uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor];
std::vector<uint32_t> curr_point_neighbors = stitched_graph[original_point_id];
if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) ==
curr_point_neighbors.end())
{
stitched_graph[original_point_id].push_back(original_neighbor_id);
final_index_size += sizeof(uint32_t);
}
}
}
}
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA);
std::chrono::duration<double> stitching_index_time =
std::chrono::high_resolution_clock::now() - stitching_index_timer;
std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl;
return std::make_tuple(stitched_graph, final_index_size);
}
/*
* Applies the prune_neighbors function from src/index.cpp to
* every node in the stitched graph.
*
* This is an optional step, hence the saving of both the full
* and pruned graph.
*/
template <typename T>
void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path,
std::vector<std::vector<uint32_t>> stitched_graph, uint32_t stitched_R,
tsl::robin_map<std::string, uint32_t> label_entry_points, std::string universal_label,
path label_data_path, uint32_t num_threads)
{
size_t dimension, number_of_label_points;
auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr);
auto std_cout_buffer = std::cout.rdbuf(nullptr);
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false,
false, false, 0, false);
// not searching this index, set search_l to 0
index.load(full_index_path_prefix.c_str(), num_threads, 1);
std::cout << "parsing labels" << std::endl;
index.prune_all_neighbors(stitched_R, 750, 1.2);
index.save((final_index_path_prefix).c_str());
diskann::cout.rdbuf(diskann_cout_buffer);
std::cout.rdbuf(std_cout_buffer);
std::chrono::duration<double> pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer;
std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl;
}
/*
* Delete all temporary artifacts.
* In the process of creating the stitched index, some temporary artifacts are
* created:
* 1. the separate bin files for each labels' points
* 2. the separate diskANN indices built for each label
* 3. the '.data' file created while generating the indices
*/
void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels)
{
for (const auto &lbl : all_labels)
{
path curr_label_input_data_path(input_data_path + "_" + lbl);
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
path curr_label_index_path_data(curr_label_index_path + ".data");
if (std::remove(curr_label_index_path.c_str()) != 0)
throw;
if (std::remove(curr_label_input_data_path.c_str()) != 0)
throw;
if (std::remove(curr_label_index_path_data.c_str()) != 0)
throw;
}
}
int main(int argc, char **argv)
{
// 1. handle cmdline inputs
std::string data_type;
path input_data_path, final_index_path_prefix, label_data_path;
std::string universal_label;
uint32_t num_threads, R, L, stitched_R;
float alpha;
auto index_timer = std::chrono::high_resolution_clock::now();
handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label,
num_threads, R, L, stitched_R, alpha);
path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt";
path labels_map_file = final_index_path_prefix + "_labels_map.txt";
convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label);
// 2. parse label file and create necessary data structures
std::vector<label_set> point_ids_to_labels;
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
label_set all_labels;
std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) =
diskann::parse_label_file(labels_file_to_use, universal_label);
// 3. for each label, make a separate data file
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map;
uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size();
#ifndef _WINDOWS
if (data_type == "uint8")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<uint8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else if (data_type == "int8")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<int8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else if (data_type == "float")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<float>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else
throw;
#else
if (data_type == "uint8")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<uint8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else if (data_type == "int8")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<int8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else if (data_type == "float")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<float>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else
throw;
#endif
// 4. for each created data file, create a vanilla diskANN index
if (data_type == "uint8")
diskann::generate_label_indices<uint8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
num_threads);
else if (data_type == "int8")
diskann::generate_label_indices<int8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
num_threads);
else if (data_type == "float")
diskann::generate_label_indices<float>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
num_threads);
else
throw;
// 5. "stitch" the indices together
std::vector<std::vector<uint32_t>> stitched_graph;
tsl::robin_map<std::string, uint32_t> label_entry_points;
uint64_t stitched_graph_size;
if (data_type == "uint8")
std::tie(stitched_graph, stitched_graph_size) =
stitch_label_indices<uint8_t>(final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
else if (data_type == "int8")
std::tie(stitched_graph, stitched_graph_size) =
stitch_label_indices<int8_t>(final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
else if (data_type == "float")
std::tie(stitched_graph, stitched_graph_size) =
stitch_label_indices<float>(final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
else
throw;
path full_index_path_prefix = final_index_path_prefix + "_full";
// 5a. save the stitched graph to disk
save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points,
universal_label, labels_file_to_use);
// 6. run a prune on the stitched index, and save to disk
if (data_type == "uint8")
prune_and_save<uint8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
else if (data_type == "int8")
prune_and_save<int8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
else if (data_type == "float")
prune_and_save<float>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
else
throw;
std::chrono::duration<double> index_time = std::chrono::high_resolution_clock::now() - index_timer;
std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl;
clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels);
}

View File

@@ -0,0 +1,46 @@
<!-- Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT license. -->
# Integration Tests
The following tests use Python to prepare, run, verify, and tear down the rest api services.
We do make use of the built-in `unittest` library, but that's only to take advantage of test reporting purposes.
These are decidedly **not** _unit_ tests. These are end to end integration tests.
## Caveats
This has only been tested or built for Linux, though we have written platform agnostic Python for the smoke test
(i.e. using `os.path.join`, etc)
It has been tested on Python 3.9 and 3.10, but should work on Python 3.6+.
## How to Run
First, build the DiskANN RestAPI code; see $REPOSITORY_ROOT/workflows/rest_api.md for detailed instructions.
```bash
cd tests/python
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
export DISKANN_BUILD_DIR=/path/to/your/diskann/build
python -m unittest
```
## Smoke Test Failed, Now What?
The smoke test written takes advantage of temporary directories that are only valid during the
lifetime of the test. The contents of these directories include:
- Randomized vectors (first in tsv, then bin form) used to build the PQFlashIndex
- The PQFlashIndex files
It is useful to keep these around. By setting some environment variables, you can control whether an ephemeral,
temporary directory is used (and deleted on test completion), or left as an exercise for the developer to
clean up.
The valid environment variables are:
- `DISKANN_REST_TEST_WORKING_DIR` (example: `$USER/DiskANNRestTest`)
- If this is specified, it **must exist** and **must be writeable**. Any existing files will be clobbered.
- `DISKANN_REST_SERVER` (example: `http://127.0.0.1:10067`)
- Note that if this is set, no data will be generated, nor will a server be started; it is presumed you have done
all the work in creating and starting the rest server prior to running the test and just submits requests against it.

View File

@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import numpy as np
import os
import subprocess
def output_vectors(
diskann_build_path: str,
temporary_file_path: str,
vectors: np.ndarray,
timeout: int = 60
) -> str:
vectors_as_tsv_path = os.path.join(temporary_file_path, "vectors.tsv")
with open(vectors_as_tsv_path, "w") as vectors_tsv_out:
for vector in vectors:
as_str = "\t".join((str(component) for component in vector))
print(as_str, file=vectors_tsv_out)
# there is probably a clever way to have numpy write out C++ friendly floats, so feel free to remove this in
# favor of something more sane later
vectors_as_bin_path = os.path.join(temporary_file_path, "vectors.bin")
tsv_to_bin_path = os.path.join(diskann_build_path, "apps", "utils", "tsv_to_bin")
number_of_points, dimensions = vectors.shape
args = [
tsv_to_bin_path,
"float",
vectors_as_tsv_path,
vectors_as_bin_path,
str(dimensions),
str(number_of_points)
]
completed = subprocess.run(args, timeout=timeout)
if completed.returncode != 0:
raise Exception(f"Unable to convert tsv to binary using tsv_to_bin, completed_process: {completed}")
return vectors_as_bin_path
def build_ssd_index(
diskann_build_path: str,
temporary_file_path: str,
vectors: np.ndarray,
per_process_timeout: int = 60 # this may not be long enough if you're doing something larger
):
vectors_as_bin_path = output_vectors(diskann_build_path, temporary_file_path, vectors, timeout=per_process_timeout)
ssd_builder_path = os.path.join(diskann_build_path, "apps", "build_disk_index")
args = [
ssd_builder_path,
"--data_type", "float",
"--dist_fn", "l2",
"--data_path", vectors_as_bin_path,
"--index_path_prefix", os.path.join(temporary_file_path, "smoke_test"),
"-R", "64",
"-L", "100",
"--search_DRAM_budget", "1",
"--build_DRAM_budget", "1",
"--num_threads", "1",
"--PQ_disk_bytes", "0"
]
completed = subprocess.run(args, timeout=per_process_timeout)
if completed.returncode != 0:
command_run = " ".join(args)
raise Exception(f"Unable to build a disk index with the command: '{command_run}'\ncompleted_process: {completed}\nstdout: {completed.stdout}\nstderr: {completed.stderr}")
# index is now built inside of temporary_file_path

View File

@@ -0,0 +1,379 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <atomic>
#include <cstring>
#include <iomanip>
#include <omp.h>
#include <set>
#include <boost/program_options.hpp>
#include "index.h"
#include "disk_utils.h"
#include "math_utils.h"
#include "memory_mapper.h"
#include "pq_flash_index.h"
#include "partition.h"
#include "timer.h"
#include "program_options_utils.hpp"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include "linux_aligned_file_reader.h"
#else
#ifdef USE_BING_INFRA
#include "bing_aligned_file_reader.h"
#else
#include "windows_aligned_file_reader.h"
#endif
#endif
namespace po = boost::program_options;
#define WARMUP false
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
{
diskann::cout << std::setw(20) << category << ": " << std::flush;
for (uint32_t s = 0; s < percentiles.size(); s++)
{
diskann::cout << std::setw(8) << percentiles[s] << "%";
}
diskann::cout << std::endl;
diskann::cout << std::setw(22) << " " << std::flush;
for (uint32_t s = 0; s < percentiles.size(); s++)
{
diskann::cout << std::setw(9) << results[s];
}
diskann::cout << std::endl;
}
template <typename T, typename LabelT = uint32_t>
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file,
std::string &gt_file, const uint32_t num_threads, const float search_range,
const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector<uint32_t> &Lvec)
{
std::string pq_prefix = index_path_prefix + "_pq";
std::string disk_index_file = index_path_prefix + "_disk.index";
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
if (beamwidth <= 0)
diskann::cout << "beamwidth to be optimized for each L value" << std::endl;
else
diskann::cout << " beamwidth: " << beamwidth << std::endl;
// load query bin
T *query = nullptr;
std::vector<std::vector<uint32_t>> groundtruth_ids;
size_t query_num, query_dim, query_aligned_dim, gt_num;
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
bool calc_recall_flag = false;
if (gt_file != std::string("null") && file_exists(gt_file))
{
diskann::load_range_truthset(gt_file, groundtruth_ids,
gt_num); // use for range search type of truthset
// diskann::prune_truthset_for_range(gt_file, search_range,
// groundtruth_ids, gt_num); // use for traditional truthset
if (gt_num != query_num)
{
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
return -1;
}
calc_recall_flag = true;
}
std::shared_ptr<AlignedFileReader> reader = nullptr;
#ifdef _WINDOWS
#ifndef USE_BING_INFRA
reader.reset(new WindowsAlignedFileReader());
#else
reader.reset(new diskann::BingAlignedFileReader());
#endif
#else
reader.reset(new LinuxAlignedFileReader());
#endif
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
if (res != 0)
{
return res;
}
// cache bfs levels
std::vector<uint32_t> node_list;
diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl;
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
// _pFlashIndex->generate_cache_list_from_sample_queries(
// warmup_query_file, 15, 6, num_nodes_to_cache, num_threads,
// node_list);
_pFlashIndex->load_cache_list(node_list);
node_list.clear();
node_list.shrink_to_fit();
omp_set_num_threads(num_threads);
uint64_t warmup_L = 20;
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
T *warmup = nullptr;
if (WARMUP)
{
if (file_exists(warmup_query_file))
{
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
}
else
{
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
warmup_dim = query_dim;
warmup_aligned_dim = query_aligned_dim;
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-128, 127);
for (uint32_t i = 0; i < warmup_num; i++)
{
for (uint32_t d = 0; d < warmup_dim; d++)
{
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
}
}
}
diskann::cout << "Warming up index... " << std::flush;
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<float> warmup_result_dists(warmup_num, 0);
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
{
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
warmup_result_ids_64.data() + (i * 1),
warmup_result_dists.data() + (i * 1), 4);
}
diskann::cout << "..done" << std::endl;
}
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
diskann::cout.precision(2);
std::string recall_string = "Recall@rng=" + std::to_string(search_range);
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
<< "CPU (s)";
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall_string << std::endl;
}
else
diskann::cout << std::endl;
diskann::cout << "==============================================================="
"==========================================="
<< std::endl;
std::vector<std::vector<std::vector<uint32_t>>> query_result_ids(Lvec.size());
uint32_t optimized_beamwidth = 2;
uint32_t max_list_size = 10000;
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
{
uint32_t L = Lvec[test_id];
if (beamwidth <= 0)
{
optimized_beamwidth =
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
}
else
optimized_beamwidth = beamwidth;
query_result_ids[test_id].clear();
query_result_ids[test_id].resize(query_num);
diskann::QueryStats *stats = new diskann::QueryStats[query_num];
auto s = std::chrono::high_resolution_clock::now();
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)query_num; i++)
{
std::vector<uint64_t> indices;
std::vector<float> distances;
uint32_t res_count =
_pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices,
distances, optimized_beamwidth, stats + i);
query_result_ids[test_id][i].reserve(res_count);
query_result_ids[test_id][i].resize(res_count);
for (uint32_t idx = 0; idx < res_count; idx++)
query_result_ids[test_id][i][idx] = (uint32_t)indices[idx];
}
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
auto qps = (1.0 * query_num) / (1.0 * diff.count());
auto mean_latency = diskann::get_mean_stats<float>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
auto latency_999 = diskann::get_percentile_stats<float>(
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.n_ios; });
double mean_cpuus = diskann::get_mean_stats<float>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; });
double recall = 0;
double ratio_of_sums = 0;
if (calc_recall_flag)
{
recall =
diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]);
uint32_t total_true_positive = 0;
uint32_t total_positive = 0;
for (uint32_t i = 0; i < query_num; i++)
{
total_true_positive += (uint32_t)query_result_ids[test_id][i].size();
total_positive += (uint32_t)groundtruth_ids[i].size();
}
ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive);
}
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
<< std::setw(16) << mean_cpuus;
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall << "," << ratio_of_sums << std::endl;
}
else
diskann::cout << std::endl;
}
diskann::cout << "Done searching. " << std::endl;
diskann::aligned_free(query);
if (warmup != nullptr)
diskann::aligned_free(warmup);
return 0;
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file;
uint32_t num_threads, W, num_nodes_to_cache;
std::vector<uint32_t> Lvec;
float range;
po::options_description desc{program_options_utils::make_program_description(
"range_search_disk_index", "Searches disk DiskANN indexes using ranges")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
program_options_utils::QUERY_FILE_DESCRIPTION);
required_configs.add_options()("search_list,L",
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
program_options_utils::SEARCH_LIST_DESCRIPTION);
required_configs.add_options()("range_threshold,K", po::value<float>(&range)->required(),
"Number of neighbors to be returned");
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("gt_file", po::value<std::string>(&gt_file)->default_value(std::string("null")),
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
program_options_utils::BEAMWIDTH);
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cout << "Unsupported distance function. Currently only L2/ Inner "
"Product/Cosine are supported."
<< std::endl;
return -1;
}
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
{
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
return -1;
}
try
{
if (data_type == std::string("float"))
return search_disk_index<float>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
num_nodes_to_cache, Lvec);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
num_nodes_to_cache, Lvec);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
num_nodes_to_cache, Lvec);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
return -1;
}
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index search failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
set(CMAKE_CXX_STANDARD 17)
add_executable(inmem_server inmem_server.cpp)
if(MSVC)
target_link_options(inmem_server PRIVATE /MACHINE:x64)
target_link_libraries(inmem_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(inmem_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(inmem_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(ssd_server ssd_server.cpp)
if(MSVC)
target_link_options(ssd_server PRIVATE /MACHINE:x64)
target_link_libraries(ssd_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(ssd_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(ssd_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(multiple_ssdindex_server multiple_ssdindex_server.cpp)
if(MSVC)
target_link_options(multiple_ssdindex_server PRIVATE /MACHINE:x64)
target_link_libraries(multiple_ssdindex_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(multiple_ssdindex_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(multiple_ssdindex_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(client client.cpp)
if(MSVC)
target_link_options(client PRIVATE /MACHINE:x64)
target_link_libraries(client debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(client optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(client ${PROJECT_NAME} -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()

View File

@@ -0,0 +1,124 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <cpprest/http_client.h>
#include <restapi/common.h>
using namespace web;
using namespace web::http;
using namespace web::http::client;
using namespace diskann;
namespace po = boost::program_options;
template <typename T>
void query_loop(const std::string &ip_addr_port, const std::string &query_file, const unsigned nq, const unsigned Ls,
const unsigned k_value)
{
web::http::client::http_client client(U(ip_addr_port));
T *data;
size_t npts = 1, ndims = 128, rounded_dim = 128;
diskann::load_aligned_bin<T>(query_file, data, npts, ndims, rounded_dim);
for (unsigned i = 0; i < nq; ++i)
{
T *vec = data + i * rounded_dim;
web::http::http_request http_query(methods::POST);
web::json::value queryJson = web::json::value::object();
queryJson[QUERY_ID_KEY] = i;
queryJson[K_KEY] = k_value;
queryJson[L_KEY] = Ls;
for (size_t i = 0; i < ndims; ++i)
{
queryJson[VECTOR_KEY][i] = web::json::value::number(vec[i]);
}
http_query.set_body(queryJson);
client.request(http_query)
.then([](web::http::http_response response) -> pplx::task<utility::string_t> {
if (response.status_code() == status_codes::OK)
{
return response.extract_string();
}
std::cerr << "Query failed" << std::endl;
return pplx::task_from_result(utility::string_t());
})
.then([](pplx::task<utility::string_t> previousTask) {
try
{
std::cout << previousTask.get() << std::endl;
}
catch (http_exception const &e)
{
std::wcout << e.what() << std::endl;
}
})
.wait();
}
}
int main(int argc, char *argv[])
{
std::string data_type, query_file, address;
uint32_t num_queries;
uint32_t l_search, k_value;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
"File containing the queries to search");
desc.add_options()("num_queries,Q", po::value<uint32_t>(&num_queries)->required(),
"Number of queries to search");
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
desc.add_options()("k_value,K", po::value<uint32_t>(&k_value)->default_value(10), "Value of K (default 10)");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
query_loop<float>(address, query_file, num_queries, l_search, k_value);
}
else if (data_type == std::string("int8"))
{
query_loop<int8_t>(address, query_file, num_queries, l_search, k_value);
}
else if (data_type == std::string("uint8"))
{
query_loop<uint8_t>(address, query_file, num_queries, l_search, k_value);
}
else
{
std::cerr << "Unsupported type " << argv[2] << std::endl;
return -1;
}
return 0;
}

View File

@@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_inMemorySearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
uint32_t num_threads;
uint32_t l_search;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("data_file", po::value<std::string>(&data_file)->required(),
"File containing the data found in the index");
desc.add_options()("index_path_prefix", po::value<std::string>(&index_file)->required(),
"Path prefix for saving index file components");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->required(),
"Number of threads used for building index");
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<float>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else if (data_type == std::string("int8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else if (data_type == std::string("uint8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else
{
std::cerr << "Unsupported data type " << argv[2] << std::endl;
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <restapi/server.h>
#include <restapi/in_memory_search.h>
#include <codecvt>
#include <iostream>
std::unique_ptr<Server> g_httpServer(nullptr);
std::unique_ptr<diskann::InMemorySearch> g_inMemorySearch(nullptr);
void setup(const utility::string_t &address)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::wcout << L"Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch));
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
void loadIndex(const char *indexFile, const char *baseFile, const char *idsFile)
{
auto nsgSearch = new diskann::InMemorySearch(baseFile, indexFile, idsFile, diskann::L2);
g_inMemorySearch = std::unique_ptr<diskann::InMemorySearch>(nsgSearch);
}
std::wstring getHostingAddress(const char *hostNameAndPort)
{
wchar_t buffer[4096];
mbstowcs_s(nullptr, buffer, sizeof(buffer) / sizeof(buffer[0]), hostNameAndPort,
sizeof(buffer) / sizeof(buffer[0]));
return std::wstring(buffer);
}
int main(int argc, char *argv[])
{
if (argc != 5)
{
std::cout << "Usage: nsg_server <ip_addr_and_port> <index_file> "
"<base_file> <ids_file> "
<< std::endl;
exit(1);
}
auto address = getHostingAddress(argv[1]);
loadIndex(argv[2], argv[3], argv[4]);
while (1)
{
try
{
setup(address);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,182 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <omp.h>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("index_prefix_paths", po::value<std::string>(&index_prefix_paths)->required(),
"Path prefix for loading index file components");
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
"Number of nodes to cache during search");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
std::vector<std::pair<std::string, std::string>> index_tag_paths;
std::ifstream index_in(index_prefix_paths);
if (!index_in.is_open())
{
std::cerr << "Could not open " << index_prefix_paths << std::endl;
exit(-1);
}
std::ifstream tags_in(tags_file);
if (!tags_in.is_open())
{
std::cerr << "Could not open " << tags_file << std::endl;
exit(-1);
}
std::string prefix, tagfile;
while (std::getline(index_in, prefix))
{
if (std::getline(tags_in, tagfile))
{
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
}
else
{
std::cerr << "The number of tags specified does not match the number of "
"indices specified"
<< std::endl;
exit(-1);
}
}
index_in.close();
tags_in.close();
if (data_type == std::string("float"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<float>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else if (data_type == std::string("int8"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else if (data_type == std::string("uint8"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<uint8_t>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else
{
std::cerr << "Unsupported data type " << data_type << std::endl;
exit(-1);
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,141 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <omp.h>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
"Path prefix for loading index file components");
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
"Number of nodes to cache during search");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else if (data_type == std::string("int8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<int8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else if (data_type == std::string("uint8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<uint8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else
{
std::cerr << "Unsupported data type " << argv[2] << std::endl;
exit(-1);
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,499 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include "common_includes.h"
#include <boost/program_options.hpp>
#include "index.h"
#include "disk_utils.h"
#include "math_utils.h"
#include "memory_mapper.h"
#include "partition.h"
#include "pq_flash_index.h"
#include "timer.h"
#include "percentile_stats.h"
#include "program_options_utils.hpp"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include "linux_aligned_file_reader.h"
#else
#ifdef USE_BING_INFRA
#include "bing_aligned_file_reader.h"
#else
#include "windows_aligned_file_reader.h"
#endif
#endif
#define WARMUP false
namespace po = boost::program_options;
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
{
diskann::cout << std::setw(20) << category << ": " << std::flush;
for (uint32_t s = 0; s < percentiles.size(); s++)
{
diskann::cout << std::setw(8) << percentiles[s] << "%";
}
diskann::cout << std::endl;
diskann::cout << std::setw(22) << " " << std::flush;
for (uint32_t s = 0; s < percentiles.size(); s++)
{
diskann::cout << std::setw(9) << results[s];
}
diskann::cout << std::endl;
}
template <typename T, typename LabelT = uint32_t>
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
const std::string &result_output_prefix, const std::string &query_file, std::string &gt_file,
const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth,
const uint32_t num_nodes_to_cache, const uint32_t search_io_limit,
const std::vector<uint32_t> &Lvec, const float fail_if_recall_below,
const std::vector<std::string> &query_filters, const bool use_reorder_data = false)
{
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
if (beamwidth <= 0)
diskann::cout << "beamwidth to be optimized for each L value" << std::flush;
else
diskann::cout << " beamwidth: " << beamwidth << std::flush;
if (search_io_limit == std::numeric_limits<uint32_t>::max())
diskann::cout << "." << std::endl;
else
diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl;
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
// load query bin
T *query = nullptr;
uint32_t *gt_ids = nullptr;
float *gt_dists = nullptr;
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
bool filtered_search = false;
if (!query_filters.empty())
{
filtered_search = true;
if (query_filters.size() != 1 && query_filters.size() != query_num)
{
std::cout << "Error. Mismatch in number of queries and size of query "
"filters file"
<< std::endl;
return -1; // To return -1 or some other error handling?
}
}
bool calc_recall_flag = false;
if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file))
{
diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim);
if (gt_num != query_num)
{
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
}
calc_recall_flag = true;
}
std::shared_ptr<AlignedFileReader> reader = nullptr;
#ifdef _WINDOWS
#ifndef USE_BING_INFRA
reader.reset(new WindowsAlignedFileReader());
#else
reader.reset(new diskann::BingAlignedFileReader());
#endif
#else
reader.reset(new LinuxAlignedFileReader());
#endif
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
if (res != 0)
{
return res;
}
std::vector<uint32_t> node_list;
diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl;
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
// if (num_nodes_to_cache > 0)
// _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache,
// num_threads, node_list);
_pFlashIndex->load_cache_list(node_list);
node_list.clear();
node_list.shrink_to_fit();
omp_set_num_threads(num_threads);
uint64_t warmup_L = 20;
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
T *warmup = nullptr;
if (WARMUP)
{
if (file_exists(warmup_query_file))
{
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
}
else
{
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
warmup_dim = query_dim;
warmup_aligned_dim = query_aligned_dim;
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-128, 127);
for (uint32_t i = 0; i < warmup_num; i++)
{
for (uint32_t d = 0; d < warmup_dim; d++)
{
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
}
}
}
diskann::cout << "Warming up index... " << std::flush;
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<float> warmup_result_dists(warmup_num, 0);
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
{
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
warmup_result_ids_64.data() + (i * 1),
warmup_result_dists.data() + (i * 1), 4);
}
diskann::cout << "..done" << std::endl;
}
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
diskann::cout.precision(2);
std::string recall_string = "Recall@" + std::to_string(recall_at);
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
<< "Mean IO (us)" << std::setw(16) << "CPU (s)";
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall_string << std::endl;
}
else
diskann::cout << std::endl;
diskann::cout << "=================================================================="
"================================================================="
<< std::endl;
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
std::vector<std::vector<float>> query_result_dists(Lvec.size());
uint32_t optimized_beamwidth = 2;
double best_recall = 0.0;
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
{
uint32_t L = Lvec[test_id];
if (L < recall_at)
{
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
continue;
}
if (beamwidth <= 0)
{
diskann::cout << "Tuning beamwidth.." << std::endl;
optimized_beamwidth =
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
}
else
optimized_beamwidth = beamwidth;
query_result_ids[test_id].resize(recall_at * query_num);
query_result_dists[test_id].resize(recall_at * query_num);
auto stats = new diskann::QueryStats[query_num];
std::vector<uint64_t> query_result_ids_64(recall_at * query_num);
auto s = std::chrono::high_resolution_clock::now();
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)query_num; i++)
{
if (!filtered_search)
{
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at),
optimized_beamwidth, use_reorder_data, stats + i);
}
else
{
LabelT label_for_search;
if (query_filters.size() == 1)
{ // one label for all queries
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
}
else
{ // one label for each query
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
}
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
use_reorder_data, stats + i);
}
}
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
double qps = (1.0 * query_num) / (1.0 * diff.count());
diskann::convert_types<uint64_t, uint32_t>(query_result_ids_64.data(), query_result_ids[test_id].data(),
query_num, recall_at);
auto mean_latency = diskann::get_mean_stats<float>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
auto latency_999 = diskann::get_percentile_stats<float>(
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.n_ios; });
auto mean_cpuus = diskann::get_mean_stats<float>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.cpu_us; });
auto mean_io_us = diskann::get_mean_stats<float>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.io_us; });
double recall = 0;
if (calc_recall_flag)
{
recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
query_result_ids[test_id].data(), recall_at, recall_at);
best_recall = std::max(recall, best_recall);
}
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
<< std::setw(16) << mean_io_us << std::setw(16) << mean_cpuus;
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall << std::endl;
}
else
diskann::cout << std::endl;
delete[] stats;
}
diskann::cout << "Done searching. Now saving results " << std::endl;
uint64_t test_id = 0;
for (auto L : Lvec)
{
if (L < recall_at)
continue;
std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin";
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin";
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at);
}
diskann::aligned_free(query);
if (warmup != nullptr)
diskann::aligned_free(warmup);
return best_recall >= fail_if_recall_below ? 0 : -1;
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label,
label_type, query_filters_file;
uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit;
std::vector<uint32_t> Lvec;
bool use_reorder_data = false;
float fail_if_recall_below = 0.0f;
po::options_description desc{
program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("result_path", po::value<std::string>(&result_path_prefix)->required(),
program_options_utils::RESULT_PATH_DESCRIPTION);
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
program_options_utils::QUERY_FILE_DESCRIPTION);
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
required_configs.add_options()("search_list,L",
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
program_options_utils::SEARCH_LIST_DESCRIPTION);
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("gt_file", po::value<std::string>(&gt_file)->default_value(std::string("null")),
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
program_options_utils::BEAMWIDTH);
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
optional_configs.add_options()(
"search_io_limit",
po::value<uint32_t>(&search_io_limit)->default_value(std::numeric_limits<uint32_t>::max()),
"Max #IOs for search. Default value: uint32::max()");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false),
"Include full precision data in the index. Use only in "
"conjuction with compressed data on SSD. Default value: false");
optional_configs.add_options()("filter_label",
po::value<std::string>(&filter_label)->default_value(std::string("")),
program_options_utils::FILTER_LABEL_DESCRIPTION);
optional_configs.add_options()("query_filters_file",
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
program_options_utils::FILTERS_FILE_DESCRIPTION);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
optional_configs.add_options()("fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
if (vm["use_reorder_data"].as<bool>())
use_reorder_data = true;
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cout << "Unsupported distance function. Currently only L2/ Inner "
"Product/Cosine are supported."
<< std::endl;
return -1;
}
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
{
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
return -1;
}
if (use_reorder_data && data_type != std::string("float"))
{
std::cout << "Error: Reorder data for reordering currently only "
"supported for float data type."
<< std::endl;
return -1;
}
if (filter_label != "" && query_filters_file != "")
{
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
return -1;
}
std::vector<std::string> query_filters;
if (filter_label != "")
{
query_filters.push_back(filter_label);
}
else if (query_filters_file != "")
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
}
try
{
if (!query_filters.empty() && label_type == "ushort")
{
if (data_type == std::string("float"))
return search_disk_index<float, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
return -1;
}
}
else
{
if (data_type == std::string("float"))
return search_disk_index<float>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
return -1;
}
}
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index search failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,477 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <numeric>
#include <omp.h>
#include <set>
#include <string.h>
#include <boost/program_options.hpp>
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <time.h>
#include <unistd.h>
#endif
#include "index.h"
#include "memory_mapper.h"
#include "utils.h"
#include "program_options_utils.hpp"
#include "index_factory.h"
namespace po = boost::program_options;
template <typename T, typename LabelT = uint32_t>
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
const bool dynamic, const bool tags, const bool show_qps_per_thread,
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
{
using TagT = uint32_t;
// Load the query file
T *query = nullptr;
uint32_t *gt_ids = nullptr;
float *gt_dists = nullptr;
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
bool calc_recall_flag = false;
if (truthset_file != std::string("null") && file_exists(truthset_file))
{
diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim);
if (gt_num != query_num)
{
std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
}
calc_recall_flag = true;
}
else
{
diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl;
}
bool filtered_search = false;
if (!query_filters.empty())
{
filtered_search = true;
if (query_filters.size() != 1 && query_filters.size() != query_num)
{
std::cout << "Error. Mismatch in number of queries and size of query "
"filters file"
<< std::endl;
return -1; // To return -1 or some other error handling?
}
}
const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path);
auto config = diskann::IndexConfigBuilder()
.with_metric(metric)
.with_dimension(query_dim)
.with_max_points(0)
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
.with_data_type(diskann_type_to_name<T>())
.with_label_type(diskann_type_to_name<LabelT>())
.with_tag_type(diskann_type_to_name<TagT>())
.is_dynamic_index(dynamic)
.is_enable_tags(tags)
.is_concurrent_consolidate(false)
.is_pq_dist_build(false)
.is_use_opq(false)
.with_num_pq_chunks(0)
.with_num_frozen_pts(num_frozen_pts)
.build();
auto index_factory = diskann::IndexFactory(config);
auto index = index_factory.create_instance();
index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end())));
std::cout << "Index loaded" << std::endl;
if (metric == diskann::FAST_L2)
index->optimize_index_layout();
std::cout << "Using " << num_threads << " threads to search" << std::endl;
std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
std::cout.precision(2);
const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS";
uint32_t table_width = 0;
if (tags)
{
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)"
<< std::setw(15) << "99.9 Latency";
table_width += 4 + 12 + 20 + 15;
}
else
{
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps"
<< std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency";
table_width += 4 + 12 + 18 + 20 + 15;
}
uint32_t recalls_to_print = 0;
const uint32_t first_recall = print_all_recalls ? 1 : recall_at;
if (calc_recall_flag)
{
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
{
std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall));
}
recalls_to_print = recall_at + 1 - first_recall;
table_width += recalls_to_print * 12;
}
std::cout << std::endl;
std::cout << std::string(table_width, '=') << std::endl;
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
std::vector<std::vector<float>> query_result_dists(Lvec.size());
std::vector<float> latency_stats(query_num, 0);
std::vector<uint32_t> cmp_stats;
if (not tags || filtered_search)
{
cmp_stats = std::vector<uint32_t>(query_num, 0);
}
std::vector<TagT> query_result_tags;
if (tags)
{
query_result_tags.resize(recall_at * query_num);
}
double best_recall = 0.0;
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
{
uint32_t L = Lvec[test_id];
if (L < recall_at)
{
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
continue;
}
query_result_ids[test_id].resize(recall_at * query_num);
query_result_dists[test_id].resize(recall_at * query_num);
std::vector<T *> res = std::vector<T *>();
auto s = std::chrono::high_resolution_clock::now();
omp_set_num_threads(num_threads);
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)query_num; i++)
{
auto qs = std::chrono::high_resolution_clock::now();
if (filtered_search && !tags)
{
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
query_result_ids[test_id].data() + i * recall_at,
query_result_dists[test_id].data() + i * recall_at);
cmp_stats[i] = retval.second;
}
else if (metric == diskann::FAST_L2)
{
index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L,
query_result_ids[test_id].data() + i * recall_at);
}
else if (tags)
{
if (!filtered_search)
{
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
query_result_tags.data() + i * recall_at, nullptr, res);
}
else
{
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter);
}
for (int64_t r = 0; r < (int64_t)recall_at; r++)
{
query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r];
}
}
else
{
cmp_stats[i] = index
->search(query + i * query_aligned_dim, recall_at, L,
query_result_ids[test_id].data() + i * recall_at)
.second;
}
auto qe = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = qe - qs;
latency_stats[i] = (float)(diff.count() * 1000000);
}
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
double displayed_qps = query_num / diff.count();
if (show_qps_per_thread)
displayed_qps /= num_threads;
std::vector<double> recalls;
if (calc_recall_flag)
{
recalls.reserve(recalls_to_print);
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
{
recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
query_result_ids[test_id].data(), recall_at, curr_recall));
}
}
std::sort(latency_stats.begin(), latency_stats.end());
double mean_latency =
std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast<float>(query_num);
float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num;
if (tags && !filtered_search)
{
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency
<< std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)];
}
else
{
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps
<< std::setw(20) << (float)mean_latency << std::setw(15)
<< (float)latency_stats[(uint64_t)(0.999 * query_num)];
}
for (double recall : recalls)
{
std::cout << std::setw(12) << recall;
best_recall = std::max(recall, best_recall);
}
std::cout << std::endl;
}
std::cout << "Done searching. Now saving results " << std::endl;
uint64_t test_id = 0;
for (auto L : Lvec)
{
if (L < recall_at)
{
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
continue;
}
std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L);
std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin";
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
cur_result_path = cur_result_path_prefix + "_dists_float.bin";
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at);
test_id++;
}
diskann::aligned_free(query);
return best_recall >= fail_if_recall_below ? 0 : -1;
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
query_filters_file;
uint32_t num_threads, K;
std::vector<uint32_t> Lvec;
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
float fail_if_recall_below = 0.0f;
po::options_description desc{
program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")};
try
{
desc.add_options()("help,h", "Print this information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("result_path", po::value<std::string>(&result_path)->required(),
program_options_utils::RESULT_PATH_DESCRIPTION);
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
program_options_utils::QUERY_FILE_DESCRIPTION);
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
required_configs.add_options()("search_list,L",
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
program_options_utils::SEARCH_LIST_DESCRIPTION);
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("filter_label",
po::value<std::string>(&filter_label)->default_value(std::string("")),
program_options_utils::FILTER_LABEL_DESCRIPTION);
optional_configs.add_options()("query_filters_file",
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
program_options_utils::FILTERS_FILE_DESCRIPTION);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
optional_configs.add_options()("gt_file", po::value<std::string>(&gt_file)->default_value(std::string("null")),
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()(
"dynamic", po::value<bool>(&dynamic)->default_value(false),
"Whether the index is dynamic. Dynamic indices must have associated tags. Default false.");
optional_configs.add_options()("tags", po::value<bool>(&tags)->default_value(false),
"Whether to search with external identifiers (tags). Default false.");
optional_configs.add_options()("fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);
// Output controls
po::options_description output_controls("Output controls");
output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls),
"Print recalls at all positions, from 1 up to specified "
"recall_at value");
output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread),
"Print overall QPS divided by the number of threads in "
"the output table");
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs).add(output_controls);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
diskann::Metric metric;
if ((dist_fn == std::string("mips")) && (data_type == std::string("float")))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float")))
{
metric = diskann::Metric::FAST_L2;
}
else
{
std::cout << "Unsupported distance function. Currently only l2/ cosine are "
"supported in general, and mips/fast_l2 only for floating "
"point data."
<< std::endl;
return -1;
}
if (dynamic && not tags)
{
std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl;
return -1;
}
if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0)
{
std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl;
return -1;
}
if (filter_label != "" && query_filters_file != "")
{
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
return -1;
}
std::vector<std::string> query_filters;
if (filter_label != "")
{
query_filters.push_back(filter_label);
}
else if (query_filters_file != "")
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
}
try
{
if (!query_filters.empty() && label_type == "ushort")
{
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else
{
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
return -1;
}
}
else
{
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else
{
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
return -1;
}
}
}
catch (std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index search failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,536 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <index.h>
#include <numeric>
#include <omp.h>
#include <string.h>
#include <time.h>
#include <timer.h>
#include <boost/program_options.hpp>
#include <future>
#include "utils.h"
#include "filter_utils.h"
#include "program_options_utils.hpp"
#include "index_factory.h"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#include "memory_mapper.h"
namespace po = boost::program_options;
// load_aligned_bin modified to read pieces of the file, but using ifstream
// instead of cached_ifstream.
template <typename T>
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
{
diskann::Timer timer;
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(bin_file, std::ios::binary | std::ios::ate);
size_t actual_file_size = reader.tellg();
reader.seekg(0, std::ios::beg);
int npts_i32, dim_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&dim_i32, sizeof(int));
size_t npts = (uint32_t)npts_i32;
size_t dim = (uint32_t)dim_i32;
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
if (actual_file_size != expected_actual_file_size)
{
std::stringstream stream;
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
<< std::endl;
std::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
if (offset_points + points_to_read > npts)
{
std::stringstream stream;
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
<< " points, but have only " << npts << " points" << std::endl;
std::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
const size_t rounded_dim = ROUND_UP(dim, 8);
for (size_t i = 0; i < points_to_read; i++)
{
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
}
reader.close();
const double elapsedSeconds = timer.elapsed() / 1000000.0;
std::cout << "Read " << points_to_read << " points using non-cached reads in " << elapsedSeconds << std::endl;
}
std::string get_save_filename(const std::string &save_path, size_t points_to_skip, size_t points_deleted,
size_t last_point_threshold)
{
std::string final_path = save_path;
if (points_to_skip > 0)
{
final_path += "skip" + std::to_string(points_to_skip) + "-";
}
final_path += "del" + std::to_string(points_deleted) + "-";
final_path += std::to_string(last_point_threshold);
return final_path;
}
template <typename T, typename TagT, typename LabelT>
void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data,
size_t aligned_dim, std::vector<std::vector<LabelT>> &location_to_labels)
{
diskann::Timer insert_timer;
#pragma omp parallel for num_threads(thread_count) schedule(dynamic)
for (int64_t j = start; j < (int64_t)end; j++)
{
if (!location_to_labels.empty())
{
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
location_to_labels[j - start]);
}
else
{
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
}
}
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
<< " points/second overall, " << (end - start) / elapsedSeconds / thread_count << " per thread)\n ";
}
template <typename T, typename TagT>
void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params,
size_t points_to_skip, size_t points_to_delete_from_beginning)
{
try
{
std::cout << std::endl
<< "Lazy deleting points " << points_to_skip << " to "
<< points_to_skip + points_to_delete_from_beginning << "... ";
for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i)
index.lazy_delete(static_cast<TagT>(i + 1)); // Since tags are data location + 1
std::cout << "done." << std::endl;
auto report = index.consolidate_deletes(delete_params);
std::cout << "#active points: " << report._active_points << std::endl
<< "max points: " << report._max_points << std::endl
<< "empty slots: " << report._empty_slots << std::endl
<< "deletes processed: " << report._slots_released << std::endl
<< "latest delete size: " << report._delete_set_size << std::endl
<< "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, "
<< points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)"
<< std::endl;
}
catch (std::system_error &e)
{
std::cout << "Exception caught in deletion thread: " << e.what() << std::endl;
}
}
template <typename T>
void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters &params, size_t points_to_skip,
size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm,
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
const std::string &save_path, size_t points_to_delete_from_beginning,
size_t start_deletes_after, bool concurrent, const std::string &label_file,
const std::string &universal_label)
{
size_t dim, aligned_dim;
size_t num_points;
diskann::get_bin_metadata(data_path, num_points, dim);
aligned_dim = ROUND_UP(dim, 8);
bool has_labels = label_file != "";
using TagT = uint32_t;
using LabelT = uint32_t;
size_t current_point_offset = points_to_skip;
const size_t last_point_threshold = points_to_skip + max_points_to_insert;
bool enable_tags = true;
using TagT = uint32_t;
auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads);
diskann::IndexConfig index_config = diskann::IndexConfigBuilder()
.with_metric(diskann::L2)
.with_dimension(dim)
.with_max_points(max_points_to_insert)
.is_dynamic_index(true)
.with_index_write_params(params)
.with_index_search_params(index_search_params)
.with_data_type(diskann_type_to_name<T>())
.with_tag_type(diskann_type_to_name<TagT>())
.with_label_type(diskann_type_to_name<LabelT>())
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
.is_enable_tags(enable_tags)
.is_filtered(has_labels)
.with_num_frozen_pts(num_start_pts)
.is_concurrent_consolidate(concurrent)
.build();
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
auto index = index_factory.create_instance();
if (universal_label != "")
{
LabelT u_label = 0;
index->set_universal_label(u_label);
}
if (points_to_skip > num_points)
{
throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__);
}
if (max_points_to_insert == 0)
{
max_points_to_insert = num_points;
}
if (points_to_skip + max_points_to_insert > num_points)
{
max_points_to_insert = num_points - points_to_skip;
std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert
<< " points since the data file has only that many" << std::endl;
}
if (beginning_index_size > max_points_to_insert)
{
beginning_index_size = max_points_to_insert;
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size
<< " points since the data file has only that many" << std::endl;
}
if (checkpoints_per_snapshot > 0 && beginning_index_size > points_per_checkpoint)
{
beginning_index_size = points_per_checkpoint;
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size << std::endl;
}
T *data = nullptr;
diskann::alloc_aligned(
(void **)&data, std::max(points_per_checkpoint, beginning_index_size) * aligned_dim * sizeof(T), 8 * sizeof(T));
std::vector<TagT> tags(beginning_index_size);
std::iota(tags.begin(), tags.end(), 1 + static_cast<TagT>(current_point_offset));
load_aligned_bin_part(data_path, data, current_point_offset, beginning_index_size);
std::cout << "load aligned bin succeeded" << std::endl;
diskann::Timer timer;
if (beginning_index_size > 0)
{
index->build(data, beginning_index_size, tags);
}
else
{
index->set_start_points_at_random(static_cast<T>(start_point_norm));
}
const double elapsedSeconds = timer.elapsed() / 1000000.0;
std::cout << "Initial non-incremental index build time for " << beginning_index_size << " points took "
<< elapsedSeconds << " seconds (" << beginning_index_size / elapsedSeconds << " points/second)\n ";
current_point_offset += beginning_index_size;
if (points_to_delete_from_beginning > max_points_to_insert)
{
points_to_delete_from_beginning = static_cast<uint32_t>(max_points_to_insert);
std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning
<< " points since the data file has only that many" << std::endl;
}
std::vector<std::vector<LabelT>> location_to_labels;
if (concurrent)
{
// handle labels
const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip,
points_to_delete_from_beginning, last_point_threshold);
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
if (has_labels)
{
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
location_to_labels = std::get<0>(parse_result);
}
int32_t sub_threads = (params.num_threads + 1) / 2;
bool delete_launched = false;
std::future<void> delete_task;
diskann::Timer timer;
for (size_t start = current_point_offset; start < last_point_threshold;
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
{
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
auto insert_task = std::async(std::launch::async, [&]() {
load_aligned_bin_part(data_path, data, start, end - start);
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, sub_threads, data, aligned_dim,
location_to_labels);
});
insert_task.wait();
if (!delete_launched && end >= start_deletes_after &&
end >= points_to_skip + points_to_delete_from_beginning)
{
delete_launched = true;
diskann::IndexWriteParameters delete_params =
diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build();
delete_task = std::async(std::launch::async, [&]() {
delete_from_beginning<T, TagT>(*index, delete_params, points_to_skip,
points_to_delete_from_beginning);
});
}
}
delete_task.wait();
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
index->save(save_path_inc.c_str(), true);
}
else
{
const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip,
points_to_delete_from_beginning, last_point_threshold);
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
if (has_labels)
{
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
location_to_labels = std::get<0>(parse_result);
}
size_t last_snapshot_points_threshold = 0;
size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot;
for (size_t start = current_point_offset; start < last_point_threshold;
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
{
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
load_aligned_bin_part(data_path, data, start, end - start);
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, (int32_t)params.num_threads, data,
aligned_dim, location_to_labels);
if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0)
{
diskann::Timer save_timer;
const auto save_path_inc =
get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end);
index->save(save_path_inc.c_str(), false);
const double elapsedSeconds = save_timer.elapsed() / 1000000.0;
const size_t points_saved = end - points_to_skip;
std::cout << "Saved " << points_saved << " points in " << elapsedSeconds << " seconds ("
<< points_saved / elapsedSeconds << " points/second)\n";
num_checkpoints_till_snapshot = checkpoints_per_snapshot;
last_snapshot_points_threshold = end;
}
std::cout << "Number of points in the index post insertion " << end << std::endl;
}
if (checkpoints_per_snapshot > 0 && last_snapshot_points_threshold != last_point_threshold)
{
const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip,
points_to_delete_from_beginning, last_point_threshold);
// index.save(save_path_inc.c_str(), false);
}
if (points_to_delete_from_beginning > 0)
{
delete_from_beginning<T, TagT>(*index, params, points_to_skip, points_to_delete_from_beginning);
}
index->save(save_path_inc.c_str(), true);
}
diskann::aligned_free(data);
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix;
uint32_t num_threads, R, L, num_start_pts;
float alpha, start_point_norm;
size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot,
points_to_delete_from_beginning, start_deletes_after;
bool concurrent;
// label options
std::string label_file, label_type, universal_label;
std::uint32_t Lf, unique_labels_supported;
po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate",
"Test insert deletes & consolidate")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
required_configs.add_options()("points_to_skip", po::value<uint64_t>(&points_to_skip)->required(),
"Skip these first set of points from file");
required_configs.add_options()("beginning_index_size", po::value<uint64_t>(&beginning_index_size)->required(),
"Batch build will be called on these set of points");
required_configs.add_options()("points_per_checkpoint", po::value<uint64_t>(&points_per_checkpoint)->required(),
"Insertions are done in batches of points_per_checkpoint");
required_configs.add_options()("checkpoints_per_snapshot",
po::value<uint64_t>(&checkpoints_per_snapshot)->required(),
"Save the index to disk every few checkpoints");
required_configs.add_options()("points_to_delete_from_beginning",
po::value<uint64_t>(&points_to_delete_from_beginning)->required(), "");
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("max_points_to_insert",
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
"These number of points from the file are inserted after "
"points_to_skip");
optional_configs.add_options()("do_concurrent", po::value<bool>(&concurrent)->default_value(false), "");
optional_configs.add_options()("start_deletes_after",
po::value<uint64_t>(&start_deletes_after)->default_value(0), "");
optional_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->default_value(0),
"Set the start point to a random point on a sphere of this radius");
// optional params for filters
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
"Input label file in txt format for Filtered Index search. "
"The file should contain comma separated filters for each node "
"with each line corresponding to a graph node");
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
"Universal label, if using it, only in conjunction with labels_file");
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
"Build complexity for filtered points, higher value "
"results in better graphs");
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
"Storage type of Labels <uint/ushort>, default value is uint which "
"will consume memory 4 bytes per filter");
optional_configs.add_options()("unique_labels_supported",
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
"Number of unique labels supported by the dynamic index.");
optional_configs.add_options()(
"num_start_points",
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
"Set the number of random start (frozen) points to use when "
"inserting and searching");
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
if (beginning_index_size == 0)
if (start_point_norm == 0)
{
std::cout << "When beginning_index_size is 0, use a start "
"point with "
"appropriate norm"
<< std::endl;
return -1;
}
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
bool has_labels = false;
if (!label_file.empty() || label_file != "")
{
has_labels = true;
}
if (num_start_pts < unique_labels_supported)
{
num_start_pts = unique_labels_supported;
}
try
{
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
.with_max_occlusion_size(500)
.with_alpha(alpha)
.with_num_threads(num_threads)
.with_filter_list_size(Lf)
.build();
if (data_type == std::string("int8"))
build_incremental_index<int8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
else if (data_type == std::string("uint8"))
build_incremental_index<uint8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
else if (data_type == std::string("float"))
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
start_deletes_after, concurrent, label_file, universal_label);
else
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
}
catch (const std::exception &e)
{
std::cerr << "Caught exception: " << e.what() << std::endl;
exit(-1);
}
catch (...)
{
std::cerr << "Caught unknown exception" << std::endl;
exit(-1);
}
return 0;
}

View File

@@ -0,0 +1,523 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <index.h>
#include <numeric>
#include <omp.h>
#include <string.h>
#include <time.h>
#include <timer.h>
#include <boost/program_options.hpp>
#include <future>
#include <abstract_index.h>
#include <index_factory.h>
#include "utils.h"
#include "filter_utils.h"
#include "program_options_utils.hpp"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#include "memory_mapper.h"
namespace po = boost::program_options;
// load_aligned_bin modified to read pieces of the file, but using ifstream
// instead of cached_ifstream.
template <typename T>
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(bin_file, std::ios::binary | std::ios::ate);
size_t actual_file_size = reader.tellg();
reader.seekg(0, std::ios::beg);
int npts_i32, dim_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&dim_i32, sizeof(int));
size_t npts = (uint32_t)npts_i32;
size_t dim = (uint32_t)dim_i32;
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
if (actual_file_size != expected_actual_file_size)
{
std::stringstream stream;
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
<< std::endl;
std::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
if (offset_points + points_to_read > npts)
{
std::stringstream stream;
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
<< " points, but have only " << npts << " points" << std::endl;
std::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
const size_t rounded_dim = ROUND_UP(dim, 8);
for (size_t i = 0; i < points_to_read; i++)
{
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
}
reader.close();
}
std::string get_save_filename(const std::string &save_path, size_t active_window, size_t consolidate_interval,
size_t max_points_to_insert)
{
std::string final_path = save_path;
final_path += "act" + std::to_string(active_window) + "-";
final_path += "cons" + std::to_string(consolidate_interval) + "-";
final_path += "max" + std::to_string(max_points_to_insert);
return final_path;
}
template <typename T, typename TagT, typename LabelT>
void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data,
size_t aligned_dim, std::vector<std::vector<LabelT>> &pts_to_labels)
{
try
{
diskann::Timer insert_timer;
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
size_t num_failed = 0;
#pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed)
for (int64_t j = start; j < (int64_t)end; j++)
{
int insert_result = -1;
if (pts_to_labels.size() > 0)
{
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
pts_to_labels[j - start]);
}
else
{
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
}
if (insert_result != 0)
{
std::cerr << "Insert failed " << j << std::endl;
num_failed++;
}
}
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
<< " points/second overall, " << (end - start) / elapsedSeconds / insert_threads << " per thread)"
<< std::endl;
if (num_failed > 0)
std::cout << num_failed << " of " << end - start << "inserts failed" << std::endl;
}
catch (std::system_error &e)
{
std::cout << "Exiting after catching exception in insertion task: " << e.what() << std::endl;
exit(-1);
}
}
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start,
size_t end)
{
try
{
std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... ";
for (size_t i = start; i < end; ++i)
index.lazy_delete(static_cast<TagT>(1 + i));
std::cout << "lazy delete done." << std::endl;
auto report = index.consolidate_deletes(delete_params);
while (report._status != diskann::consolidation_report::status_code::SUCCESS)
{
int wait_time = 5;
if (report._status == diskann::consolidation_report::status_code::LOCK_FAIL)
{
diskann::cerr << "Unable to acquire consolidate delete lock after "
<< "deleting points " << start << " to " << end << ". Will retry in " << wait_time
<< "seconds." << std::endl;
}
else if (report._status == diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR)
{
diskann::cerr << "Inconsistent counts in data structure. "
<< "Will retry in " << wait_time << "seconds." << std::endl;
}
else
{
std::cerr << "Exiting after unknown error in consolidate delete" << std::endl;
exit(-1);
}
std::this_thread::sleep_for(std::chrono::seconds(wait_time));
report = index.consolidate_deletes(delete_params);
}
auto points_processed = report._active_points + report._slots_released;
auto deletion_rate = points_processed / report._time;
std::cout << "#active points: " << report._active_points << std::endl
<< "max points: " << report._max_points << std::endl
<< "empty slots: " << report._empty_slots << std::endl
<< "deletes processed: " << report._slots_released << std::endl
<< "latest delete size: " << report._delete_set_size << std::endl
<< "Deletion rate: " << deletion_rate << "/sec "
<< "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl;
}
catch (std::system_error &e)
{
std::cerr << "Exiting after catching exception in deletion task: " << e.what() << std::endl;
exit(-1);
}
}
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha,
const uint32_t insert_threads, const uint32_t consolidate_threads,
size_t max_points_to_insert, size_t active_window, size_t consolidate_interval,
const float start_point_norm, uint32_t num_start_pts, const std::string &save_path,
const std::string &label_file, const std::string &universal_label, const uint32_t Lf)
{
const uint32_t C = 500;
const bool saturate_graph = false;
bool has_labels = label_file != "";
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
.with_max_occlusion_size(C)
.with_alpha(alpha)
.with_saturate_graph(saturate_graph)
.with_num_threads(insert_threads)
.with_filter_list_size(Lf)
.build();
auto index_search_params = diskann::IndexSearchParams(L, insert_threads);
diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R)
.with_max_occlusion_size(C)
.with_alpha(alpha)
.with_saturate_graph(saturate_graph)
.with_num_threads(consolidate_threads)
.with_filter_list_size(Lf)
.build();
size_t dim, aligned_dim;
size_t num_points;
std::vector<std::vector<LabelT>> pts_to_labels;
const auto save_path_inc =
get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert);
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
if (has_labels)
{
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
pts_to_labels = std::get<0>(parse_result);
}
diskann::get_bin_metadata(data_path, num_points, dim);
diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims"
<< std::endl;
aligned_dim = ROUND_UP(dim, 8);
auto index_config = diskann::IndexConfigBuilder()
.with_metric(diskann::L2)
.with_dimension(dim)
.with_max_points(active_window + 4 * consolidate_interval)
.is_dynamic_index(true)
.is_enable_tags(true)
.is_use_opq(false)
.is_filtered(has_labels)
.with_num_pq_chunks(0)
.is_pq_dist_build(false)
.with_num_frozen_pts(num_start_pts)
.with_tag_type(diskann_type_to_name<TagT>())
.with_label_type(diskann_type_to_name<LabelT>())
.with_data_type(diskann_type_to_name<T>())
.with_index_write_params(params)
.with_index_search_params(index_search_params)
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
.build();
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
auto index = index_factory.create_instance();
if (universal_label != "")
{
LabelT u_label = 0;
index->set_universal_label(u_label);
}
if (max_points_to_insert == 0)
{
max_points_to_insert = num_points;
}
if (num_points < max_points_to_insert)
throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) +
") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")",
-1, __FUNCSIG__, __FILE__, __LINE__);
if (max_points_to_insert < active_window + consolidate_interval)
throw diskann::ANNException("ERROR: max_points_to_insert < "
"active_window + consolidate_interval",
-1, __FUNCSIG__, __FILE__, __LINE__);
if (consolidate_interval < max_points_to_insert / 1000)
throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__);
index->set_start_points_at_random(static_cast<T>(start_point_norm));
T *data = nullptr;
diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T),
8 * sizeof(T));
std::vector<TagT> tags(max_points_to_insert);
std::iota(tags.begin(), tags.end(), static_cast<TagT>(0));
diskann::Timer timer;
std::vector<std::future<void>> delete_tasks;
auto insert_task = std::async(std::launch::async, [&]() {
load_aligned_bin_part(data_path, data, 0, active_window);
insert_next_batch<T, TagT, LabelT>(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim,
pts_to_labels);
});
insert_task.wait();
for (size_t start = active_window; start + consolidate_interval <= max_points_to_insert;
start += consolidate_interval)
{
auto end = std::min(start + consolidate_interval, max_points_to_insert);
auto insert_task = std::async(std::launch::async, [&]() {
load_aligned_bin_part(data_path, data, start, end - start);
insert_next_batch<T, TagT, LabelT>(*index, start, end, params.num_threads, data, aligned_dim,
pts_to_labels);
});
insert_task.wait();
if (delete_tasks.size() > 0)
delete_tasks[delete_tasks.size() - 1].wait();
if (start >= active_window + consolidate_interval)
{
auto start_del = start - active_window - consolidate_interval;
auto end_del = start - active_window;
delete_tasks.emplace_back(std::async(std::launch::async, [&]() {
delete_and_consolidate<T, TagT, LabelT>(*index, delete_params, (size_t)start_del, (size_t)end_del);
}));
}
}
if (delete_tasks.size() > 0)
delete_tasks[delete_tasks.size() - 1].wait();
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
index->save(save_path_inc.c_str(), true);
diskann::aligned_free(data);
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, unique_labels_supported;
float alpha, start_point_norm;
size_t max_points_to_insert, active_window, consolidate_interval;
po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario",
"Test insert deletes & consolidate")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
required_configs.add_options()("active_window", po::value<uint64_t>(&active_window)->required(),
"Program maintains an index over an active window of "
"this size that slides through the data");
required_configs.add_options()("consolidate_interval", po::value<uint64_t>(&consolidate_interval)->required(),
"The program simultaneously adds this number of points to the "
"right of "
"the window while deleting the same number from the left");
required_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->required(),
"Set the start point to a random point on a sphere of this radius");
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("insert_threads",
po::value<uint32_t>(&insert_threads)->default_value(omp_get_num_procs() / 2),
"Number of threads used for inserting into the index (defaults to "
"omp_get_num_procs()/2)");
optional_configs.add_options()(
"consolidate_threads", po::value<uint32_t>(&consolidate_threads)->default_value(omp_get_num_procs() / 2),
"Number of threads used for consolidating deletes to "
"the index (defaults to omp_get_num_procs()/2)");
optional_configs.add_options()("max_points_to_insert",
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
"The number of points from the file that the program streams "
"over ");
optional_configs.add_options()(
"num_start_points",
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
"Set the number of random start (frozen) points to use when "
"inserting and searching");
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
"Input label file in txt format for Filtered Index search. "
"The file should contain comma separated filters for each node "
"with each line corresponding to a graph node");
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
"Universal label, if using it, only in conjunction with labels_file");
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
"Build complexity for filtered points, higher value "
"results in better graphs");
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
"Storage type of Labels <uint/ushort>, default value is uint which "
"will consume memory 4 bytes per filter");
optional_configs.add_options()("unique_labels_supported",
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
"Number of unique labels supported by the dynamic index.");
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
// Validate arguments
if (start_point_norm == 0)
{
std::cout << "When beginning_index_size is 0, use a start point with "
"appropriate norm"
<< std::endl;
return -1;
}
if (label_type != std::string("ushort") && label_type != std::string("uint"))
{
std::cerr << "Invalid label type. Supported types are uint and ushort" << std::endl;
return -1;
}
if (data_type != std::string("int8") && data_type != std::string("uint8") && data_type != std::string("float"))
{
std::cerr << "Invalid data type. Supported types are int8, uint8 and float" << std::endl;
return -1;
}
// TODO: Are additional distance functions supported?
if (dist_fn != std::string("l2") && dist_fn != std::string("mips"))
{
std::cerr << "Invalid distance function. Supported functions are l2 and mips" << std::endl;
return -1;
}
if (num_start_pts < unique_labels_supported)
{
num_start_pts = unique_labels_supported;
}
try
{
if (data_type == std::string("uint8"))
{
if (label_type == std::string("ushort"))
{
build_incremental_index<uint8_t, uint32_t, uint16_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
else if (label_type == std::string("uint"))
{
build_incremental_index<uint8_t, uint32_t, uint32_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
}
else if (data_type == std::string("int8"))
{
if (label_type == std::string("ushort"))
{
build_incremental_index<int8_t, uint32_t, uint16_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
else if (label_type == std::string("uint"))
{
build_incremental_index<int8_t, uint32_t, uint32_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
}
else if (data_type == std::string("float"))
{
if (label_type == std::string("ushort"))
{
build_incremental_index<float, uint32_t, uint16_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
else if (label_type == std::string("uint"))
{
build_incremental_index<float, uint32_t, uint32_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
}
}
catch (const std::exception &e)
{
std::cerr << "Caught exception: " << e.what() << std::endl;
exit(-1);
}
catch (...)
{
std::cerr << "Caught unknown exception" << std::endl;
exit(-1);
}
return 0;
}

View File

@@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
add_executable(fvecs_to_bin fvecs_to_bin.cpp)
add_executable(fvecs_to_bvecs fvecs_to_bvecs.cpp)
add_executable(rand_data_gen rand_data_gen.cpp)
target_link_libraries(rand_data_gen ${PROJECT_NAME} Boost::program_options)
add_executable(float_bin_to_int8 float_bin_to_int8.cpp)
add_executable(ivecs_to_bin ivecs_to_bin.cpp)
add_executable(count_bfs_levels count_bfs_levels.cpp)
target_link_libraries(count_bfs_levels ${PROJECT_NAME} Boost::program_options)
add_executable(tsv_to_bin tsv_to_bin.cpp)
add_executable(bin_to_tsv bin_to_tsv.cpp)
add_executable(int8_to_float int8_to_float.cpp)
target_link_libraries(int8_to_float ${PROJECT_NAME})
add_executable(int8_to_float_scale int8_to_float_scale.cpp)
target_link_libraries(int8_to_float_scale ${PROJECT_NAME})
add_executable(uint8_to_float uint8_to_float.cpp)
target_link_libraries(uint8_to_float ${PROJECT_NAME})
add_executable(uint32_to_uint8 uint32_to_uint8.cpp)
target_link_libraries(uint32_to_uint8 ${PROJECT_NAME})
add_executable(vector_analysis vector_analysis.cpp)
target_link_libraries(vector_analysis ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(gen_random_slice gen_random_slice.cpp)
target_link_libraries(gen_random_slice ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(simulate_aggregate_recall simulate_aggregate_recall.cpp)
add_executable(calculate_recall calculate_recall.cpp)
target_link_libraries(calculate_recall ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
# Compute ground truth thing outside of DiskANN main source that depends on MKL.
add_executable(compute_groundtruth compute_groundtruth.cpp)
target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp)
target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
add_executable(generate_pq generate_pq.cpp)
target_link_libraries(generate_pq ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(partition_data partition_data.cpp)
target_link_libraries(partition_data ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(partition_with_ram_budget partition_with_ram_budget.cpp)
target_link_libraries(partition_with_ram_budget ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(merge_shards merge_shards.cpp)
target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB})
add_executable(create_disk_layout create_disk_layout.cpp)
target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(generate_synthetic_labels generate_synthetic_labels.cpp)
target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options)
add_executable(stats_label_data stats_label_data.cpp)
target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options)
if (NOT MSVC)
include(GNUInstallDirs)
install(TARGETS fvecs_to_bin
fvecs_to_bvecs
rand_data_gen
float_bin_to_int8
ivecs_to_bin
count_bfs_levels
tsv_to_bin
bin_to_tsv
int8_to_float
int8_to_float_scale
uint8_to_float
uint32_to_uint8
vector_analysis
gen_random_slice
simulate_aggregate_recall
calculate_recall
compute_groundtruth
compute_groundtruth_for_filters
generate_pq
partition_data
partition_with_ram_budget
merge_shards
create_disk_layout
generate_synthetic_labels
stats_label_data
RUNTIME
)
endif()

View File

@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "util.h"
void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts,
uint64_t ndims)
{
writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned)));
#pragma omp parallel for
for (uint64_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
}
readr.read((char *)write_buf, npts * ndims * sizeof(float));
}
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_bin output_fvecs" << std::endl;
exit(-1);
}
std::ifstream readr(argv[1], std::ios::binary);
int npts_s32;
int ndims_s32;
readr.read((char *)&npts_s32, sizeof(int32_t));
readr.read((char *)&ndims_s32, sizeof(int32_t));
size_t npts = npts_s32;
size_t ndims = ndims_s32;
uint32_t ndims_u32 = (uint32_t)ndims_s32;
// uint64_t fsize = writr.tellg();
readr.seekg(0, std::ios::beg);
unsigned ndims_u32;
writr.write((char *)&ndims_u32, sizeof(unsigned));
writr.seekg(0, std::ios::beg);
uint64_t ndims = (uint64_t)ndims_u32;
uint64_t npts = fsize / ((ndims + 1) * sizeof(float));
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
uint64_t blk_size = 131072;
uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writr(argv[2], std::ios::binary);
float *read_buf = new float[npts * (ndims + 1)];
float *write_buf = new float[npts * ndims];
for (uint64_t i = 0; i < nblks; i++)
{
uint64_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
writr.close();
readr.close();
}

View File

@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
template <class T>
void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims)
{
reader.read((char *)read_buf, npts * ndims * sizeof(float));
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; d++)
{
writer << read_buf[d + i * ndims];
if (d < ndims - 1)
writer << "\t";
else
writer << "\n";
}
}
}
int main(int argc, char **argv)
{
if (argc != 4)
{
std::cout << argv[0] << " <float/int8/uint8> input_bin output_tsv" << std::endl;
exit(-1);
}
std::string type_string(argv[1]);
if ((type_string != std::string("float")) && (type_string != std::string("int8")) &&
(type_string != std::string("uin8")))
{
std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl;
}
std::ifstream reader(argv[2], std::ios::binary);
uint32_t npts_u32;
uint32_t ndims_u32;
reader.read((char *)&npts_u32, sizeof(uint32_t));
reader.read((char *)&ndims_u32, sizeof(uint32_t));
size_t npts = npts_u32;
size_t ndims = ndims_u32;
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::ofstream writer(argv[3]);
char *read_buf = new char[blk_size * ndims * 4];
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
if (type_string == std::string("float"))
block_convert<float>(writer, reader, (float *)read_buf, cblk_size, ndims);
else if (type_string == std::string("int8"))
block_convert<int8_t>(writer, reader, (int8_t *)read_buf, cblk_size, ndims);
else if (type_string == std::string("uint8"))
block_convert<uint8_t>(writer, reader, (uint8_t *)read_buf, cblk_size, ndims);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
writer.close();
reader.close();
}

View File

@@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <cstddef>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <set>
#include <string>
#include <vector>
#include "utils.h"
#include "disk_utils.h"
int main(int argc, char **argv)
{
if (argc != 4)
{
std::cout << argv[0] << " <ground_truth_bin> <our_results_bin> <r> " << std::endl;
return -1;
}
uint32_t *gold_std = NULL;
float *gs_dist = nullptr;
uint32_t *our_results = NULL;
float *or_dist = nullptr;
size_t points_num, points_num_gs, points_num_or;
size_t dim_gs;
size_t dim_or;
diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs);
diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or);
if (points_num_gs != points_num_or)
{
std::cout << "Error. Number of queries mismatch in ground truth and "
"our results"
<< std::endl;
return -1;
}
points_num = points_num_gs;
uint32_t recall_at = std::atoi(argv[3]);
if ((dim_or < recall_at) || (recall_at > dim_gs))
{
std::cout << "ground truth has size " << dim_gs << "; our set has " << dim_or << " points. Asking for recall "
<< recall_at << std::endl;
return -1;
}
std::cout << "Calculating recall@" << recall_at << std::endl;
double recall_val = diskann::calculate_recall((uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs,
our_results, (uint32_t)dim_or, (uint32_t)recall_at);
// double avg_recall = (recall*1.0)/(points_num*1.0);
std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n";
}

View File

@@ -0,0 +1,574 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <string>
#include <iostream>
#include <fstream>
#include <cassert>
#include <vector>
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <random>
#include <limits>
#include <cstring>
#include <queue>
#include <omp.h>
#include <mkl.h>
#include <boost/program_options.hpp>
#include <unordered_map>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>
#ifdef _WINDOWS
#include <malloc.h>
#else
#include <stdlib.h>
#endif
#include "filter_utils.h"
#include "utils.h"
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
#define PARTSIZE 10000000
#define ALIGNMENT 512
// custom types (for readability)
typedef tsl::robin_set<std::string> label_set;
typedef std::string path;
namespace po = boost::program_options;
template <class T> T div_round_up(const T numerator, const T denominator)
{
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
}
using pairIF = std::pair<size_t, float>;
struct cmpmaxstruct
{
bool operator()(const pairIF &l, const pairIF &r)
{
return l.second < r.second;
};
};
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
{
#ifdef _WINDOWS
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
#else
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
#endif
}
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
{
return a.second < b.second;
}
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
{
assert(points_l2sq != NULL);
#pragma omp parallel for schedule(static, 65536)
for (int64_t d = 0; d < num_points; ++d)
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
}
void distsq_to_points(const size_t dim,
float *dist_matrix, // Col Major, cols are queries, rows are points
size_t npoints, const float *const points,
const float *const points_l2sq, // points in Col major
size_t nqueries, const float *const queries,
const float *const queries_l2sq, // queries in Col major
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
{
bool ones_vec_alloc = false;
if (ones_vec == NULL)
{
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
ones_vec_alloc = true;
}
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
(float)0.0, dist_matrix, npoints);
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
if (ones_vec_alloc)
delete[] ones_vec;
}
void inner_prod_to_points(const size_t dim,
float *dist_matrix, // Col Major, cols are queries, rows are points
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
{
bool ones_vec_alloc = false;
if (ones_vec == NULL)
{
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
ones_vec_alloc = true;
}
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
(float)0.0, dist_matrix, npoints);
if (ones_vec_alloc)
delete[] ones_vec;
}
void exact_knn(const size_t dim, const size_t k,
size_t *const closest_points, // k * num_queries preallocated, col
// major, queries columns
float *const dist_closest_points, // k * num_queries
// preallocated, Dist to
// corresponding closes_points
size_t npoints,
float *points_in, // points in Col major
size_t nqueries, float *queries_in,
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
{
float *points_l2sq = new float[npoints];
float *queries_l2sq = new float[nqueries];
compute_l2sq(points_l2sq, points_in, npoints, dim);
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
float *points = points_in;
float *queries = queries_in;
if (metric == diskann::Metric::COSINE)
{ // we convert cosine distance as
// normalized L2 distnace
points = new float[npoints * dim];
queries = new float[nqueries * dim];
#pragma omp parallel for schedule(static, 4096)
for (int64_t i = 0; i < (int64_t)npoints; i++)
{
float norm = std::sqrt(points_l2sq[i]);
if (norm == 0)
{
norm = std::numeric_limits<float>::epsilon();
}
for (uint32_t j = 0; j < dim; j++)
{
points[i * dim + j] = points_in[i * dim + j] / norm;
}
}
#pragma omp parallel for schedule(static, 4096)
for (int64_t i = 0; i < (int64_t)nqueries; i++)
{
float norm = std::sqrt(queries_l2sq[i]);
if (norm == 0)
{
norm = std::numeric_limits<float>::epsilon();
}
for (uint32_t j = 0; j < dim; j++)
{
queries[i * dim + j] = queries_in[i * dim + j] / norm;
}
}
// recalculate norms after normalizing, they should all be one.
compute_l2sq(points_l2sq, points, npoints, dim);
compute_l2sq(queries_l2sq, queries, nqueries, dim);
}
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
<< dim << " dimensions using";
if (metric == diskann::Metric::INNER_PRODUCT)
std::cout << " MIPS ";
else if (metric == diskann::Metric::COSINE)
std::cout << " Cosine ";
else
std::cout << " L2 ";
std::cout << "distance fn. " << std::endl;
size_t q_batch_size = (1 << 9);
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
{
int64_t q_b = b * q_batch_size;
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
{
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
}
else
{
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
}
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
#pragma omp parallel for schedule(dynamic, 16)
for (long long q = q_b; q < q_e; q++)
{
maxPQIFCS point_dist;
for (size_t p = 0; p < k; p++)
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
for (size_t p = k; p < npoints; p++)
{
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
if (point_dist.size() > k)
point_dist.pop();
}
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
{
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
point_dist.pop();
}
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
}
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
}
delete[] dist_matrix;
delete[] points_l2sq;
delete[] queries_l2sq;
if (metric == diskann::Metric::COSINE)
{
delete[] points;
delete[] queries;
}
}
template <typename T> inline int get_num_parts(const char *filename)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(filename, std::ios::binary);
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
reader.close();
uint32_t num_parts =
(npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
std::cout << "Number of parts: " << num_parts << std::endl;
return num_parts;
}
template <typename T>
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(filename, std::ios::binary);
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
uint64_t start_id = part_num * PARTSIZE;
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
npts = end_id - start_id;
ndims = (uint64_t)ndims_i32;
std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B"
<< std::endl;
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
T *data_T = new T[npts * ndims];
reader.read((char *)data_T, sizeof(T) * npts * ndims);
std::cout << "Finished reading part of the bin file." << std::endl;
reader.close();
data = aligned_malloc<float>(npts * ndims, ALIGNMENT);
#pragma omp parallel for schedule(dynamic, 32768)
for (int64_t i = 0; i < (int64_t)npts; i++)
{
for (int64_t j = 0; j < (int64_t)ndims; j++)
{
float cur_val_float = (float)data_T[i * ndims + j];
std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float));
}
}
delete[] data_T;
std::cout << "Finished converting part data to float." << std::endl;
}
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
{
std::ofstream writer;
writer.exceptions(std::ios::failbit | std::ios::badbit);
writer.open(filename, std::ios::binary | std::ios::out);
std::cout << "Writing bin: " << filename << "\n";
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
writer.write((char *)&npts_i32, sizeof(int));
writer.write((char *)&ndims_i32, sizeof(int));
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
writer.write((char *)data, npts * ndims * sizeof(T));
writer.close();
std::cout << "Finished writing bin" << std::endl;
}
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
size_t ndims)
{
std::ofstream writer(filename, std::ios::binary | std::ios::out);
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
writer.write((char *)&npts_i32, sizeof(int));
writer.write((char *)&ndims_i32, sizeof(int));
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
"npts*dim dist-matrix) with npts = "
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
<< "B" << std::endl;
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
writer.write((char *)distances, npts * ndims * sizeof(float));
writer.close();
std::cout << "Finished writing truthset" << std::endl;
}
template <typename T>
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
size_t &nqueries, size_t &npoints,
size_t &dim, size_t &k, float *query_data,
const diskann::Metric &metric,
std::vector<uint32_t> &location_to_tag)
{
float *base_data = nullptr;
int num_parts = get_num_parts<T>(base_file.c_str());
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
for (int p = 0; p < num_parts; p++)
{
size_t start_id = p * PARTSIZE;
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
size_t *closest_points_part = new size_t[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];
auto part_k = k < npoints ? k : npoints;
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
metric);
for (size_t i = 0; i < nqueries; i++)
{
for (size_t j = 0; j < part_k; j++)
{
if (!location_to_tag.empty())
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
dist_closest_points_part[i * part_k + j]));
}
}
delete[] closest_points_part;
delete[] dist_closest_points_part;
diskann::aligned_free(base_data);
}
return res;
};
template <typename T>
int aux_main(const std::string &base_file, const std::string &query_file, const std::string &gt_file, size_t k,
const diskann::Metric &metric, const std::string &tags_file = std::string(""))
{
size_t npoints, nqueries, dim;
float *query_data;
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
if (nqueries > PARTSIZE)
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
// load tags
const bool tags_enabled = tags_file.empty() ? false : true;
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
int *closest_points = new int[nqueries * k];
float *dist_closest_points = new float[nqueries * k];
std::vector<std::vector<std::pair<uint32_t, float>>> results =
processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
for (size_t i = 0; i < nqueries; i++)
{
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
size_t j = 0;
for (auto iter : cur_res)
{
if (j == k)
break;
if (tags_enabled)
{
std::uint32_t index_with_tag = location_to_tag[iter.first];
closest_points[i * k + j] = (int32_t)index_with_tag;
}
else
{
closest_points[i * k + j] = (int32_t)iter.first;
}
if (metric == diskann::Metric::INNER_PRODUCT)
dist_closest_points[i * k + j] = -iter.second;
else
dist_closest_points[i * k + j] = iter.second;
++j;
}
if (j < k)
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
}
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
delete[] closest_points;
delete[] dist_closest_points;
diskann::aligned_free(query_data);
return 0;
}
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
{
size_t read_blk_size = 64 * 1024 * 1024;
cached_ifstream reader(bin_file, read_blk_size);
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
size_t actual_file_size = reader.get_file_size();
int npts_i32, dim_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&dim_i32, sizeof(int));
npts = (uint32_t)npts_i32;
dim = (uint32_t)dim_i32;
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
// only ids, -1 is error
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
if (actual_file_size == expected_file_size_with_dists)
truthset_type = 1;
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
if (actual_file_size == expected_file_size_just_ids)
truthset_type = 2;
if (truthset_type == -1)
{
std::stringstream stream;
stream << "Error. File size mismatch. File should have bin format, with "
"npts followed by ngt followed by npts*ngt ids and optionally "
"followed by npts*ngt distance values; actual size: "
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
<< expected_file_size_just_ids;
diskann::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
ids = new uint32_t[npts * dim];
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
if (truthset_type == 1)
{
dists = new float[npts * dim];
reader.read((char *)dists, npts * dim * sizeof(float));
}
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file;
uint64_t K;
try
{
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
"distance function <l2/mips/cosine>");
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
"File containing the base vectors in binary format");
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
"File containing the query vectors in binary format");
desc.add_options()("gt_file", po::value<std::string>(&gt_file)->required(),
"File name for the writing ground truth in binary "
"format, please don' append .bin at end if "
"no filter_label or filter_label_file is provided it "
"will save the file with '.bin' at end."
"else it will save the file as filename_label.bin");
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
"Number of ground truth nearest neighbors to compute");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"File containing the tags in binary format");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
{
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
return -1;
}
try
{
if (data_type == std::string("float"))
aux_main<float>(base_file, query_file, gt_file, K, metric, tags_file);
if (data_type == std::string("int8"))
aux_main<int8_t>(base_file, query_file, gt_file, K, metric, tags_file);
if (data_type == std::string("uint8"))
aux_main<uint8_t>(base_file, query_file, gt_file, K, metric, tags_file);
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Compute GT failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,919 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <string>
#include <iostream>
#include <fstream>
#include <cassert>
#include <vector>
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <random>
#include <limits>
#include <cstring>
#include <queue>
#include <omp.h>
#include <mkl.h>
#include <boost/program_options.hpp>
#include <unordered_map>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>
#ifdef _WINDOWS
#include <malloc.h>
#else
#include <stdlib.h>
#endif
#include "filter_utils.h"
#include "utils.h"
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
#define PARTSIZE 10000000
#define ALIGNMENT 512
// custom types (for readability)
typedef tsl::robin_set<std::string> label_set;
typedef std::string path;
namespace po = boost::program_options;
template <class T> T div_round_up(const T numerator, const T denominator)
{
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
}
using pairIF = std::pair<size_t, float>;
struct cmpmaxstruct
{
bool operator()(const pairIF &l, const pairIF &r)
{
return l.second < r.second;
};
};
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
{
#ifdef _WINDOWS
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
#else
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
#endif
}
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
{
return a.second < b.second;
}
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
{
assert(points_l2sq != NULL);
#pragma omp parallel for schedule(static, 65536)
for (int64_t d = 0; d < num_points; ++d)
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
}
void distsq_to_points(const size_t dim,
float *dist_matrix, // Col Major, cols are queries, rows are points
size_t npoints, const float *const points,
const float *const points_l2sq, // points in Col major
size_t nqueries, const float *const queries,
const float *const queries_l2sq, // queries in Col major
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
{
bool ones_vec_alloc = false;
if (ones_vec == NULL)
{
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
ones_vec_alloc = true;
}
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
(float)0.0, dist_matrix, npoints);
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
if (ones_vec_alloc)
delete[] ones_vec;
}
void inner_prod_to_points(const size_t dim,
float *dist_matrix, // Col Major, cols are queries, rows are points
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
{
bool ones_vec_alloc = false;
if (ones_vec == NULL)
{
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
ones_vec_alloc = true;
}
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
(float)0.0, dist_matrix, npoints);
if (ones_vec_alloc)
delete[] ones_vec;
}
void exact_knn(const size_t dim, const size_t k,
size_t *const closest_points, // k * num_queries preallocated, col
// major, queries columns
float *const dist_closest_points, // k * num_queries
// preallocated, Dist to
// corresponding closes_points
size_t npoints,
float *points_in, // points in Col major
size_t nqueries, float *queries_in,
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
{
float *points_l2sq = new float[npoints];
float *queries_l2sq = new float[nqueries];
compute_l2sq(points_l2sq, points_in, npoints, dim);
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
float *points = points_in;
float *queries = queries_in;
if (metric == diskann::Metric::COSINE)
{ // we convert cosine distance as
// normalized L2 distnace
points = new float[npoints * dim];
queries = new float[nqueries * dim];
#pragma omp parallel for schedule(static, 4096)
for (int64_t i = 0; i < (int64_t)npoints; i++)
{
float norm = std::sqrt(points_l2sq[i]);
if (norm == 0)
{
norm = std::numeric_limits<float>::epsilon();
}
for (uint32_t j = 0; j < dim; j++)
{
points[i * dim + j] = points_in[i * dim + j] / norm;
}
}
#pragma omp parallel for schedule(static, 4096)
for (int64_t i = 0; i < (int64_t)nqueries; i++)
{
float norm = std::sqrt(queries_l2sq[i]);
if (norm == 0)
{
norm = std::numeric_limits<float>::epsilon();
}
for (uint32_t j = 0; j < dim; j++)
{
queries[i * dim + j] = queries_in[i * dim + j] / norm;
}
}
// recalculate norms after normalizing, they should all be one.
compute_l2sq(points_l2sq, points, npoints, dim);
compute_l2sq(queries_l2sq, queries, nqueries, dim);
}
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
<< dim << " dimensions using";
if (metric == diskann::Metric::INNER_PRODUCT)
std::cout << " MIPS ";
else if (metric == diskann::Metric::COSINE)
std::cout << " Cosine ";
else
std::cout << " L2 ";
std::cout << "distance fn. " << std::endl;
size_t q_batch_size = (1 << 9);
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
{
int64_t q_b = b * q_batch_size;
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
{
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
}
else
{
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
}
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
#pragma omp parallel for schedule(dynamic, 16)
for (long long q = q_b; q < q_e; q++)
{
maxPQIFCS point_dist;
for (size_t p = 0; p < k; p++)
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
for (size_t p = k; p < npoints; p++)
{
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
if (point_dist.size() > k)
point_dist.pop();
}
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
{
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
point_dist.pop();
}
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
}
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
}
delete[] dist_matrix;
delete[] points_l2sq;
delete[] queries_l2sq;
if (metric == diskann::Metric::COSINE)
{
delete[] points;
delete[] queries;
}
}
template <typename T> inline int get_num_parts(const char *filename)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(filename, std::ios::binary);
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
reader.close();
int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
std::cout << "Number of parts: " << num_parts << std::endl;
return num_parts;
}
template <typename T>
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(filename, std::ios::binary);
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
uint64_t start_id = part_num * PARTSIZE;
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
npts_u64 = end_id - start_id;
ndims_u64 = (uint64_t)ndims_i32;
std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64
<< ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl;
reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
T *data_T = new T[npts_u64 * ndims_u64];
reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64);
std::cout << "Finished reading part of the bin file." << std::endl;
reader.close();
data = aligned_malloc<float>(npts_u64 * ndims_u64, ALIGNMENT);
#pragma omp parallel for schedule(dynamic, 32768)
for (int64_t i = 0; i < (int64_t)npts_u64; i++)
{
for (int64_t j = 0; j < (int64_t)ndims_u64; j++)
{
float cur_val_float = (float)data_T[i * ndims_u64 + j];
std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float));
}
}
delete[] data_T;
std::cout << "Finished converting part data to float." << std::endl;
}
template <typename T>
inline std::vector<size_t> load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims,
int part_num, const char *label_file,
const std::string &filter_label,
const std::string &universal_label, size_t &npoints_filt,
std::vector<std::vector<std::string>> &pts_to_labels)
{
std::ifstream reader(filename, std::ios::binary);
if (reader.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
}
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
std::vector<size_t> rev_map;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
uint64_t start_id = part_num * PARTSIZE;
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
npts = end_id - start_id;
ndims = (uint32_t)ndims_i32;
uint64_t nptsuint64_t = (uint64_t)npts;
uint64_t ndimsuint64_t = (uint64_t)ndims;
npoints_filt = 0;
std::cout << "#pts in part = " << npts << ", #dims = " << ndims
<< ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl;
std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl;
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
T *data_T = new T[nptsuint64_t * ndimsuint64_t];
reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t);
std::cout << "Finished reading part of the bin file." << std::endl;
reader.close();
data = aligned_malloc<float>(nptsuint64_t * ndimsuint64_t, ALIGNMENT);
for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++)
{
if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) !=
pts_to_labels[start_id + i].end() ||
std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) !=
pts_to_labels[start_id + i].end())
{
rev_map.push_back(start_id + i);
for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++)
{
float cur_val_float = (float)data_T[i * ndimsuint64_t + j];
std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float));
}
npoints_filt++;
}
}
delete[] data_T;
std::cout << "Finished converting part data to float.. identified " << npoints_filt
<< " points matching the filter." << std::endl;
return rev_map;
}
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
{
std::ofstream writer;
writer.exceptions(std::ios::failbit | std::ios::badbit);
writer.open(filename, std::ios::binary | std::ios::out);
std::cout << "Writing bin: " << filename << "\n";
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
writer.write((char *)&npts_i32, sizeof(int));
writer.write((char *)&ndims_i32, sizeof(int));
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
writer.write((char *)data, npts * ndims * sizeof(T));
writer.close();
std::cout << "Finished writing bin" << std::endl;
}
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
size_t ndims)
{
std::ofstream writer(filename, std::ios::binary | std::ios::out);
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
writer.write((char *)&npts_i32, sizeof(int));
writer.write((char *)&ndims_i32, sizeof(int));
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
"npts*dim dist-matrix) with npts = "
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
<< "B" << std::endl;
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
writer.write((char *)distances, npts * ndims * sizeof(float));
writer.close();
std::cout << "Finished writing truthset" << std::endl;
}
inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file,
std::vector<std::vector<std::string>> &pts_to_labels)
{
std::ifstream infile(map_file);
std::string line, token;
std::set<std::string> labels;
infile.clear();
infile.seekg(0, std::ios::beg);
while (std::getline(infile, line))
{
std::istringstream iss(line);
std::vector<std::string> lbls(0);
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
lbls.push_back(token);
labels.insert(token);
}
std::sort(lbls.begin(), lbls.end());
pts_to_labels.push_back(lbls);
}
std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for "
<< pts_to_labels.size() << " points" << std::endl;
}
template <typename T>
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
size_t &nqueries, size_t &npoints,
size_t &dim, size_t &k, float *query_data,
const diskann::Metric &metric,
std::vector<uint32_t> &location_to_tag)
{
float *base_data = nullptr;
int num_parts = get_num_parts<T>(base_file.c_str());
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
for (int p = 0; p < num_parts; p++)
{
size_t start_id = p * PARTSIZE;
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
size_t *closest_points_part = new size_t[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];
auto part_k = k < npoints ? k : npoints;
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
metric);
for (size_t i = 0; i < nqueries; i++)
{
for (uint64_t j = 0; j < part_k; j++)
{
if (!location_to_tag.empty())
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
dist_closest_points_part[i * part_k + j]));
}
}
delete[] closest_points_part;
delete[] dist_closest_points_part;
diskann::aligned_free(base_data);
}
return res;
};
template <typename T>
std::vector<std::vector<std::pair<uint32_t, float>>> processFilteredParts(
const std::string &base_file, const std::string &label_file, const std::string &filter_label,
const std::string &universal_label, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data,
const diskann::Metric &metric, std::vector<uint32_t> &location_to_tag)
{
size_t npoints_filt = 0;
float *base_data = nullptr;
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
int num_parts = get_num_parts<T>(base_file.c_str());
std::vector<std::vector<std::string>> pts_to_labels;
if (filter_label != "")
parse_label_file_into_vec(npoints, label_file, pts_to_labels);
for (int p = 0; p < num_parts; p++)
{
size_t start_id = p * PARTSIZE;
std::vector<size_t> rev_map;
if (filter_label != "")
rev_map = load_filtered_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(),
filter_label, universal_label, npoints_filt, pts_to_labels);
size_t *closest_points_part = new size_t[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];
auto part_k = k < npoints_filt ? k : npoints_filt;
if (npoints_filt > 0)
{
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries,
query_data, metric);
}
for (size_t i = 0; i < nqueries; i++)
{
for (uint64_t j = 0; j < part_k; j++)
{
if (!location_to_tag.empty())
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;
res[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]),
dist_closest_points_part[i * part_k + j]));
}
}
delete[] closest_points_part;
delete[] dist_closest_points_part;
diskann::aligned_free(base_data);
}
return res;
};
template <typename T>
int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file,
const std::string &gt_file, size_t k, const std::string &universal_label, const diskann::Metric &metric,
const std::string &filter_label, const std::string &tags_file = std::string(""))
{
size_t npoints, nqueries, dim;
float *query_data = nullptr;
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
if (nqueries > PARTSIZE)
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
// load tags
const bool tags_enabled = tags_file.empty() ? false : true;
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
int *closest_points = new int[nqueries * k];
float *dist_closest_points = new float[nqueries * k];
std::vector<std::vector<std::pair<uint32_t, float>>> results;
if (filter_label == "")
{
results = processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
}
else
{
results = processFilteredParts<T>(base_file, label_file, filter_label, universal_label, nqueries, npoints, dim,
k, query_data, metric, location_to_tag);
}
for (size_t i = 0; i < nqueries; i++)
{
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
size_t j = 0;
for (auto iter : cur_res)
{
if (j == k)
break;
if (tags_enabled)
{
std::uint32_t index_with_tag = location_to_tag[iter.first];
closest_points[i * k + j] = (int32_t)index_with_tag;
}
else
{
closest_points[i * k + j] = (int32_t)iter.first;
}
if (metric == diskann::Metric::INNER_PRODUCT)
dist_closest_points[i * k + j] = -iter.second;
else
dist_closest_points[i * k + j] = iter.second;
++j;
}
if (j < k)
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
}
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
delete[] closest_points;
delete[] dist_closest_points;
diskann::aligned_free(query_data);
return 0;
}
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
{
size_t read_blk_size = 64 * 1024 * 1024;
cached_ifstream reader(bin_file, read_blk_size);
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
size_t actual_file_size = reader.get_file_size();
int npts_i32, dim_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&dim_i32, sizeof(int));
npts = (uint32_t)npts_i32;
dim = (uint32_t)dim_i32;
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
// only ids, -1 is error
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
if (actual_file_size == expected_file_size_with_dists)
truthset_type = 1;
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
if (actual_file_size == expected_file_size_just_ids)
truthset_type = 2;
if (truthset_type == -1)
{
std::stringstream stream;
stream << "Error. File size mismatch. File should have bin format, with "
"npts followed by ngt followed by npts*ngt ids and optionally "
"followed by npts*ngt distance values; actual size: "
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
<< expected_file_size_just_ids;
diskann::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
ids = new uint32_t[npts * dim];
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
if (truthset_type == 1)
{
dists = new float[npts * dim];
reader.read((char *)dists, npts * dim * sizeof(float));
}
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label,
universal_label, filter_label_file;
uint64_t K;
try
{
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(), "distance function <l2/mips>");
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
"File containing the base vectors in binary format");
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
"File containing the query vectors in binary format");
desc.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
"Input labels file in txt format if present");
desc.add_options()("filter_label", po::value<std::string>(&filter_label)->default_value(""),
"Input filter label if doing filtered groundtruth");
desc.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
"Universal label, if using it, only in conjunction with label_file");
desc.add_options()("gt_file", po::value<std::string>(&gt_file)->required(),
"File name for the writing ground truth in binary "
"format, please don' append .bin at end if "
"no filter_label or filter_label_file is provided it "
"will save the file with '.bin' at end."
"else it will save the file as filename_label.bin");
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
"Number of ground truth nearest neighbors to compute");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"File containing the tags in binary format");
desc.add_options()("filter_label_file",
po::value<std::string>(&filter_label_file)->default_value(std::string("")),
"Filter file for Queries for Filtered Search ");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
{
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
return -1;
}
if (filter_label != "" && filter_label_file != "")
{
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
return -1;
}
std::vector<std::string> filter_labels;
if (filter_label != "")
{
filter_labels.push_back(filter_label);
}
else if (filter_label_file != "")
{
filter_labels = read_file_to_vector_of_strings(filter_label_file, false);
}
// only when there is no filter label or 1 filter label for all queries
if (filter_labels.size() == 1)
{
try
{
if (data_type == std::string("float"))
aux_main<float>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
filter_labels[0], tags_file);
if (data_type == std::string("int8"))
aux_main<int8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
filter_labels[0], tags_file);
if (data_type == std::string("uint8"))
aux_main<uint8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
filter_labels[0], tags_file);
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Compute GT failed." << std::endl;
return -1;
}
}
else
{ // Each query has its own filter label
// Split up data and query bins into label specific ones
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
tsl::robin_map<std::string, uint32_t> labels_to_number_of_queries;
label_set all_labels;
for (size_t i = 0; i < filter_labels.size(); i++)
{
std::string label = filter_labels[i];
all_labels.insert(label);
if (labels_to_number_of_queries.find(label) == labels_to_number_of_queries.end())
{
labels_to_number_of_queries[label] = 0;
}
labels_to_number_of_queries[label] += 1;
}
size_t npoints;
std::vector<std::vector<std::string>> point_to_labels;
parse_label_file_into_vec(npoints, label_file, point_to_labels);
std::vector<label_set> point_ids_to_labels(point_to_labels.size());
std::vector<label_set> query_ids_to_labels(filter_labels.size());
for (size_t i = 0; i < point_to_labels.size(); i++)
{
for (size_t j = 0; j < point_to_labels[i].size(); j++)
{
std::string label = point_to_labels[i][j];
if (all_labels.find(label) != all_labels.end())
{
point_ids_to_labels[i].insert(point_to_labels[i][j]);
if (labels_to_number_of_points.find(label) == labels_to_number_of_points.end())
{
labels_to_number_of_points[label] = 0;
}
labels_to_number_of_points[label] += 1;
}
}
}
for (size_t i = 0; i < filter_labels.size(); i++)
{
query_ids_to_labels[i].insert(filter_labels[i]);
}
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id;
tsl::robin_map<std::string, std::vector<uint32_t>> label_query_id_to_orig_id;
if (data_type == std::string("float"))
{
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
query_file, labels_to_number_of_queries, query_ids_to_labels,
all_labels); // query_filters acts like query_ids_to_labels
}
else if (data_type == std::string("int8"))
{
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
query_file, labels_to_number_of_queries, query_ids_to_labels,
all_labels); // query_filters acts like query_ids_to_labels
}
else if (data_type == std::string("uint8"))
{
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
query_file, labels_to_number_of_queries, query_ids_to_labels,
all_labels); // query_filters acts like query_ids_to_labels
}
else
{
diskann::cerr << "Invalid data type" << std::endl;
return -1;
}
// Generate label specific ground truths
try
{
for (const auto &label : all_labels)
{
std::string filtered_base_file = base_file + "_" + label;
std::string filtered_query_file = query_file + "_" + label;
std::string filtered_gt_file = gt_file + "_" + label;
if (data_type == std::string("float"))
aux_main<float>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
if (data_type == std::string("int8"))
aux_main<int8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
if (data_type == std::string("uint8"))
aux_main<uint8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
}
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Compute GT failed." << std::endl;
return -1;
}
// Combine the label specific ground truths to produce a single GT file
uint32_t *gt_ids = nullptr;
float *gt_dists = nullptr;
size_t gt_num, gt_dim;
std::vector<std::vector<int32_t>> final_gt_ids;
std::vector<std::vector<float>> final_gt_dists;
uint32_t query_num = 0;
for (const auto &lbl : all_labels)
{
query_num += labels_to_number_of_queries[lbl];
}
for (uint32_t i = 0; i < query_num; i++)
{
final_gt_ids.push_back(std::vector<int32_t>(K));
final_gt_dists.push_back(std::vector<float>(K));
}
for (const auto &lbl : all_labels)
{
std::string filtered_gt_file = gt_file + "_" + lbl;
load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim);
for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++)
{
uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i];
for (uint64_t j = 0; j < K; j++)
{
final_gt_ids[orig_query_id][j] = label_id_to_orig_id[lbl][gt_ids[i * K + j]];
final_gt_dists[orig_query_id][j] = gt_dists[i * K + j];
}
}
}
int32_t *closest_points = new int32_t[query_num * K];
float *dist_closest_points = new float[query_num * K];
for (uint32_t i = 0; i < query_num; i++)
{
for (uint32_t j = 0; j < K; j++)
{
closest_points[i * K + j] = final_gt_ids[i][j];
dist_closest_points[i * K + j] = final_gt_dists[i][j];
}
}
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, query_num, K);
// cleanup artifacts
std::cout << "Cleaning up artifacts..." << std::endl;
tsl::robin_set<std::string> paths_to_clean{gt_file, base_file, query_file};
clean_up_artifacts(paths_to_clean, all_labels);
}
}

View File

@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <numeric>
#include <omp.h>
#include <set>
#include <string.h>
#include <boost/program_options.hpp>
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <time.h>
#include <unistd.h>
#endif
#include "utils.h"
#include "index.h"
#include "memory_mapper.h"
namespace po = boost::program_options;
template <typename T> void bfs_count(const std::string &index_path, uint32_t data_dims)
{
using TagT = uint32_t;
using LabelT = uint32_t;
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false, false,
false, 0, false);
std::cout << "Index class instantiated" << std::endl;
index.load(index_path.c_str(), 1, 100);
std::cout << "Index loaded" << std::endl;
index.count_nodes_at_bfs_levels();
}
int main(int argc, char **argv)
{
std::string data_type, index_path_prefix;
uint32_t data_dims;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
"Path prefix to the index");
desc.add_options()("data_dims", po::value<uint32_t>(&data_dims)->required(), "Dimensionality of the data");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
try
{
if (data_type == std::string("int8"))
bfs_count<int8_t>(index_path_prefix, data_dims);
else if (data_type == std::string("uint8"))
bfs_count<uint8_t>(index_path_prefix, data_dims);
if (data_type == std::string("float"))
bfs_count<float>(index_path_prefix, data_dims);
}
catch (std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index BFS failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <limits>
#include <vector>
#include "utils.h"
#include "disk_utils.h"
#include "cached_io.h"
template <typename T> int create_disk_layout(char **argv)
{
std::string base_file(argv[2]);
std::string vamana_file(argv[3]);
std::string output_file(argv[4]);
diskann::create_disk_layout<T>(base_file, vamana_file, output_file);
return 0;
}
int main(int argc, char **argv)
{
if (argc != 5)
{
std::cout << argv[0]
<< " data_type <float/int8/uint8> data_bin "
"vamana_index_file output_diskann_index_file"
<< std::endl;
exit(-1);
}
int ret_val = -1;
if (std::string(argv[1]) == std::string("float"))
ret_val = create_disk_layout<float>(argv);
else if (std::string(argv[1]) == std::string("int8"))
ret_val = create_disk_layout<int8_t>(argv);
else if (std::string(argv[1]) == std::string("uint8"))
ret_val = create_disk_layout<uint8_t>(argv);
else
{
std::cout << "unsupported type. use int8/uint8/float " << std::endl;
ret_val = -2;
}
return ret_val;
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts,
size_t ndims, float bias, float scale)
{
reader.read((char *)read_buf, npts * ndims * sizeof(float));
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; d++)
{
write_buf[d + i * ndims] = (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale));
}
}
writer.write((char *)write_buf, npts * ndims);
}
int main(int argc, char **argv)
{
if (argc != 5)
{
std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" << std::endl;
exit(-1);
}
std::ifstream reader(argv[1], std::ios::binary);
uint32_t npts_u32;
uint32_t ndims_u32;
reader.read((char *)&npts_u32, sizeof(uint32_t));
reader.read((char *)&ndims_u32, sizeof(uint32_t));
size_t npts = npts_u32;
size_t ndims = ndims_u32;
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::ofstream writer(argv[2], std::ios::binary);
auto read_buf = new float[blk_size * ndims];
auto write_buf = new int8_t[blk_size * ndims];
float bias = (float)atof(argv[3]);
float scale = (float)atof(argv[4]);
writer.write((char *)(&npts_u32), sizeof(uint32_t));
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
writer.close();
reader.close();
}

View File

@@ -0,0 +1,95 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
// Convert float types
void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts,
size_t ndims)
{
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
for (size_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
}
writer.write((char *)write_buf, npts * ndims * sizeof(float));
}
// Convert byte types
void block_convert_byte(std::ifstream &reader, std::ofstream &writer, uint8_t *read_buf, uint8_t *write_buf,
size_t npts, size_t ndims)
{
reader.read((char *)read_buf, npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t)));
for (size_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t),
ndims * sizeof(uint8_t));
}
writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t));
}
int main(int argc, char **argv)
{
if (argc != 4)
{
std::cout << argv[0] << " <float/int8/uint8> input_vecs output_bin" << std::endl;
exit(-1);
}
int datasize = sizeof(float);
if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0)
{
datasize = sizeof(uint8_t);
}
else if (strcmp(argv[1], "float") != 0)
{
std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl;
exit(-1);
}
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
size_t fsize = reader.tellg();
reader.seekg(0, std::ios::beg);
uint32_t ndims_u32;
reader.read((char *)&ndims_u32, sizeof(uint32_t));
reader.seekg(0, std::ios::beg);
size_t ndims = (size_t)ndims_u32;
size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t));
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writer(argv[3], std::ios::binary);
int32_t npts_s32 = (int32_t)npts;
int32_t ndims_s32 = (int32_t)ndims;
writer.write((char *)&npts_s32, sizeof(int32_t));
writer.write((char *)&ndims_s32, sizeof(int32_t));
size_t chunknpts = std::min(npts, blk_size);
uint8_t *read_buf = new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))];
uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize];
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
if (datasize == sizeof(float))
{
block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, cblk_size, ndims);
}
else
{
block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims);
}
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
reader.close();
writer.close();
}

View File

@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts,
size_t ndims)
{
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
for (size_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(uint32_t));
for (size_t d = 0; d < ndims; d++)
write_buf[i * (ndims + 4) + 4 + d] = (uint8_t)read_buf[i * (ndims + 1) + 1 + d];
}
writer.write((char *)write_buf, npts * (ndims * 1 + 4));
}
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl;
exit(-1);
}
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
size_t fsize = reader.tellg();
reader.seekg(0, std::ios::beg);
uint32_t ndims_u32;
reader.read((char *)&ndims_u32, sizeof(uint32_t));
reader.seekg(0, std::ios::beg);
size_t ndims = (size_t)ndims_u32;
size_t npts = fsize / ((ndims + 1) * sizeof(float));
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writer(argv[2], std::ios::binary);
auto read_buf = new float[npts * (ndims + 1)];
auto write_buf = new uint8_t[npts * (ndims + 4)];
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
reader.close();
writer.close();
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <omp.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <ctime>
#include <iostream>
#include <iterator>
#include <map>
#include <sstream>
#include <string>
#include "partition.h"
#include "utils.h"
#include <fcntl.h>
#include <sys/stat.h>
#include <time.h>
#include <typeinfo>
template <typename T> int aux_main(char **argv)
{
std::string base_file(argv[2]);
std::string output_prefix(argv[3]);
float sampling_rate = (float)(std::atof(argv[4]));
gen_random_slice<T>(base_file, output_prefix, sampling_rate);
return 0;
}
int main(int argc, char **argv)
{
if (argc != 5)
{
std::cout << argv[0]
<< " data_type [float/int8/uint8] base_bin_file "
"sample_output_prefix sampling_probability"
<< std::endl;
exit(-1);
}
if (std::string(argv[1]) == std::string("float"))
{
aux_main<float>(argv);
}
else if (std::string(argv[1]) == std::string("int8"))
{
aux_main<int8_t>(argv);
}
else if (std::string(argv[1]) == std::string("uint8"))
{
aux_main<uint8_t>(argv);
}
else
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
return 0;
}

View File

@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include "math_utils.h"
#include "pq.h"
#include "partition.h"
#define KMEANS_ITERS_FOR_PQ 15
template <typename T>
bool generate_pq(const std::string &data_path, const std::string &index_prefix_path, const size_t num_pq_centers,
const size_t num_pq_chunks, const float sampling_rate, const bool opq)
{
std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin";
// generates random sample and sets it to train_data and updates train_size
size_t train_size, train_dim;
float *train_data;
gen_random_slice<T>(data_path, sampling_rate, train_data, train_size, train_dim);
std::cout << "For computing pivots, loaded sample data of size " << train_size << std::endl;
if (opq)
{
diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
(uint32_t)num_pq_chunks, pq_pivots_path, true);
}
else
{
diskann::generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
(uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path);
}
diskann::generate_pq_data_from_pivots<T>(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks,
pq_pivots_path, pq_compressed_vectors_path, true);
delete[] train_data;
return 0;
}
int main(int argc, char **argv)
{
if (argc != 7)
{
std::cout << "Usage: \n"
<< argv[0]
<< " <data_type[float/uint8/int8]> <data_file[.bin]>"
" <PQ_prefix_path> <target-bytes/data-point> "
"<sampling_rate> <PQ(0)/OPQ(1)>"
<< std::endl;
}
else
{
const std::string data_path(argv[2]);
const std::string index_prefix_path(argv[3]);
const size_t num_pq_centers = 256;
const size_t num_pq_chunks = (size_t)atoi(argv[4]);
const float sampling_rate = (float)atof(argv[5]);
const bool opq = atoi(argv[6]) == 0 ? false : true;
if (std::string(argv[1]) == std::string("float"))
generate_pq<float>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
else if (std::string(argv[1]) == std::string("int8"))
generate_pq<int8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
else if (std::string(argv[1]) == std::string("uint8"))
generate_pq<uint8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
else
std::cout << "Error. wrong file type" << std::endl;
}
}

View File

@@ -0,0 +1,204 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include <random>
#include <boost/program_options.hpp>
#include <math.h>
#include <cmath>
#include "utils.h"
namespace po = boost::program_options;
class ZipfDistribution
{
public:
ZipfDistribution(uint64_t num_points, uint32_t num_labels)
: num_labels(num_labels), num_points(num_points),
uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0))
{
}
std::unordered_map<uint32_t, uint32_t> createDistributionMap()
{
std::unordered_map<uint32_t, uint32_t> map;
uint32_t primary_label_freq = (uint32_t)ceil(num_points * distribution_factor);
for (uint32_t i{1}; i < num_labels + 1; i++)
{
map[i] = (uint32_t)ceil(primary_label_freq / i);
}
return map;
}
int writeDistribution(std::ofstream &outfile)
{
auto distribution_map = createDistributionMap();
for (uint32_t i{0}; i < num_points; i++)
{
bool label_written = false;
for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); it++)
{
auto label_selection_probability = std::bernoulli_distribution(distribution_factor / (double)it->first);
if (label_selection_probability(rand_engine) && distribution_map[it->first] > 0)
{
if (label_written)
{
outfile << ',';
}
outfile << it->first;
label_written = true;
// remove label from map if we have used all labels
distribution_map[it->first] -= 1;
}
}
if (!label_written)
{
outfile << 0;
}
if (i < num_points - 1)
{
outfile << '\n';
}
}
return 0;
}
int writeDistribution(std::string filename)
{
std::ofstream outfile(filename);
if (!outfile.is_open())
{
std::cerr << "Error: could not open output file " << filename << '\n';
return -1;
}
writeDistribution(outfile);
outfile.close();
}
private:
const uint32_t num_labels;
const uint64_t num_points;
const double distribution_factor = 0.7;
std::knuth_b rand_engine;
const std::uniform_real_distribution<double> uniform_zero_to_one;
};
int main(int argc, char **argv)
{
std::string output_file, distribution_type;
uint32_t num_labels;
uint64_t num_points;
try
{
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("output_file,O", po::value<std::string>(&output_file)->required(),
"Filename for saving the label file");
desc.add_options()("num_points,N", po::value<uint64_t>(&num_points)->required(), "Number of points in dataset");
desc.add_options()("num_labels,L", po::value<uint32_t>(&num_labels)->required(),
"Number of unique labels, up to 5000");
desc.add_options()("distribution_type,DT", po::value<std::string>(&distribution_type)->default_value("random"),
"Distribution function for labels <random/zipf/one_per_point> defaults "
"to random");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
if (num_labels > 5000)
{
std::cerr << "Error: num_labels must be 5000 or less" << '\n';
return -1;
}
if (num_points <= 0)
{
std::cerr << "Error: num_points must be greater than 0" << '\n';
return -1;
}
std::cout << "Generating synthetic labels for " << num_points << " points with " << num_labels << " unique labels"
<< '\n';
try
{
std::ofstream outfile(output_file);
if (!outfile.is_open())
{
std::cerr << "Error: could not open output file " << output_file << '\n';
return -1;
}
if (distribution_type == "zipf")
{
ZipfDistribution zipf(num_points, num_labels);
zipf.writeDistribution(outfile);
}
else if (distribution_type == "random")
{
for (size_t i = 0; i < num_points; i++)
{
bool label_written = false;
for (size_t j = 1; j <= num_labels; j++)
{
// 50% chance to assign each label
if (rand() > (RAND_MAX / 2))
{
if (label_written)
{
outfile << ',';
}
outfile << j;
label_written = true;
}
}
if (!label_written)
{
outfile << 0;
}
if (i < num_points - 1)
{
outfile << '\n';
}
}
}
else if (distribution_type == "one_per_point")
{
std::random_device rd; // obtain a random number from hardware
std::mt19937 gen(rd()); // seed the generator
std::uniform_int_distribution<> distr(0, num_labels); // define the range
for (size_t i = 0; i < num_points; i++)
{
outfile << distr(gen);
if (i != num_points - 1)
outfile << '\n';
}
}
if (outfile.is_open())
{
outfile.close();
}
std::cout << "Labels written to " << output_file << '\n';
}
catch (const std::exception &ex)
{
std::cerr << "Label generation failed: " << ex.what() << '\n';
return -1;
}
return 0;
}

View File

@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl;
exit(-1);
}
int8_t *input;
size_t npts, nd;
diskann::load_bin<int8_t>(argv[1], input, npts, nd);
float *output = new float[npts * nd];
diskann::convert_types<int8_t, float>(input, output, npts, nd);
diskann::save_bin<float>(argv[2], output, npts, nd);
delete[] output;
delete[] input;
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts,
size_t ndims, float bias, float scale)
{
reader.read((char *)read_buf, npts * ndims * sizeof(int8_t));
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; d++)
{
write_buf[d + i * ndims] = (((float)read_buf[d + i * ndims] - bias) * scale);
}
}
writer.write((char *)write_buf, npts * ndims * sizeof(float));
}
int main(int argc, char **argv)
{
if (argc != 5)
{
std::cout << "Usage: " << argv[0] << " input-int8.bin output-float.bin bias scale" << std::endl;
exit(-1);
}
std::ifstream reader(argv[1], std::ios::binary);
uint32_t npts_u32;
uint32_t ndims_u32;
reader.read((char *)&npts_u32, sizeof(uint32_t));
reader.read((char *)&ndims_u32, sizeof(uint32_t));
size_t npts = npts_u32;
size_t ndims = ndims_u32;
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::ofstream writer(argv[2], std::ios::binary);
auto read_buf = new int8_t[blk_size * ndims];
auto write_buf = new float[blk_size * ndims];
float bias = (float)atof(argv[3]);
float scale = (float)atof(argv[4]);
writer.write((char *)(&npts_u32), sizeof(uint32_t));
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
writer.close();
reader.close();
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts,
size_t ndims)
{
reader.read((char *)read_buf, npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t)));
for (size_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(uint32_t));
}
writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t));
}
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_ivecs output_bin" << std::endl;
exit(-1);
}
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
size_t fsize = reader.tellg();
reader.seekg(0, std::ios::beg);
uint32_t ndims_u32;
reader.read((char *)&ndims_u32, sizeof(uint32_t));
reader.seekg(0, std::ios::beg);
size_t ndims = (size_t)ndims_u32;
size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t));
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writer(argv[2], std::ios::binary);
int npts_s32 = (int)npts;
int ndims_s32 = (int)ndims;
writer.write((char *)&npts_s32, sizeof(int));
writer.write((char *)&ndims_s32, sizeof(int));
uint32_t *read_buf = new uint32_t[npts * (ndims + 1)];
uint32_t *write_buf = new uint32_t[npts * ndims];
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
reader.close();
writer.close();
}

View File

@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <algorithm>
#include <atomic>
#include <cassert>
#include <fstream>
#include <iostream>
#include <set>
#include <string>
#include <vector>
#include "disk_utils.h"
#include "cached_io.h"
#include "utils.h"
int main(int argc, char **argv)
{
if (argc != 9)
{
std::cout << argv[0]
<< " vamana_index_prefix[1] vamana_index_suffix[2] "
"idmaps_prefix[3] "
"idmaps_suffix[4] n_shards[5] max_degree[6] "
"output_vamana_path[7] "
"output_medoids_path[8]"
<< std::endl;
exit(-1);
}
std::string vamana_prefix(argv[1]);
std::string vamana_suffix(argv[2]);
std::string idmaps_prefix(argv[3]);
std::string idmaps_suffix(argv[4]);
uint64_t nshards = (uint64_t)std::atoi(argv[5]);
uint32_t max_degree = (uint64_t)std::atoi(argv[6]);
std::string output_index(argv[7]);
std::string output_medoids(argv[8]);
return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, idmaps_suffix, nshards, max_degree,
output_index, output_medoids);
}

View File

@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <index.h>
#include <math_utils.h>
#include "cached_io.h"
#include "partition.h"
// DEPRECATED: NEED TO REPROGRAM
int main(int argc, char **argv)
{
if (argc != 7)
{
std::cout << "Usage:\n"
<< argv[0]
<< " datatype<int8/uint8/float> <data_path>"
" <prefix_path> <sampling_rate> "
" <num_partitions> <k_index>"
<< std::endl;
exit(-1);
}
const std::string data_path(argv[2]);
const std::string prefix_path(argv[3]);
const float sampling_rate = (float)atof(argv[4]);
const size_t num_partitions = (size_t)std::atoi(argv[5]);
const size_t max_reps = 15;
const size_t k_index = (size_t)std::atoi(argv[6]);
if (std::string(argv[1]) == std::string("float"))
partition<float>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
else if (std::string(argv[1]) == std::string("int8"))
partition<int8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
else if (std::string(argv[1]) == std::string("uint8"))
partition<uint8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
else
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
}

View File

@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <index.h>
#include <math_utils.h>
#include "cached_io.h"
#include "partition.h"
// DEPRECATED: NEED TO REPROGRAM
int main(int argc, char **argv)
{
if (argc != 8)
{
std::cout << "Usage:\n"
<< argv[0]
<< " datatype<int8/uint8/float> <data_path>"
" <prefix_path> <sampling_rate> "
" <ram_budget(GB)> <graph_degree> <k_index>"
<< std::endl;
exit(-1);
}
const std::string data_path(argv[2]);
const std::string prefix_path(argv[3]);
const float sampling_rate = (float)atof(argv[4]);
const double ram_budget = (double)std::atof(argv[5]);
const size_t graph_degree = (size_t)std::atoi(argv[6]);
const size_t k_index = (size_t)std::atoi(argv[7]);
if (std::string(argv[1]) == std::string("float"))
partition_with_ram_budget<float>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
else if (std::string(argv[1]) == std::string("int8"))
partition_with_ram_budget<int8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
else if (std::string(argv[1]) == std::string("uint8"))
partition_with_ram_budget<uint8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
else
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
}

View File

@@ -0,0 +1,237 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include <cstdlib>
#include <random>
#include <cmath>
#include <boost/program_options.hpp>
#include "utils.h"
namespace po = boost::program_options;
int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm,
float rand_scale)
{
auto vec = new float[ndims];
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<> normal_rand{0, 1};
std::uniform_real_distribution<> unif_dis(1.0, rand_scale);
for (size_t i = 0; i < npts; i++)
{
float sum = 0;
float scale = 1.0f;
if (rand_scale > 1.0f)
scale = (float)unif_dis(gen);
for (size_t d = 0; d < ndims; ++d)
vec[d] = scale * (float)normal_rand(gen);
if (normalization)
{
for (size_t d = 0; d < ndims; ++d)
sum += vec[d] * vec[d];
for (size_t d = 0; d < ndims; ++d)
vec[d] = vec[d] * norm / std::sqrt(sum);
}
writer.write((char *)vec, ndims * sizeof(float));
}
delete[] vec;
return 0;
}
int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
{
auto vec = new float[ndims];
auto vec_T = new int8_t[ndims];
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<> normal_rand{0, 1};
for (size_t i = 0; i < npts; i++)
{
float sum = 0;
for (size_t d = 0; d < ndims; ++d)
vec[d] = (float)normal_rand(gen);
for (size_t d = 0; d < ndims; ++d)
sum += vec[d] * vec[d];
for (size_t d = 0; d < ndims; ++d)
vec[d] = vec[d] * norm / std::sqrt(sum);
for (size_t d = 0; d < ndims; ++d)
{
vec_T[d] = (int8_t)std::round(vec[d]);
}
writer.write((char *)vec_T, ndims * sizeof(int8_t));
}
delete[] vec;
delete[] vec_T;
return 0;
}
int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
{
auto vec = new float[ndims];
auto vec_T = new int8_t[ndims];
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<> normal_rand{0, 1};
for (size_t i = 0; i < npts; i++)
{
float sum = 0;
for (size_t d = 0; d < ndims; ++d)
vec[d] = (float)normal_rand(gen);
for (size_t d = 0; d < ndims; ++d)
sum += vec[d] * vec[d];
for (size_t d = 0; d < ndims; ++d)
vec[d] = vec[d] * norm / std::sqrt(sum);
for (size_t d = 0; d < ndims; ++d)
{
vec_T[d] = 128 + (int8_t)std::round(vec[d]);
}
writer.write((char *)vec_T, ndims * sizeof(uint8_t));
}
delete[] vec;
delete[] vec_T;
return 0;
}
int main(int argc, char **argv)
{
std::string data_type, output_file;
size_t ndims, npts;
float norm, rand_scaling;
bool normalization = false;
try
{
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("output_file", po::value<std::string>(&output_file)->required(),
"File name for saving the random vectors");
desc.add_options()("ndims,D", po::value<uint64_t>(&ndims)->required(), "Dimensoinality of the vector");
desc.add_options()("npts,N", po::value<uint64_t>(&npts)->required(), "Number of vectors");
desc.add_options()("norm", po::value<float>(&norm)->default_value(-1.0f),
"Norm of the vectors (if not specified, vectors are not normalized)");
desc.add_options()("rand_scaling", po::value<float>(&rand_scaling)->default_value(1.0f),
"Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from "
"[1, rand_scale]. Only applicable for floating point data");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
{
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
return -1;
}
if (norm > 0.0)
{
normalization = true;
}
if (rand_scaling < 1.0)
{
std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl;
return -1;
}
if ((rand_scaling > 1.0) && (normalization == true))
{
std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl;
return -1;
}
if (data_type == std::string("int8") || data_type == std::string("uint8"))
{
if (norm > 127)
{
std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be "
"greater "
"than 127"
<< std::endl;
return -1;
}
if (rand_scaling > 1.0)
{
std::cout << "Data scaling only supported for floating point data." << std::endl;
return -1;
}
}
try
{
std::ofstream writer;
writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
writer.open(output_file, std::ios::binary);
auto npts_u32 = (uint32_t)npts;
auto ndims_u32 = (uint32_t)ndims;
writer.write((char *)&npts_u32, sizeof(uint32_t));
writer.write((char *)&ndims_u32, sizeof(uint32_t));
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
int ret = 0;
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
if (data_type == std::string("float"))
{
ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling);
}
else if (data_type == std::string("int8"))
{
ret = block_write_int8(writer, ndims, cblk_size, norm);
}
else if (data_type == std::string("uint8"))
{
ret = block_write_uint8(writer, ndims, cblk_size, norm);
}
if (ret == 0)
std::cout << "Block #" << i << " written" << std::endl;
else
{
writer.close();
std::cout << "failed to write" << std::endl;
return -1;
}
}
writer.close();
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index build failed." << std::endl;
return -1;
}
return 0;
}

View File

@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include <cstdlib>
#include <random>
#include <cmath>
inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, uint32_t *count,
const std::vector<float> &recalls)
{
float found = 0;
for (uint32_t i = 0; i < npart; ++i)
{
size_t max_found = std::min(count[i], k);
found += recalls[max_found - 1] * max_found;
}
return found / (float)k_aggr;
}
void simulate(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, const uint32_t nsim,
const std::vector<float> &recalls)
{
std::random_device r;
std::default_random_engine randeng(r());
std::uniform_int_distribution<int> uniform_dist(0, npart - 1);
uint32_t *count = new uint32_t[npart];
double aggr_recall = 0;
for (uint32_t i = 0; i < nsim; ++i)
{
for (uint32_t p = 0; p < npart; ++p)
{
count[p] = 0;
}
for (uint32_t t = 0; t < k_aggr; ++t)
{
count[uniform_dist(randeng)]++;
}
aggr_recall += aggregate_recall(k_aggr, k, npart, count, recalls);
}
std::cout << "Aggregate recall is " << aggr_recall / (double)nsim << std::endl;
delete[] count;
}
int main(int argc, char **argv)
{
if (argc < 6)
{
std::cout << argv[0] << " k_aggregate k_out npart nsim recall@1 recall@2 ... recall@k" << std::endl;
exit(-1);
}
const uint32_t k_aggr = atoi(argv[1]);
const uint32_t k = atoi(argv[2]);
const uint32_t npart = atoi(argv[3]);
const uint32_t nsim = atoi(argv[4]);
std::vector<float> recalls;
for (int ctr = 5; ctr < argc; ctr++)
{
recalls.push_back((float)atof(argv[ctr]));
}
if (recalls.size() != k)
{
std::cerr << "Please input k numbers for recall@1, recall@2 .. recall@k" << std::endl;
}
if (k_aggr > npart * k)
{
std::cerr << "k_aggr must be <= k * npart" << std::endl;
exit(-1);
}
if (nsim <= npart * k_aggr)
{
std::cerr << "Choose nsim > npart*k_aggr" << std::endl;
exit(-1);
}
simulate(k_aggr, k, npart, nsim, recalls);
return 0;
}

View File

@@ -0,0 +1,147 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <omp.h>
#include <string.h>
#include <atomic>
#include <cstring>
#include <iomanip>
#include <set>
#include <boost/program_options.hpp>
#include "utils.h"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <unistd.h>
#include <sys/stat.h>
#include <time.h>
#else
#include <Windows.h>
#endif
namespace po = boost::program_options;
void stats_analysis(const std::string labels_file, std::string univeral_label, uint32_t density = 10)
{
std::string token, line;
std::ifstream labels_stream(labels_file);
std::unordered_map<std::string, uint32_t> label_counts;
std::string label_with_max_points;
uint32_t max_points = 0;
long long sum = 0;
long long point_cnt = 0;
float avg_labels_per_pt, mean_label_size;
std::vector<uint32_t> labels_per_point;
uint32_t dense_pts = 0;
if (labels_stream.is_open())
{
while (getline(labels_stream, line))
{
point_cnt++;
std::stringstream iss(line);
uint32_t lbl_cnt = 0;
while (getline(iss, token, ','))
{
lbl_cnt++;
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
if (label_counts.find(token) == label_counts.end())
label_counts[token] = 0;
label_counts[token]++;
}
if (lbl_cnt >= density)
{
dense_pts++;
}
labels_per_point.emplace_back(lbl_cnt);
}
}
std::cout << "fraction of dense points with >= " << density
<< " labels = " << (float)dense_pts / (float)labels_per_point.size() << std::endl;
std::sort(labels_per_point.begin(), labels_per_point.end());
std::vector<std::pair<std::string, uint32_t>> label_count_vec;
for (auto it = label_counts.begin(); it != label_counts.end(); it++)
{
auto &lbl = *it;
label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second));
if (lbl.second > max_points)
{
max_points = lbl.second;
label_with_max_points = lbl.first;
}
sum += lbl.second;
}
sort(label_count_vec.begin(), label_count_vec.end(),
[](const std::pair<std::string, uint32_t> &lhs, const std::pair<std::string, uint32_t> &rhs) {
return lhs.second < rhs.second;
});
for (float p = 0; p < 1; p += 0.05)
{
std::cout << "Percentile " << (100 * p) << "\t" << label_count_vec[(size_t)(p * label_count_vec.size())].first
<< " with count=" << label_count_vec[(size_t)(p * label_count_vec.size())].second << std::endl;
}
std::cout << "Most common label "
<< "\t" << label_count_vec[label_count_vec.size() - 1].first
<< " with count=" << label_count_vec[label_count_vec.size() - 1].second << std::endl;
if (label_count_vec.size() > 1)
std::cout << "Second common label "
<< "\t" << label_count_vec[label_count_vec.size() - 2].first
<< " with count=" << label_count_vec[label_count_vec.size() - 2].second << std::endl;
if (label_count_vec.size() > 2)
std::cout << "Third common label "
<< "\t" << label_count_vec[label_count_vec.size() - 3].first
<< " with count=" << label_count_vec[label_count_vec.size() - 3].second << std::endl;
avg_labels_per_pt = sum / (float)point_cnt;
mean_label_size = sum / (float)label_counts.size();
std::cout << "Total number of points = " << point_cnt << ", number of labels = " << label_counts.size()
<< std::endl;
std::cout << "Average number of labels per point = " << avg_labels_per_pt << std::endl;
std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl;
std::cout << "Most popular label is " << label_with_max_points << " with " << max_points << " pts" << std::endl;
}
int main(int argc, char **argv)
{
std::string labels_file, universal_label;
uint32_t density;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("labels_file", po::value<std::string>(&labels_file)->required(),
"path to labels data file.");
desc.add_options()("universal_label", po::value<std::string>(&universal_label)->required(),
"Universal label used in labels file.");
desc.add_options()("density", po::value<uint32_t>(&density)->default_value(1),
"Number of labels each point in labels file, defaults to 1");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &e)
{
std::cerr << e.what() << '\n';
return -1;
}
stats_analysis(labels_file, universal_label, density);
}

View File

@@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert_float(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
{
auto read_buf = new float[npts * (ndims + 1)];
auto cursor = read_buf;
float val;
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; ++d)
{
reader >> val;
*cursor = val;
cursor++;
}
}
writer.write((char *)read_buf, npts * ndims * sizeof(float));
delete[] read_buf;
}
void block_convert_int8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
{
auto read_buf = new int8_t[npts * (ndims + 1)];
auto cursor = read_buf;
int val;
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; ++d)
{
reader >> val;
*cursor = (int8_t)val;
cursor++;
}
}
writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t));
delete[] read_buf;
}
void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
{
auto read_buf = new uint8_t[npts * (ndims + 1)];
auto cursor = read_buf;
int val;
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; ++d)
{
reader >> val;
*cursor = (uint8_t)val;
cursor++;
}
}
writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t));
delete[] read_buf;
}
int main(int argc, char **argv)
{
if (argc != 6)
{
std::cout << argv[0]
<< "<float/int8/uint8> input_filename.tsv output_filename.bin "
"dim num_pts>"
<< std::endl;
exit(-1);
}
if (std::string(argv[1]) != std::string("float") && std::string(argv[1]) != std::string("int8") &&
std::string(argv[1]) != std::string("uint8"))
{
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
}
size_t ndims = atoi(argv[4]);
size_t npts = atoi(argv[5]);
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
// size_t fsize = reader.tellg();
reader.seekg(0, std::ios::beg);
reader.seekg(0, std::ios::beg);
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writer(argv[3], std::ios::binary);
auto npts_u32 = (uint32_t)npts;
auto ndims_u32 = (uint32_t)ndims;
writer.write((char *)&npts_u32, sizeof(uint32_t));
writer.write((char *)&ndims_u32, sizeof(uint32_t));
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
if (std::string(argv[1]) == std::string("float"))
{
block_convert_float(reader, writer, cblk_size, ndims);
}
else if (std::string(argv[1]) == std::string("int8"))
{
block_convert_int8(reader, writer, cblk_size, ndims);
}
else if (std::string(argv[1]) == std::string("uint8"))
{
block_convert_uint8(reader, writer, cblk_size, ndims);
}
std::cout << "Block #" << i << " written" << std::endl;
}
reader.close();
writer.close();
}

View File

@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_uint32_bin output_int8_bin" << std::endl;
exit(-1);
}
uint32_t *input;
size_t npts, nd;
diskann::load_bin<uint32_t>(argv[1], input, npts, nd);
uint8_t *output = new uint8_t[npts * nd];
diskann::convert_types<uint32_t, uint8_t>(input, output, npts, nd);
diskann::save_bin<uint8_t>(argv[2], output, npts, nd);
delete[] output;
delete[] input;
}

View File

@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_uint8_bin output_float_bin" << std::endl;
exit(-1);
}
uint8_t *input;
size_t npts, nd;
diskann::load_bin<uint8_t>(argv[1], input, npts, nd);
float *output = new float[npts * nd];
diskann::convert_types<uint8_t, float>(input, output, npts, nd);
diskann::save_bin<float>(argv[2], output, npts, nd);
delete[] output;
delete[] input;
}

View File

@@ -0,0 +1,163 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <omp.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <ctime>
#include <iostream>
#include <iterator>
#include <map>
#include <sstream>
#include <string>
#include <fcntl.h>
#include <sys/stat.h>
#include <time.h>
#include <typeinfo>
#include "partition.h"
#include "utils.h"
template <typename T> int analyze_norm(std::string base_file)
{
std::cout << "Analyzing data norms" << std::endl;
T *data;
size_t npts, ndims;
diskann::load_bin<T>(base_file, data, npts, ndims);
std::vector<float> norms(npts, 0);
#pragma omp parallel for schedule(dynamic)
for (int64_t i = 0; i < (int64_t)npts; i++)
{
for (size_t d = 0; d < ndims; d++)
norms[i] += data[i * ndims + d] * data[i * ndims + d];
norms[i] = std::sqrt(norms[i]);
}
std::sort(norms.begin(), norms.end());
for (int p = 0; p < 100; p += 5)
std::cout << "percentile " << p << ": " << norms[(uint64_t)(std::floor((p / 100.0) * npts))] << std::endl;
std::cout << "percentile 100"
<< ": " << norms[npts - 1] << std::endl;
delete[] data;
return 0;
}
template <typename T> int normalize_base(std::string base_file, std::string out_file)
{
std::cout << "Normalizing base" << std::endl;
T *data;
size_t npts, ndims;
diskann::load_bin<T>(base_file, data, npts, ndims);
// std::vector<float> norms(npts, 0);
#pragma omp parallel for schedule(dynamic)
for (int64_t i = 0; i < (int64_t)npts; i++)
{
float pt_norm = 0;
for (size_t d = 0; d < ndims; d++)
pt_norm += data[i * ndims + d] * data[i * ndims + d];
pt_norm = std::sqrt(pt_norm);
for (size_t d = 0; d < ndims; d++)
data[i * ndims + d] = static_cast<T>(data[i * ndims + d] / pt_norm);
}
diskann::save_bin<T>(out_file, data, npts, ndims);
delete[] data;
return 0;
}
template <typename T> int augment_base(std::string base_file, std::string out_file, bool prep_base = true)
{
std::cout << "Analyzing data norms" << std::endl;
T *data;
size_t npts, ndims;
diskann::load_bin<T>(base_file, data, npts, ndims);
std::vector<float> norms(npts, 0);
float max_norm = 0;
#pragma omp parallel for schedule(dynamic)
for (int64_t i = 0; i < (int64_t)npts; i++)
{
for (size_t d = 0; d < ndims; d++)
norms[i] += data[i * ndims + d] * data[i * ndims + d];
max_norm = norms[i] > max_norm ? norms[i] : max_norm;
}
// std::sort(norms.begin(), norms.end());
max_norm = std::sqrt(max_norm);
std::cout << "Max norm: " << max_norm << std::endl;
T *new_data;
size_t newdims = ndims + 1;
new_data = new T[npts * newdims];
for (size_t i = 0; i < npts; i++)
{
if (prep_base)
{
for (size_t j = 0; j < ndims; j++)
{
new_data[i * newdims + j] = static_cast<T>(data[i * ndims + j] / max_norm);
}
float diff = 1 - (norms[i] / (max_norm * max_norm));
diff = diff <= 0 ? 0 : std::sqrt(diff);
new_data[i * newdims + ndims] = static_cast<T>(diff);
if (diff <= 0)
{
std::cout << i << " has large max norm, investigate if needed. diff = " << diff << std::endl;
}
}
else
{
for (size_t j = 0; j < ndims; j++)
{
new_data[i * newdims + j] = static_cast<T>(data[i * ndims + j] / std::sqrt(norms[i]));
}
new_data[i * newdims + ndims] = 0;
}
}
diskann::save_bin<T>(out_file, new_data, npts, newdims);
delete[] new_data;
delete[] data;
return 0;
}
template <typename T> int aux_main(char **argv)
{
std::string base_file(argv[2]);
uint32_t option = atoi(argv[3]);
if (option == 1)
analyze_norm<T>(base_file);
else if (option == 2)
augment_base<T>(base_file, std::string(argv[4]), true);
else if (option == 3)
augment_base<T>(base_file, std::string(argv[4]), false);
else if (option == 4)
normalize_base<T>(base_file, std::string(argv[4]));
return 0;
}
int main(int argc, char **argv)
{
if (argc < 4)
{
std::cout << argv[0]
<< " data_type [float/int8/uint8] base_bin_file "
"[option: 1-norm analysis, 2-prep_base_for_mip, "
"3-prep_query_for_mip, 4-normalize-vecs] [out_file for "
"options 2/3/4]"
<< std::endl;
exit(-1);
}
if (std::string(argv[1]) == std::string("float"))
{
aux_main<float>(argv);
}
else if (std::string(argv[1]) == std::string("int8"))
{
aux_main<int8_t>(argv);
}
else if (std::string(argv[1]) == std::string("uint8"))
{
aux_main<uint8_t>(argv);
}
else
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
return 0;
}