Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
set(CMAKE_CXX_STANDARD 17)
add_executable(inmem_server inmem_server.cpp)
if(MSVC)
target_link_options(inmem_server PRIVATE /MACHINE:x64)
target_link_libraries(inmem_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(inmem_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(inmem_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(ssd_server ssd_server.cpp)
if(MSVC)
target_link_options(ssd_server PRIVATE /MACHINE:x64)
target_link_libraries(ssd_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(ssd_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(ssd_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(multiple_ssdindex_server multiple_ssdindex_server.cpp)
if(MSVC)
target_link_options(multiple_ssdindex_server PRIVATE /MACHINE:x64)
target_link_libraries(multiple_ssdindex_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(multiple_ssdindex_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(multiple_ssdindex_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(client client.cpp)
if(MSVC)
target_link_options(client PRIVATE /MACHINE:x64)
target_link_libraries(client debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(client optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(client ${PROJECT_NAME} -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()

View File

@@ -0,0 +1,124 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <cpprest/http_client.h>
#include <restapi/common.h>
using namespace web;
using namespace web::http;
using namespace web::http::client;
using namespace diskann;
namespace po = boost::program_options;
template <typename T>
void query_loop(const std::string &ip_addr_port, const std::string &query_file, const unsigned nq, const unsigned Ls,
const unsigned k_value)
{
web::http::client::http_client client(U(ip_addr_port));
T *data;
size_t npts = 1, ndims = 128, rounded_dim = 128;
diskann::load_aligned_bin<T>(query_file, data, npts, ndims, rounded_dim);
for (unsigned i = 0; i < nq; ++i)
{
T *vec = data + i * rounded_dim;
web::http::http_request http_query(methods::POST);
web::json::value queryJson = web::json::value::object();
queryJson[QUERY_ID_KEY] = i;
queryJson[K_KEY] = k_value;
queryJson[L_KEY] = Ls;
for (size_t i = 0; i < ndims; ++i)
{
queryJson[VECTOR_KEY][i] = web::json::value::number(vec[i]);
}
http_query.set_body(queryJson);
client.request(http_query)
.then([](web::http::http_response response) -> pplx::task<utility::string_t> {
if (response.status_code() == status_codes::OK)
{
return response.extract_string();
}
std::cerr << "Query failed" << std::endl;
return pplx::task_from_result(utility::string_t());
})
.then([](pplx::task<utility::string_t> previousTask) {
try
{
std::cout << previousTask.get() << std::endl;
}
catch (http_exception const &e)
{
std::wcout << e.what() << std::endl;
}
})
.wait();
}
}
int main(int argc, char *argv[])
{
std::string data_type, query_file, address;
uint32_t num_queries;
uint32_t l_search, k_value;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
"File containing the queries to search");
desc.add_options()("num_queries,Q", po::value<uint32_t>(&num_queries)->required(),
"Number of queries to search");
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
desc.add_options()("k_value,K", po::value<uint32_t>(&k_value)->default_value(10), "Value of K (default 10)");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
query_loop<float>(address, query_file, num_queries, l_search, k_value);
}
else if (data_type == std::string("int8"))
{
query_loop<int8_t>(address, query_file, num_queries, l_search, k_value);
}
else if (data_type == std::string("uint8"))
{
query_loop<uint8_t>(address, query_file, num_queries, l_search, k_value);
}
else
{
std::cerr << "Unsupported type " << argv[2] << std::endl;
return -1;
}
return 0;
}

View File

@@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_inMemorySearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
uint32_t num_threads;
uint32_t l_search;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("data_file", po::value<std::string>(&data_file)->required(),
"File containing the data found in the index");
desc.add_options()("index_path_prefix", po::value<std::string>(&index_file)->required(),
"Path prefix for saving index file components");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->required(),
"Number of threads used for building index");
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<float>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else if (data_type == std::string("int8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else if (data_type == std::string("uint8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else
{
std::cerr << "Unsupported data type " << argv[2] << std::endl;
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <restapi/server.h>
#include <restapi/in_memory_search.h>
#include <codecvt>
#include <iostream>
std::unique_ptr<Server> g_httpServer(nullptr);
std::unique_ptr<diskann::InMemorySearch> g_inMemorySearch(nullptr);
void setup(const utility::string_t &address)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::wcout << L"Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch));
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
void loadIndex(const char *indexFile, const char *baseFile, const char *idsFile)
{
auto nsgSearch = new diskann::InMemorySearch(baseFile, indexFile, idsFile, diskann::L2);
g_inMemorySearch = std::unique_ptr<diskann::InMemorySearch>(nsgSearch);
}
std::wstring getHostingAddress(const char *hostNameAndPort)
{
wchar_t buffer[4096];
mbstowcs_s(nullptr, buffer, sizeof(buffer) / sizeof(buffer[0]), hostNameAndPort,
sizeof(buffer) / sizeof(buffer[0]));
return std::wstring(buffer);
}
int main(int argc, char *argv[])
{
if (argc != 5)
{
std::cout << "Usage: nsg_server <ip_addr_and_port> <index_file> "
"<base_file> <ids_file> "
<< std::endl;
exit(1);
}
auto address = getHostingAddress(argv[1]);
loadIndex(argv[2], argv[3], argv[4]);
while (1)
{
try
{
setup(address);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,182 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <omp.h>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("index_prefix_paths", po::value<std::string>(&index_prefix_paths)->required(),
"Path prefix for loading index file components");
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
"Number of nodes to cache during search");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
std::vector<std::pair<std::string, std::string>> index_tag_paths;
std::ifstream index_in(index_prefix_paths);
if (!index_in.is_open())
{
std::cerr << "Could not open " << index_prefix_paths << std::endl;
exit(-1);
}
std::ifstream tags_in(tags_file);
if (!tags_in.is_open())
{
std::cerr << "Could not open " << tags_file << std::endl;
exit(-1);
}
std::string prefix, tagfile;
while (std::getline(index_in, prefix))
{
if (std::getline(tags_in, tagfile))
{
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
}
else
{
std::cerr << "The number of tags specified does not match the number of "
"indices specified"
<< std::endl;
exit(-1);
}
}
index_in.close();
tags_in.close();
if (data_type == std::string("float"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<float>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else if (data_type == std::string("int8"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else if (data_type == std::string("uint8"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<uint8_t>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else
{
std::cerr << "Unsupported data type " << data_type << std::endl;
exit(-1);
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,141 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <omp.h>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
"Path prefix for loading index file components");
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
"Number of nodes to cache during search");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else if (data_type == std::string("int8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<int8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else if (data_type == std::string("uint8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<uint8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else
{
std::cerr << "Unsupported data type " << argv[2] << std::endl;
exit(-1);
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}