Initial commit
This commit is contained in:
18
packages/leann-backend-diskann/third_party/DiskANN/include/restapi/common.h
vendored
Normal file
18
packages/leann-backend-diskann/third_party/DiskANN/include/restapi/common.h
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cpprest/base_uri.h>
|
||||
#include <restapi/search_wrapper.h>
|
||||
|
||||
namespace diskann
|
||||
{
|
||||
// Constants
|
||||
static const std::string VECTOR_KEY = "query", K_KEY = "k", INDICES_KEY = "indices", DISTANCES_KEY = "distances",
|
||||
TAGS_KEY = "tags", QUERY_ID_KEY = "query_id", ERROR_MESSAGE_KEY = "error", L_KEY = "Ls",
|
||||
TIME_TAKEN_KEY = "time_taken_in_us", PARTITION_KEY = "partition",
|
||||
UNKNOWN_ERROR = "unknown_error";
|
||||
const unsigned int DEFAULT_L = 100;
|
||||
|
||||
} // namespace diskann
|
||||
140
packages/leann-backend-diskann/third_party/DiskANN/include/restapi/search_wrapper.h
vendored
Normal file
140
packages/leann-backend-diskann/third_party/DiskANN/include/restapi/search_wrapper.h
vendored
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <index.h>
|
||||
#include <pq_flash_index.h>
|
||||
|
||||
namespace diskann
|
||||
{
|
||||
class SearchResult
|
||||
{
|
||||
public:
|
||||
SearchResult(unsigned int K, unsigned int elapsed_time_in_ms, const unsigned *const indices,
|
||||
const float *const distances, const std::string *const tags = nullptr,
|
||||
const unsigned *const partitions = nullptr);
|
||||
|
||||
const std::vector<unsigned int> &get_indices() const
|
||||
{
|
||||
return _indices;
|
||||
}
|
||||
const std::vector<float> &get_distances() const
|
||||
{
|
||||
return _distances;
|
||||
}
|
||||
bool tags_enabled() const
|
||||
{
|
||||
return _tags_enabled;
|
||||
}
|
||||
const std::vector<std::string> &get_tags() const
|
||||
{
|
||||
return _tags;
|
||||
}
|
||||
bool partitions_enabled() const
|
||||
{
|
||||
return _partitions_enabled;
|
||||
}
|
||||
const std::vector<unsigned> &get_partitions() const
|
||||
{
|
||||
return _partitions;
|
||||
}
|
||||
unsigned get_time() const
|
||||
{
|
||||
return _search_time_in_ms;
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned int _K;
|
||||
unsigned int _search_time_in_ms;
|
||||
std::vector<unsigned int> _indices;
|
||||
std::vector<float> _distances;
|
||||
|
||||
bool _tags_enabled;
|
||||
std::vector<std::string> _tags;
|
||||
|
||||
bool _partitions_enabled;
|
||||
std::vector<unsigned> _partitions;
|
||||
};
|
||||
|
||||
class SearchNotImplementedException : public std::logic_error
|
||||
{
|
||||
private:
|
||||
std::string _errormsg;
|
||||
|
||||
public:
|
||||
SearchNotImplementedException(const char *type) : std::logic_error("Not Implemented")
|
||||
{
|
||||
_errormsg = "Search with data type ";
|
||||
_errormsg += std::string(type);
|
||||
_errormsg += " not implemented : ";
|
||||
_errormsg += __FUNCTION__;
|
||||
}
|
||||
|
||||
virtual const char *what() const throw()
|
||||
{
|
||||
return _errormsg.c_str();
|
||||
}
|
||||
};
|
||||
|
||||
class BaseSearch
|
||||
{
|
||||
public:
|
||||
BaseSearch(const std::string &tagsFile = nullptr);
|
||||
virtual SearchResult search(const float *query, const unsigned int dimensions, const unsigned int K,
|
||||
const unsigned int Ls)
|
||||
{
|
||||
throw SearchNotImplementedException("float");
|
||||
}
|
||||
virtual SearchResult search(const int8_t *query, const unsigned int dimensions, const unsigned int K,
|
||||
const unsigned int Ls)
|
||||
{
|
||||
throw SearchNotImplementedException("int8_t");
|
||||
}
|
||||
|
||||
virtual SearchResult search(const uint8_t *query, const unsigned int dimensions, const unsigned int K,
|
||||
const unsigned int Ls)
|
||||
{
|
||||
throw SearchNotImplementedException("uint8_t");
|
||||
}
|
||||
|
||||
void lookup_tags(const unsigned K, const unsigned *indices, std::string *ret_tags);
|
||||
|
||||
protected:
|
||||
bool _tags_enabled;
|
||||
std::vector<std::string> _tags_str;
|
||||
};
|
||||
|
||||
template <typename T> class InMemorySearch : public BaseSearch
|
||||
{
|
||||
public:
|
||||
InMemorySearch(const std::string &baseFile, const std::string &indexFile, const std::string &tagsFile, Metric m,
|
||||
uint32_t num_threads, uint32_t search_l);
|
||||
virtual ~InMemorySearch();
|
||||
|
||||
SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls);
|
||||
|
||||
private:
|
||||
unsigned int _dimensions, _numPoints;
|
||||
std::unique_ptr<diskann::Index<T>> _index;
|
||||
};
|
||||
|
||||
template <typename T> class PQFlashSearch : public BaseSearch
|
||||
{
|
||||
public:
|
||||
PQFlashSearch(const std::string &indexPrefix, const unsigned num_nodes_to_cache, const unsigned num_threads,
|
||||
const std::string &tagsFile, Metric m);
|
||||
virtual ~PQFlashSearch();
|
||||
|
||||
SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls);
|
||||
|
||||
private:
|
||||
unsigned int _dimensions, _numPoints;
|
||||
std::unique_ptr<diskann::PQFlashIndex<T>> _index;
|
||||
std::shared_ptr<AlignedFileReader> reader;
|
||||
};
|
||||
} // namespace diskann
|
||||
45
packages/leann-backend-diskann/third_party/DiskANN/include/restapi/server.h
vendored
Normal file
45
packages/leann-backend-diskann/third_party/DiskANN/include/restapi/server.h
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <restapi/common.h>
|
||||
#include <cpprest/http_listener.h>
|
||||
|
||||
namespace diskann
|
||||
{
|
||||
class Server
|
||||
{
|
||||
public:
|
||||
Server(web::uri &url, std::vector<std::unique_ptr<diskann::BaseSearch>> &multi_searcher,
|
||||
const std::string &typestring);
|
||||
virtual ~Server();
|
||||
|
||||
pplx::task<void> open();
|
||||
pplx::task<void> close();
|
||||
|
||||
protected:
|
||||
template <class T> void handle_post(web::http::http_request message);
|
||||
|
||||
template <typename T>
|
||||
web::json::value toJsonArray(const std::vector<T> &v, std::function<web::json::value(const T &)> valConverter);
|
||||
web::json::value prepareResponse(const int64_t &queryId, const int k);
|
||||
|
||||
template <class T>
|
||||
void parseJson(const utility::string_t &body, unsigned int &k, int64_t &queryId, T *&queryVector,
|
||||
unsigned int &dimensions, unsigned &Ls);
|
||||
|
||||
web::json::value idsToJsonArray(const diskann::SearchResult &result);
|
||||
web::json::value distancesToJsonArray(const diskann::SearchResult &result);
|
||||
web::json::value tagsToJsonArray(const diskann::SearchResult &result);
|
||||
web::json::value partitionsToJsonArray(const diskann::SearchResult &result);
|
||||
|
||||
SearchResult aggregate_results(const unsigned K, const std::vector<diskann::SearchResult> &results);
|
||||
|
||||
private:
|
||||
bool _isDebug;
|
||||
std::unique_ptr<web::http::experimental::listener::http_listener> _listener;
|
||||
const bool _multi_search;
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> _multi_searcher;
|
||||
};
|
||||
} // namespace diskann
|
||||
Reference in New Issue
Block a user