Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,127 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <vector>
#include <string>
#include "types.h"
#include "windows_customizations.h"
#include "distance.h"
namespace diskann
{
template <typename data_t> class AbstractScratch;
template <typename data_t> class AbstractDataStore
{
public:
AbstractDataStore(const location_t capacity, const size_t dim);
virtual ~AbstractDataStore() = default;
// Return number of points returned
virtual location_t load(const std::string &filename) = 0;
// Why does store take num_pts? Since store only has capacity, but we allow
// resizing we can end up in a situation where the store has spare capacity.
// To optimize disk utilization, we pass the number of points that are "true"
// points, so that the store can discard the empty locations before saving.
virtual size_t save(const std::string &filename, const location_t num_pts) = 0;
DISKANN_DLLEXPORT virtual location_t capacity() const;
DISKANN_DLLEXPORT virtual size_t get_dims() const;
// Implementers can choose to return _dim if they are not
// concerned about memory alignment.
// Some distance metrics (like l2) need data vectors to be aligned, so we
// align the dimension by padding zeros.
virtual size_t get_aligned_dim() const = 0;
// populate the store with vectors (either from a pointer or bin file),
// potentially after pre-processing the vectors if the metric deems so
// e.g., normalizing vectors for cosine distance over floating-point vectors
// useful for bulk or static index building.
virtual void populate_data(const data_t *vectors, const location_t num_pts) = 0;
virtual void populate_data(const std::string &filename, const size_t offset) = 0;
// save the first num_pts many vectors back to bin file
// note: cannot undo the pre-processing done in populate data
virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) = 0;
// Returns the updated capacity of the datastore. Clients should check
// if resize actually changed the capacity to new_num_points before
// proceeding with operations. See the code below:
// auto new_capcity = data_store->resize(new_num_points);
// if ( new_capacity >= new_num_points) {
// //PROCEED
// else
// //ERROR.
virtual location_t resize(const location_t new_num_points);
// operations on vectors
// like populate_data function, but over one vector at a time useful for
// streaming setting
virtual void get_vector(const location_t i, data_t *dest) const = 0;
virtual void set_vector(const location_t i, const data_t *const vector) = 0;
virtual void prefetch_vector(const location_t loc) = 0;
// internal shuffle operations to move around vectors
// will bulk-move all the vectors in [old_start_loc, old_start_loc +
// num_points) to [new_start_loc, new_start_loc + num_points) and set the old
// positions to zero vectors.
virtual void move_vectors(const location_t old_start_loc, const location_t new_start_loc,
const location_t num_points) = 0;
// same as above, without resetting the vectors in [from_loc, from_loc +
// num_points) to zero
virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0;
// With the PQ Data Store PR, we have also changed iterate_to_fixed_point to NOT take the query
// from the scratch object. Therefore every data store has to implement preprocess_query which
// at the least will be to copy the query into the scratch object. So making this pure virtual.
virtual void preprocess_query(const data_t *aligned_query,
AbstractScratch<data_t> *query_scratch = nullptr) const = 0;
// distance functions.
virtual float get_distance(const data_t *query, const location_t loc) const = 0;
virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count,
float *distances, AbstractScratch<data_t> *scratch_space = nullptr) const = 0;
// Specific overload for index.cpp.
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const = 0;
virtual float get_distance(const location_t loc1, const location_t loc2) const = 0;
// stats of the data stored in store
// Returns the point in the dataset that is closest to the mean of all points
// in the dataset
virtual location_t calculate_medoid() const = 0;
// REFACTOR PQ TODO: Each data store knows about its distance function, so this is
// redundant. However, we don't have an OptmizedDataStore yet, and to preserve code
// compability, we are exposing this function.
virtual Distance<data_t> *get_dist_fn() const = 0;
// search helpers
// if the base data is aligned per the request of the metric, this will tell
// how to align the query vector in a consistent manner
virtual size_t get_alignment_factor() const = 0;
protected:
// Expand the datastore to new_num_points. Returns the new capacity created,
// which should be == new_num_points in the normal case. Implementers can also
// return _capacity to indicate that there are not implementing this method.
virtual location_t expand(const location_t new_num_points) = 0;
// Shrink the datastore to new_num_points. It is NOT an error if shrink
// doesn't reduce the capacity so callers need to check this correctly. See
// also for "default" implementation
virtual location_t shrink(const location_t new_num_points) = 0;
location_t _capacity;
size_t _dim;
};
} // namespace diskann

View File

@@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <string>
#include <vector>
#include "types.h"
namespace diskann
{
class AbstractGraphStore
{
public:
AbstractGraphStore(const size_t total_pts, const size_t reserve_graph_degree)
: _capacity(total_pts), _reserve_graph_degree(reserve_graph_degree)
{
}
virtual ~AbstractGraphStore() = default;
// returns tuple of <nodes_read, start, num_frozen_points>
virtual std::tuple<uint32_t, uint32_t, size_t> load(const std::string &index_path_prefix,
const size_t num_points) = 0;
virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_fz_points,
const uint32_t start) = 0;
// not synchronised, user should use lock when necvessary.
virtual const std::vector<location_t> &get_neighbours(const location_t i) const = 0;
virtual void add_neighbour(const location_t i, location_t neighbour_id) = 0;
virtual void clear_neighbours(const location_t i) = 0;
virtual void swap_neighbours(const location_t a, location_t b) = 0;
virtual void set_neighbours(const location_t i, std::vector<location_t> &neighbours) = 0;
virtual size_t resize_graph(const size_t new_size) = 0;
virtual void clear_graph() = 0;
virtual uint32_t get_max_observed_degree() = 0;
// set during load
virtual size_t get_max_range_of_graph() = 0;
// Total internal points _max_points + _num_frozen_points
size_t get_total_points()
{
return _capacity;
}
protected:
// Internal function, changes total points when resize_graph is called.
void set_total_points(size_t new_capacity)
{
_capacity = new_capacity;
}
size_t get_reserve_graph_degree()
{
return _reserve_graph_degree;
}
private:
size_t _capacity;
size_t _reserve_graph_degree;
};
} // namespace diskann

View File

@@ -0,0 +1,129 @@
#pragma once
#include "distance.h"
#include "parameters.h"
#include "utils.h"
#include "types.h"
#include "index_config.h"
#include "index_build_params.h"
#include <any>
namespace diskann
{
struct consolidation_report
{
enum status_code
{
SUCCESS = 0,
FAIL = 1,
LOCK_FAIL = 2,
INCONSISTENT_COUNT_ERROR = 3
};
status_code _status;
size_t _active_points, _max_points, _empty_slots, _slots_released, _delete_set_size, _num_calls_to_process_delete;
double _time;
consolidation_report(status_code status, size_t active_points, size_t max_points, size_t empty_slots,
size_t slots_released, size_t delete_set_size, size_t num_calls_to_process_delete,
double time_secs)
: _status(status), _active_points(active_points), _max_points(max_points), _empty_slots(empty_slots),
_slots_released(slots_released), _delete_set_size(delete_set_size),
_num_calls_to_process_delete(num_calls_to_process_delete), _time(time_secs)
{
}
};
/* A templated independent class for intercation with Index. Uses Type Erasure to add virtual implemetation of methods
that can take any type(using std::any) and Provides a clean API that can be inherited by different type of Index.
*/
class AbstractIndex
{
public:
AbstractIndex() = default;
virtual ~AbstractIndex() = default;
virtual void build(const std::string &data_file, const size_t num_points_to_load,
IndexFilterParams &build_params) = 0;
template <typename data_type, typename tag_type>
void build(const data_type *data, const size_t num_points_to_load, const std::vector<tag_type> &tags);
virtual void save(const char *filename, bool compact_before_save = false) = 0;
#ifdef EXEC_ENV_OLS
virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0;
#else
virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l) = 0;
#endif
// For FastL2 search on optimized layout
template <typename data_type>
void search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices);
// Initialize space for res_vectors before calling.
template <typename data_type, typename tag_type>
size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
float *distances, std::vector<data_type *> &res_vectors, bool use_filters = false,
const std::string filter_label = "");
// Added search overload that takes L as parameter, so that we
// can customize L on a per-query basis without tampering with "Parameters"
// IDtype is either uint32_t or uint64_t
template <typename data_type, typename IDType>
std::pair<uint32_t, uint32_t> search(const data_type *query, const size_t K, const uint32_t L, IDType *indices,
float *distances = nullptr);
// Filter support search
// IndexType is either uint32_t or uint64_t
template <typename IndexType>
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
float *distances);
// insert points with labels, labels should be present for filtered index
template <typename data_type, typename tag_type, typename label_type>
int insert_point(const data_type *point, const tag_type tag, const std::vector<label_type> &labels);
// insert point for unfiltered index build. do not use with filtered index
template <typename data_type, typename tag_type> int insert_point(const data_type *point, const tag_type tag);
// delete point with tag, or return -1 if point can not be deleted
template <typename tag_type> int lazy_delete(const tag_type &tag);
// batch delete tags and populates failed tags if unabke to delete given tags.
template <typename tag_type>
void lazy_delete(const std::vector<tag_type> &tags, std::vector<tag_type> &failed_tags);
template <typename tag_type> void get_active_tags(tsl::robin_set<tag_type> &active_tags);
template <typename data_type> void set_start_points_at_random(data_type radius, uint32_t random_seed = 0);
virtual consolidation_report consolidate_deletes(const IndexWriteParameters &parameters) = 0;
virtual void optimize_index_layout() = 0;
// memory should be allocated for vec before calling this function
template <typename tag_type, typename data_type> int get_vector_by_tag(tag_type &tag, data_type *vec);
template <typename label_type> void set_universal_label(const label_type universal_label);
private:
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0;
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
std::any &indices, float *distances = nullptr) = 0;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
const size_t K, const uint32_t L, std::any &indices,
float *distances) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;
virtual int _lazy_delete(const TagType &tag) = 0;
virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) = 0;
virtual void _get_active_tags(TagRobinSet &active_tags) = 0;
virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0;
virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0;
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
float *distances, DataVector &res_vectors, bool use_filters = false,
const std::string filter_label = "") = 0;
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
virtual void _set_universal_label(const LabelType universal_label) = 0;
};
} // namespace diskann

View File

@@ -0,0 +1,35 @@
#pragma once
namespace diskann
{
template <typename data_t> class PQScratch;
// By somewhat more than a coincidence, it seems that both InMemQueryScratch
// and SSDQueryScratch have the aligned query and PQScratch objects. So we
// can put them in a neat hierarchy and keep PQScratch as a standalone class.
template <typename data_t> class AbstractScratch
{
public:
AbstractScratch() = default;
// This class does not take any responsibilty for memory management of
// its members. It is the responsibility of the derived classes to do so.
virtual ~AbstractScratch() = default;
// Scratch objects should not be copied
AbstractScratch(const AbstractScratch &) = delete;
AbstractScratch &operator=(const AbstractScratch &) = delete;
data_t *aligned_query_T()
{
return _aligned_query_T;
}
PQScratch<data_t> *pq_scratch()
{
return _pq_scratch;
}
protected:
data_t *_aligned_query_T = nullptr;
PQScratch<data_t> *_pq_scratch = nullptr;
};
} // namespace diskann

View File

@@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#define MAX_IO_DEPTH 128
#include <vector>
#include <atomic>
#ifdef __linux__
#include <fcntl.h>
#include <libaio.h>
#include <unistd.h>
#include <malloc.h>
typedef io_context_t IOContext;
#elif __APPLE__
#include <dispatch/dispatch.h>
#include <fcntl.h>
#include <unistd.h>
struct IOContext
{
int fd;
dispatch_io_t channel;
dispatch_queue_t queue;
dispatch_group_t grp;
};
#elif _WINDOWS
#include <Windows.h>
#include <minwinbase.h>
#include <malloc.h>
#ifndef USE_BING_INFRA
struct IOContext
{
HANDLE fhandle = NULL;
HANDLE iocp = NULL;
std::vector<OVERLAPPED> reqs;
};
#else
#include "IDiskPriorityIO.h"
#include <atomic>
// TODO: Caller code is very callous about copying IOContext objects
// all over the place. MUST verify that it won't cause leaks/logical
// errors.
// Because of such callous copying, we have to use ptr->atomic instead
// of atomic, as atomic is not copyable.
struct IOContext
{
enum Status
{
READ_WAIT = 0,
READ_SUCCESS,
READ_FAILED,
PROCESS_COMPLETE
};
std::shared_ptr<ANNIndex::IDiskPriorityIO> m_pDiskIO = nullptr;
std::shared_ptr<std::vector<ANNIndex::AsyncReadRequest>> m_pRequests;
std::shared_ptr<std::vector<Status>> m_pRequestsStatus;
// waitonaddress on this memory to wait for IO completion signal
// reader should signal this memory after IO completion
// TODO: WindowsAlignedFileReader can be modified to take advantage of this
// and can largely share code with the file reader for Bing.
mutable volatile long m_completeCount = 0;
IOContext()
: m_pRequestsStatus(new std::vector<Status>()), m_pRequests(new std::vector<ANNIndex::AsyncReadRequest>())
{
(*m_pRequestsStatus).reserve(MAX_IO_DEPTH);
(*m_pRequests).reserve(MAX_IO_DEPTH);
}
};
#endif
#endif
#include <cstdio>
#include <mutex>
#include <thread>
#include "tsl/robin_map.h"
#include "utils.h"
// NOTE :: all 3 fields must be 512-aligned
struct AlignedRead
{
uint64_t offset; // where to read from
uint64_t len; // how much to read
void *buf; // where to read into
AlignedRead() : offset(0), len(0), buf(nullptr)
{
}
AlignedRead(uint64_t offset, uint64_t len, void *buf) : offset(offset), len(len), buf(buf)
{
assert(IS_512_ALIGNED(offset));
assert(IS_512_ALIGNED(len));
assert(IS_512_ALIGNED(buf));
// assert(malloc_usable_size(buf) >= len);
}
};
class AlignedFileReader
{
protected:
tsl::robin_map<std::thread::id, IOContext> ctx_map;
std::mutex ctx_mut;
public:
// returns the thread-specific context
// returns (io_context_t)(-1) if thread is not registered
virtual IOContext &get_ctx() = 0;
virtual ~AlignedFileReader(){};
// register thread-id for a context
virtual void register_thread() = 0;
// de-register thread-id for a context
virtual void deregister_thread() = 0;
virtual void deregister_all_threads() = 0;
// Open & close ops
// Blocking calls
virtual void open(const std::string &fname) = 0;
virtual void close() = 0;
// process batch of aligned requests in parallel
// NOTE :: blocking call
virtual void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async = false) = 0;
#ifdef USE_BING_INFRA
// wait for completion of one request in a batch of requests
virtual void wait(IOContext &ctx, int &completedIndex) = 0;
#endif
};

View File

@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <string>
#include <stdexcept>
#include <system_error>
#include "windows_customizations.h"
#include <cstdint>
#ifndef _WINDOWS
#define __FUNCSIG__ __PRETTY_FUNCTION__
#endif
namespace diskann
{
class ANNException : public std::runtime_error
{
public:
DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode);
DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode, const std::string &funcSig,
const std::string &fileName, uint32_t lineNum);
private:
int _errorCode;
};
class FileException : public ANNException
{
public:
DISKANN_DLLEXPORT FileException(const std::string &filename, std::system_error &e, const std::string &funcSig,
const std::string &fileName, uint32_t lineNum);
};
} // namespace diskann

View File

@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <cstdint>
#include <cstddef>
#include <vector>
#include <any>
#include "tsl/robin_set.h"
namespace AnyWrapper
{
/*
* Base Struct to hold refrence to the data.
* Note: No memory mamagement, caller need to keep object alive.
*/
struct AnyReference
{
template <typename Ty> AnyReference(Ty &reference) : _data(&reference)
{
}
template <typename Ty> Ty &get()
{
auto ptr = std::any_cast<Ty *>(_data);
return *ptr;
}
private:
std::any _data;
};
struct AnyRobinSet : public AnyReference
{
template <typename T> AnyRobinSet(const tsl::robin_set<T> &robin_set) : AnyReference(robin_set)
{
}
template <typename T> AnyRobinSet(tsl::robin_set<T> &robin_set) : AnyReference(robin_set)
{
}
};
struct AnyVector : public AnyReference
{
template <typename T> AnyVector(const std::vector<T> &vector) : AnyReference(vector)
{
}
template <typename T> AnyVector(std::vector<T> &vector) : AnyReference(vector)
{
}
};
} // namespace AnyWrapper

View File

@@ -0,0 +1,26 @@
#pragma once
#ifdef __APPLE__
#include "aligned_file_reader.h"
class AppleAlignedFileReader : public AlignedFileReader
{
private:
uint64_t file_sz;
FileHandle file_desc;
public:
AppleAlignedFileReader();
~AppleAlignedFileReader();
IOContext &get_ctx();
void register_thread();
void deregister_thread();
void deregister_all_threads();
void open(const std::string &fname);
void close();
void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async = false);
};
#endif

View File

@@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
namespace boost
{
#ifndef BOOST_DYNAMIC_BITSET_FWD_HPP
template <typename Block = unsigned long, typename Allocator = std::allocator<Block>> class dynamic_bitset;
#endif
} // namespace boost

View File

@@ -0,0 +1,217 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <cstring>
#include <fstream>
#include <iostream>
#include <sstream>
#include "logger.h"
#include "ann_exception.h"
// sequential cached reads
class cached_ifstream
{
public:
cached_ifstream()
{
}
cached_ifstream(const std::string &filename, uint64_t cacheSize) : cache_size(cacheSize), cur_off(0)
{
reader.exceptions(std::ifstream::failbit | std::ifstream::badbit);
this->open(filename, cache_size);
}
~cached_ifstream()
{
delete[] cache_buf;
reader.close();
}
void open(const std::string &filename, uint64_t cacheSize)
{
this->cur_off = 0;
try
{
reader.open(filename, std::ios::binary | std::ios::ate);
fsize = reader.tellg();
reader.seekg(0, std::ios::beg);
assert(reader.is_open());
assert(cacheSize > 0);
cacheSize = (std::min)(cacheSize, fsize);
this->cache_size = cacheSize;
cache_buf = new char[cacheSize];
reader.read(cache_buf, cacheSize);
diskann::cout << "Opened: " << filename.c_str() << ", size: " << fsize << ", cache_size: " << cacheSize
<< std::endl;
}
catch (std::system_error &e)
{
throw diskann::FileException(filename, e, __FUNCSIG__, __FILE__, __LINE__);
}
}
size_t get_file_size()
{
return fsize;
}
void read(char *read_buf, uint64_t n_bytes)
{
assert(cache_buf != nullptr);
assert(read_buf != nullptr);
if (n_bytes <= (cache_size - cur_off))
{
// case 1: cache contains all data
memcpy(read_buf, cache_buf + cur_off, n_bytes);
cur_off += n_bytes;
}
else
{
// case 2: cache contains some data
uint64_t cached_bytes = cache_size - cur_off;
if (n_bytes - cached_bytes > fsize - reader.tellg())
{
std::stringstream stream;
stream << "Reading beyond end of file" << std::endl;
stream << "n_bytes: " << n_bytes << " cached_bytes: " << cached_bytes << " fsize: " << fsize
<< " current pos:" << reader.tellg() << std::endl;
diskann::cout << stream.str() << std::endl;
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
memcpy(read_buf, cache_buf + cur_off, cached_bytes);
// go to disk and fetch more data
reader.read(read_buf + cached_bytes, n_bytes - cached_bytes);
// reset cur off
cur_off = cache_size;
uint64_t size_left = fsize - reader.tellg();
if (size_left >= cache_size)
{
reader.read(cache_buf, cache_size);
cur_off = 0;
}
// note that if size_left < cache_size, then cur_off = cache_size,
// so subsequent reads will all be directly from file
}
}
private:
// underlying ifstream
std::ifstream reader;
// # bytes to cache in one shot read
uint64_t cache_size = 0;
// underlying buf for cache
char *cache_buf = nullptr;
// offset into cache_buf for cur_pos
uint64_t cur_off = 0;
// file size
uint64_t fsize = 0;
};
// sequential cached writes
class cached_ofstream
{
public:
cached_ofstream(const std::string &filename, uint64_t cache_size) : cache_size(cache_size), cur_off(0)
{
writer.exceptions(std::ifstream::failbit | std::ifstream::badbit);
try
{
writer.open(filename, std::ios::binary);
assert(writer.is_open());
assert(cache_size > 0);
cache_buf = new char[cache_size];
diskann::cout << "Opened: " << filename.c_str() << ", cache_size: " << cache_size << std::endl;
}
catch (std::system_error &e)
{
throw diskann::FileException(filename, e, __FUNCSIG__, __FILE__, __LINE__);
}
}
~cached_ofstream()
{
this->close();
}
void close()
{
// dump any remaining data in memory
if (cur_off > 0)
{
this->flush_cache();
}
if (cache_buf != nullptr)
{
delete[] cache_buf;
cache_buf = nullptr;
}
if (writer.is_open())
writer.close();
diskann::cout << "Finished writing " << fsize << "B" << std::endl;
}
size_t get_file_size()
{
return fsize;
}
// writes n_bytes from write_buf to the underlying ofstream/cache
void write(char *write_buf, uint64_t n_bytes)
{
assert(cache_buf != nullptr);
if (n_bytes <= (cache_size - cur_off))
{
// case 1: cache can take all data
memcpy(cache_buf + cur_off, write_buf, n_bytes);
cur_off += n_bytes;
}
else
{
// case 2: cache cant take all data
// go to disk and write existing cache data
writer.write(cache_buf, cur_off);
fsize += cur_off;
// write the new data to disk
writer.write(write_buf, n_bytes);
fsize += n_bytes;
// memset all cache data and reset cur_off
memset(cache_buf, 0, cache_size);
cur_off = 0;
}
}
void flush_cache()
{
assert(cache_buf != nullptr);
writer.write(cache_buf, cur_off);
fsize += cur_off;
memset(cache_buf, 0, cache_size);
cur_off = 0;
}
void reset()
{
flush_cache();
writer.seekp(0);
}
private:
// underlying ofstream
std::ofstream writer;
// # bytes to cache for one shot write
uint64_t cache_size = 0;
// underlying buf for cache
char *cache_buf = nullptr;
// offset into cache_buf for cur_pos
uint64_t cur_off = 0;
// file size
uint64_t fsize = 0;
};

View File

@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <algorithm>
#include <atomic>
#include <cassert>
#include <chrono>
#include <climits>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fcntl.h>
#include <fstream>
#include <iostream>
#include <iomanip>
#include <omp.h>
#include <queue>
#include <random>
#include <set>
#include <shared_mutex>
#include <sys/stat.h>
#include <sstream>
#include <unordered_map>
#include <vector>

View File

@@ -0,0 +1,132 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <queue>
#include <thread>
#include <type_traits>
#include <unordered_set>
namespace diskann
{
template <typename T> class ConcurrentQueue
{
typedef std::chrono::microseconds chrono_us_t;
typedef std::unique_lock<std::mutex> mutex_locker;
std::queue<T> q;
std::mutex mut;
std::mutex push_mut;
std::mutex pop_mut;
std::condition_variable push_cv;
std::condition_variable pop_cv;
T null_T;
public:
ConcurrentQueue()
{
}
ConcurrentQueue(T nullT)
{
this->null_T = nullT;
}
~ConcurrentQueue()
{
this->push_cv.notify_all();
this->pop_cv.notify_all();
}
// queue stats
uint64_t size()
{
mutex_locker lk(this->mut);
uint64_t ret = q.size();
lk.unlock();
return ret;
}
bool empty()
{
return (this->size() == 0);
}
// PUSH BACK
void push(T &new_val)
{
mutex_locker lk(this->mut);
this->q.push(new_val);
lk.unlock();
}
template <class Iterator> void insert(Iterator iter_begin, Iterator iter_end)
{
mutex_locker lk(this->mut);
for (Iterator it = iter_begin; it != iter_end; it++)
{
this->q.push(*it);
}
lk.unlock();
}
// POP FRONT
T pop()
{
mutex_locker lk(this->mut);
if (this->q.empty())
{
lk.unlock();
return this->null_T;
}
else
{
T ret = this->q.front();
this->q.pop();
// diskann::cout << "thread_id: " << std::this_thread::get_id() <<
// ", ctx: "
// << ret.ctx << "\n";
lk.unlock();
return ret;
}
}
// register for notifications
void wait_for_push_notify(chrono_us_t wait_time = chrono_us_t{10})
{
mutex_locker lk(this->push_mut);
this->push_cv.wait_for(lk, wait_time);
lk.unlock();
}
void wait_for_pop_notify(chrono_us_t wait_time = chrono_us_t{10})
{
mutex_locker lk(this->pop_mut);
this->pop_cv.wait_for(lk, wait_time);
lk.unlock();
}
// just notify functions
void push_notify_one()
{
this->push_cv.notify_one();
}
void push_notify_all()
{
this->push_cv.notify_all();
}
void pop_notify_one()
{
this->pop_cv.notify_one();
}
void pop_notify_all()
{
this->pop_cv.notify_all();
}
};
} // namespace diskann

View File

@@ -0,0 +1,285 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <vector>
#include <limits>
#include <algorithm>
#include <stdexcept>
#ifndef __APPLE__
#include <immintrin.h>
#include <smmintrin.h>
#include <tmmintrin.h>
#include "simd_utils.h"
#endif
extern bool Avx2SupportedCPU;
#ifdef _WINDOWS
// SIMD implementation of Cosine similarity. Taken from hnsw library.
/**
* Non-metric Space Library
*
* Authors: Bilegsaikhan Naidan (https://github.com/bileg), Leonid Boytsov
* (http://boytsov.info). With contributions from Lawrence Cayton
* (http://lcayton.com/) and others.
*
* For the complete list of contributors and further details see:
* https://github.com/searchivarius/NonMetricSpaceLib
*
* Copyright (c) 2014
*
* This code is released under the
* Apache License Version 2.0 http://www.apache.org/licenses/.
*
*/
namespace diskann
{
using namespace std;
#define PORTABLE_ALIGN16 __declspec(align(16))
static float NormScalarProductSIMD2(const int8_t *pVect1, const int8_t *pVect2, uint32_t qty)
{
if (Avx2SupportedCPU)
{
__m256 cos, p1Len, p2Len;
cos = p1Len = p2Len = _mm256_setzero_ps();
while (qty >= 32)
{
__m256i rx = _mm256_load_si256((__m256i *)pVect1), ry = _mm256_load_si256((__m256i *)pVect2);
cos = _mm256_add_ps(cos, _mm256_mul_epi8(rx, ry));
p1Len = _mm256_add_ps(p1Len, _mm256_mul_epi8(rx, rx));
p2Len = _mm256_add_ps(p2Len, _mm256_mul_epi8(ry, ry));
pVect1 += 32;
pVect2 += 32;
qty -= 32;
}
while (qty > 0)
{
__m128i rx = _mm_load_si128((__m128i *)pVect1), ry = _mm_load_si128((__m128i *)pVect2);
cos = _mm256_add_ps(cos, _mm256_mul32_pi8(rx, ry));
p1Len = _mm256_add_ps(p1Len, _mm256_mul32_pi8(rx, rx));
p2Len = _mm256_add_ps(p2Len, _mm256_mul32_pi8(ry, ry));
pVect1 += 4;
pVect2 += 4;
qty -= 4;
}
cos = _mm256_hadd_ps(_mm256_hadd_ps(cos, cos), cos);
p1Len = _mm256_hadd_ps(_mm256_hadd_ps(p1Len, p1Len), p1Len);
p2Len = _mm256_hadd_ps(_mm256_hadd_ps(p2Len, p2Len), p2Len);
float denominator = max(numeric_limits<float>::min() * 2, sqrt(p1Len.m256_f32[0] + p1Len.m256_f32[4]) *
sqrt(p2Len.m256_f32[0] + p2Len.m256_f32[4]));
float cosine = (cos.m256_f32[0] + cos.m256_f32[4]) / denominator;
return max(float(-1), min(float(1), cosine));
}
__m128 cos, p1Len, p2Len;
cos = p1Len = p2Len = _mm_setzero_ps();
__m128i rx, ry;
while (qty >= 16)
{
rx = _mm_load_si128((__m128i *)pVect1);
ry = _mm_load_si128((__m128i *)pVect2);
cos = _mm_add_ps(cos, _mm_mul_epi8(rx, ry));
p1Len = _mm_add_ps(p1Len, _mm_mul_epi8(rx, rx));
p2Len = _mm_add_ps(p2Len, _mm_mul_epi8(ry, ry));
pVect1 += 16;
pVect2 += 16;
qty -= 16;
}
while (qty > 0)
{
rx = _mm_load_si128((__m128i *)pVect1);
ry = _mm_load_si128((__m128i *)pVect2);
cos = _mm_add_ps(cos, _mm_mul32_pi8(rx, ry));
p1Len = _mm_add_ps(p1Len, _mm_mul32_pi8(rx, rx));
p2Len = _mm_add_ps(p2Len, _mm_mul32_pi8(ry, ry));
pVect1 += 4;
pVect2 += 4;
qty -= 4;
}
cos = _mm_hadd_ps(_mm_hadd_ps(cos, cos), cos);
p1Len = _mm_hadd_ps(_mm_hadd_ps(p1Len, p1Len), p1Len);
p2Len = _mm_hadd_ps(_mm_hadd_ps(p2Len, p2Len), p2Len);
float norm1 = p1Len.m128_f32[0];
float norm2 = p2Len.m128_f32[0];
static const float eps = numeric_limits<float>::min() * 2;
if (norm1 < eps)
{ /*
* This shouldn't normally happen for this space, but
* if it does, we don't want to get NANs
*/
if (norm2 < eps)
{
return 1;
}
return 0;
}
/*
* Sometimes due to rounding errors, we get values > 1 or < -1.
* This throws off other functions that use scalar product, e.g., acos
*/
return max(float(-1), min(float(1), cos.m128_f32[0] / sqrt(norm1) / sqrt(norm2)));
}
static float NormScalarProductSIMD(const float *pVect1, const float *pVect2, uint32_t qty)
{
// Didn't get significant performance gain compared with 128bit version.
static const float eps = numeric_limits<float>::min() * 2;
if (Avx2SupportedCPU)
{
uint32_t qty8 = qty / 8;
const float *pEnd1 = pVect1 + 8 * qty8;
const float *pEnd2 = pVect1 + qty;
__m256 v1, v2;
__m256 sum_prod = _mm256_set_ps(0, 0, 0, 0, 0, 0, 0, 0);
__m256 sum_square1 = sum_prod;
__m256 sum_square2 = sum_prod;
while (pVect1 < pEnd1)
{
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum_prod = _mm256_add_ps(sum_prod, _mm256_mul_ps(v1, v2));
sum_square1 = _mm256_add_ps(sum_square1, _mm256_mul_ps(v1, v1));
sum_square2 = _mm256_add_ps(sum_square2, _mm256_mul_ps(v2, v2));
}
float PORTABLE_ALIGN16 TmpResProd[8];
float PORTABLE_ALIGN16 TmpResSquare1[8];
float PORTABLE_ALIGN16 TmpResSquare2[8];
_mm256_store_ps(TmpResProd, sum_prod);
_mm256_store_ps(TmpResSquare1, sum_square1);
_mm256_store_ps(TmpResSquare2, sum_square2);
float sum = 0.0f;
float norm1 = 0.0f;
float norm2 = 0.0f;
for (uint32_t i = 0; i < 8; ++i)
{
sum += TmpResProd[i];
norm1 += TmpResSquare1[i];
norm2 += TmpResSquare2[i];
}
while (pVect1 < pEnd2)
{
sum += (*pVect1) * (*pVect2);
norm1 += (*pVect1) * (*pVect1);
norm2 += (*pVect2) * (*pVect2);
++pVect1;
++pVect2;
}
if (norm1 < eps)
{
return norm2 < eps ? 1.0f : 0.0f;
}
return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2)));
}
__m128 v1, v2;
__m128 sum_prod = _mm_set1_ps(0);
__m128 sum_square1 = sum_prod;
__m128 sum_square2 = sum_prod;
while (qty >= 4)
{
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
sum_square1 = _mm_add_ps(sum_square1, _mm_mul_ps(v1, v1));
sum_square2 = _mm_add_ps(sum_square2, _mm_mul_ps(v2, v2));
qty -= 4;
}
float sum = sum_prod.m128_f32[0] + sum_prod.m128_f32[1] + sum_prod.m128_f32[2] + sum_prod.m128_f32[3];
float norm1 = sum_square1.m128_f32[0] + sum_square1.m128_f32[1] + sum_square1.m128_f32[2] + sum_square1.m128_f32[3];
float norm2 = sum_square2.m128_f32[0] + sum_square2.m128_f32[1] + sum_square2.m128_f32[2] + sum_square2.m128_f32[3];
if (norm1 < eps)
{
return norm2 < eps ? 1.0f : 0.0f;
}
return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2)));
}
static float NormScalarProductSIMD2(const float *pVect1, const float *pVect2, uint32_t qty)
{
return NormScalarProductSIMD(pVect1, pVect2, qty);
}
template <class T> static float CosineSimilarity2(const T *p1, const T *p2, uint32_t qty)
{
return std::max(0.0f, 1.0f - NormScalarProductSIMD2(p1, p2, qty));
}
// static template float CosineSimilarity2<__int8>(const __int8* pVect1,
// const __int8* pVect2, size_t qty);
// static template float CosineSimilarity2<float>(const float* pVect1,
// const float* pVect2, size_t qty);
template <class T> static void CosineSimilarityNormalize(T *pVector, uint32_t qty)
{
T sum = 0;
for (uint32_t i = 0; i < qty; ++i)
{
sum += pVector[i] * pVector[i];
}
sum = 1 / sqrt(sum);
if (sum == 0)
{
sum = numeric_limits<T>::min();
}
for (uint32_t i = 0; i < qty; ++i)
{
pVector[i] *= sum;
}
}
// template static void CosineSimilarityNormalize<float>(float* pVector,
// size_t qty);
// template static void CosineSimilarityNormalize<double>(double* pVector,
// size_t qty);
template <> void CosineSimilarityNormalize(__int8 * /*pVector*/, uint32_t /*qty*/)
{
throw std::runtime_error("For int8 type vector, you can not use cosine distance!");
}
template <> void CosineSimilarityNormalize(__int16 * /*pVector*/, uint32_t /*qty*/)
{
throw std::runtime_error("For int16 type vector, you can not use cosine distance!");
}
template <> void CosineSimilarityNormalize(int * /*pVector*/, uint32_t /*qty*/)
{
throw std::runtime_error("For int type vector, you can not use cosine distance!");
}
} // namespace diskann
#endif

View File

@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <stdint.h>
namespace diskann
{
namespace defaults
{
const float ALPHA = 1.2f;
const uint32_t NUM_THREADS = 0;
const uint32_t MAX_OCCLUSION_SIZE = 750;
const bool HAS_LABELS = false;
const uint32_t FILTER_LIST_SIZE = 0;
const uint32_t NUM_FROZEN_POINTS_STATIC = 0;
const uint32_t NUM_FROZEN_POINTS_DYNAMIC = 1;
// In-mem index related limits
const float GRAPH_SLACK_FACTOR = 1.3f;
// SSD Index related limits
const uint64_t MAX_GRAPH_DEGREE = 512;
const uint64_t SECTOR_LEN = 4096;
const uint64_t MAX_N_SECTOR_READS = 128;
// following constants should always be specified, but are useful as a
// sensible default at cli / python boundaries
const uint32_t MAX_DEGREE = 64;
const uint32_t BUILD_LIST_SIZE = 100;
const uint32_t SATURATE_GRAPH = false;
const uint32_t SEARCH_LIST_SIZE = 100;
} // namespace defaults
} // namespace diskann

View File

@@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <algorithm>
#include <fcntl.h>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
#include <random>
#include <set>
#ifdef __APPLE__
#else
#include <malloc.h>
#endif
#ifdef _WINDOWS
#include <Windows.h>
typedef HANDLE FileHandle;
#else
#include <unistd.h>
typedef int FileHandle;
#endif
#include "cached_io.h"
#include "common_includes.h"
#include "utils.h"
#include "windows_customizations.h"
namespace diskann
{
const size_t MAX_SAMPLE_POINTS_FOR_WARMUP = 100000;
const double PQ_TRAINING_SET_FRACTION = 0.1;
const double SPACE_FOR_CACHED_NODES_IN_GB = 0.25;
const double THRESHOLD_FOR_CACHING_IN_GB = 1.0;
const uint32_t NUM_NODES_TO_CACHE = 250000;
const uint32_t WARMUP_L = 20;
const uint32_t NUM_KMEANS_REPS = 12;
template <typename T, typename LabelT> class PQFlashIndex;
DISKANN_DLLEXPORT double get_memory_budget(const std::string &mem_budget_str);
DISKANN_DLLEXPORT double get_memory_budget(double search_ram_budget_in_gb);
DISKANN_DLLEXPORT void add_new_file_to_single_index(std::string index_file, std::string new_file);
DISKANN_DLLEXPORT size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, uint32_t dim);
DISKANN_DLLEXPORT void read_idmap(const std::string &fname, std::vector<uint32_t> &ivecs);
#ifdef EXEC_ENV_OLS
template <typename T>
DISKANN_DLLEXPORT T *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, uint64_t &warmup_num,
uint64_t warmup_dim, uint64_t warmup_aligned_dim);
#else
template <typename T>
DISKANN_DLLEXPORT T *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, uint64_t warmup_dim,
uint64_t warmup_aligned_dim);
#endif
DISKANN_DLLEXPORT int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suffix,
const std::string &idmaps_prefix, const std::string &idmaps_suffix,
const uint64_t nshards, uint32_t max_degree, const std::string &output_vamana,
const std::string &medoids_file, bool use_filters = false,
const std::string &labels_to_medoids_file = std::string(""));
DISKANN_DLLEXPORT void extract_shard_labels(const std::string &in_label_file, const std::string &shard_ids_bin,
const std::string &shard_label_file);
template <typename T>
DISKANN_DLLEXPORT std::string preprocess_base_file(const std::string &infile, const std::string &indexPrefix,
diskann::Metric &distMetric);
template <typename T, typename LabelT = uint32_t>
DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann::Metric _compareMetric, uint32_t L,
uint32_t R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_file,
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
uint32_t num_threads, bool use_filters = false,
const std::string &label_file = std::string(""),
const std::string &labels_to_medoids_file = std::string(""),
const std::string &universal_label = "", const uint32_t Lf = 0);
template <typename T, typename LabelT>
DISKANN_DLLEXPORT uint32_t optimize_beamwidth(std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &_pFlashIndex,
T *tuning_sample, uint64_t tuning_sample_num,
uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw = 2);
template <typename T, typename LabelT = uint32_t>
DISKANN_DLLEXPORT int build_disk_index(
const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters,
diskann::Metric _compareMetric, bool use_opq = false,
const std::string &codebook_prefix = "", // default is empty for no codebook pass in
bool use_filters = false,
const std::string &label_file = std::string(""), // default is empty string for no label_file
const std::string &universal_label = "", const uint32_t filter_threshold = 0,
const uint32_t Lf = 0); // default is empty string for no universal label
template <typename T>
DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file,
const std::string output_file,
const std::string reorder_data_file = std::string(""));
} // namespace diskann

View File

@@ -0,0 +1,236 @@
#pragma once
#include "windows_customizations.h"
#include <cstring>
#include <cstdint>
namespace diskann
{
enum Metric
{
L2 = 0,
INNER_PRODUCT = 1,
COSINE = 2,
FAST_L2 = 3
};
template <typename T> class Distance
{
public:
DISKANN_DLLEXPORT Distance(diskann::Metric dist_metric) : _distance_metric(dist_metric)
{
}
// distance comparison function
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const = 0;
// Needed only for COSINE-BYTE and INNER_PRODUCT-BYTE
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, const float normA, const float normB,
uint32_t length) const;
// For MIPS, normalization adds an extra dimension to the vectors.
// This function lets callers know if the normalization process
// changes the dimension.
DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const;
DISKANN_DLLEXPORT virtual diskann::Metric get_metric() const;
// This is for efficiency. If no normalization is required, the callers
// can simply ignore the normalize_data_for_build() function.
DISKANN_DLLEXPORT virtual bool preprocessing_required() const;
// Check the preprocessing_required() function before calling this.
// Clients can call the function like this:
//
// if (metric->preprocessing_required()){
// T* normalized_data_batch;
// Split data into batches of batch_size and for each, call:
// metric->preprocess_base_points(data_batch, batch_size);
//
// TODO: This does not take into account the case for SSD inner product
// where the dimensions change after normalization.
DISKANN_DLLEXPORT virtual void preprocess_base_points(T *original_data, const size_t orig_dim,
const size_t num_points);
// Invokes normalization for a single vector during search. The scratch space
// has to be created by the caller keeping track of the fact that
// normalization might change the dimension of the query vector.
DISKANN_DLLEXPORT virtual void preprocess_query(const T *query_vec, const size_t query_dim, T *scratch_query);
// If an algorithm has a requirement that some data be aligned to a certain
// boundary it can use this function to indicate that requirement. Currently,
// we are setting it to 8 because that works well for AVX2. If we have AVX512
// implementations of distance algos, they might have to set this to 16
// (depending on how they are implemented)
DISKANN_DLLEXPORT virtual size_t get_required_alignment() const;
// Providing a default implementation for the virtual destructor because we
// don't expect most metric implementations to need it.
DISKANN_DLLEXPORT virtual ~Distance() = default;
protected:
diskann::Metric _distance_metric;
size_t _alignment_factor = 8;
};
class DistanceCosineInt8 : public Distance<int8_t>
{
public:
DistanceCosineInt8() : Distance<int8_t>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const;
};
class DistanceL2Int8 : public Distance<int8_t>
{
public:
DistanceL2Int8() : Distance<int8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size) const;
};
// AVX implementations. Borrowed from HNSW code.
class AVXDistanceL2Int8 : public Distance<int8_t>
{
public:
AVXDistanceL2Int8() : Distance<int8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const;
};
class DistanceCosineFloat : public Distance<float>
{
public:
DistanceCosineFloat() : Distance<float>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
};
class DistanceL2Float : public Distance<float>
{
public:
DistanceL2Float() : Distance<float>(diskann::Metric::L2)
{
}
#ifdef _WINDOWS
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const;
#else
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const __attribute__((hot));
#endif
};
class AVXDistanceL2Float : public Distance<float>
{
public:
AVXDistanceL2Float() : Distance<float>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
};
template <typename T> class SlowDistanceL2 : public Distance<T>
{
public:
SlowDistanceL2() : Distance<T>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const;
};
class SlowDistanceCosineUInt8 : public Distance<uint8_t>
{
public:
SlowDistanceCosineUInt8() : Distance<uint8_t>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length) const;
};
class DistanceL2UInt8 : public Distance<uint8_t>
{
public:
DistanceL2UInt8() : Distance<uint8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size) const;
};
template <typename T> class DistanceInnerProduct : public Distance<T>
{
public:
DistanceInnerProduct() : Distance<T>(diskann::Metric::INNER_PRODUCT)
{
}
DistanceInnerProduct(diskann::Metric metric) : Distance<T>(metric)
{
}
inline float inner_product(const T *a, const T *b, unsigned size) const;
inline float compare(const T *a, const T *b, unsigned size) const
{
float result = inner_product(a, b, size);
// if (result < 0)
// return std::numeric_limits<float>::max();
// else
return -result;
}
};
template <typename T> class DistanceFastL2 : public DistanceInnerProduct<T>
{
// currently defined only for float.
// templated for future use.
public:
DistanceFastL2() : DistanceInnerProduct<T>(diskann::Metric::FAST_L2)
{
}
float norm(const T *a, unsigned size) const;
float compare(const T *a, const T *b, float norm, unsigned size) const;
};
class AVXDistanceInnerProductFloat : public Distance<float>
{
public:
AVXDistanceInnerProductFloat() : Distance<float>(diskann::Metric::INNER_PRODUCT)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
};
class AVXNormalizedCosineDistanceFloat : public Distance<float>
{
private:
AVXDistanceInnerProductFloat _innerProduct;
protected:
void normalize_and_copy(const float *a, uint32_t length, float *a_norm) const;
public:
AVXNormalizedCosineDistanceFloat() : Distance<float>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const override
{
// Inner product returns negative values to indicate distance.
// This will ensure that cosine is between -1 and 1.
return 1.0f + _innerProduct.compare(a, b, length);
}
DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const override;
DISKANN_DLLEXPORT virtual bool preprocessing_required() const override;
DISKANN_DLLEXPORT virtual void preprocess_base_points(float *original_data, const size_t orig_dim,
const size_t num_points) override;
DISKANN_DLLEXPORT virtual void preprocess_query(const float *query_vec, const size_t query_dim,
float *scratch_query_vector) override;
};
template <typename T> Distance<T> *get_distance_function(Metric m);
} // namespace diskann

View File

@@ -0,0 +1,675 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: embedding.proto
#ifndef GOOGLE_PROTOBUF_INCLUDED_embedding_2eproto
#define GOOGLE_PROTOBUF_INCLUDED_embedding_2eproto
#include <limits>
#include <string>
#include <google/protobuf/port_def.inc>
#if PROTOBUF_VERSION < 3012000
#error This file was generated by a newer version of protoc which is
#error incompatible with your Protocol Buffer headers. Please update
#error your headers.
#endif
#if 3012004 < PROTOBUF_MIN_PROTOC_VERSION
#error This file was generated by an older version of protoc which is
#error incompatible with your Protocol Buffer headers. Please
#error regenerate this file with a newer version of protoc.
#endif
#include <google/protobuf/port_undef.inc>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/arena.h>
#include <google/protobuf/arenastring.h>
#include <google/protobuf/generated_message_table_driven.h>
#include <google/protobuf/generated_message_util.h>
#include <google/protobuf/inlined_string_field.h>
#include <google/protobuf/metadata_lite.h>
#include <google/protobuf/generated_message_reflection.h>
#include <google/protobuf/message.h>
#include <google/protobuf/repeated_field.h> // IWYU pragma: export
#include <google/protobuf/extension_set.h> // IWYU pragma: export
#include <google/protobuf/unknown_field_set.h>
// @@protoc_insertion_point(includes)
#include <google/protobuf/port_def.inc>
#define PROTOBUF_INTERNAL_EXPORT_embedding_2eproto
PROTOBUF_NAMESPACE_OPEN
namespace internal {
class AnyMetadata;
} // namespace internal
PROTOBUF_NAMESPACE_CLOSE
// Internal implementation detail -- do not use these members.
struct TableStruct_embedding_2eproto {
static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[]
PROTOBUF_SECTION_VARIABLE(protodesc_cold);
static const ::PROTOBUF_NAMESPACE_ID::internal::AuxillaryParseTableField aux[]
PROTOBUF_SECTION_VARIABLE(protodesc_cold);
static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[2]
PROTOBUF_SECTION_VARIABLE(protodesc_cold);
static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[];
static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[];
static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[];
};
extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_embedding_2eproto;
namespace protoembedding {
class NodeEmbeddingRequest;
class NodeEmbeddingRequestDefaultTypeInternal;
extern NodeEmbeddingRequestDefaultTypeInternal _NodeEmbeddingRequest_default_instance_;
class NodeEmbeddingResponse;
class NodeEmbeddingResponseDefaultTypeInternal;
extern NodeEmbeddingResponseDefaultTypeInternal _NodeEmbeddingResponse_default_instance_;
} // namespace protoembedding
PROTOBUF_NAMESPACE_OPEN
template<> ::protoembedding::NodeEmbeddingRequest* Arena::CreateMaybeMessage<::protoembedding::NodeEmbeddingRequest>(Arena*);
template<> ::protoembedding::NodeEmbeddingResponse* Arena::CreateMaybeMessage<::protoembedding::NodeEmbeddingResponse>(Arena*);
PROTOBUF_NAMESPACE_CLOSE
namespace protoembedding {
// ===================================================================
class NodeEmbeddingRequest PROTOBUF_FINAL :
public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:protoembedding.NodeEmbeddingRequest) */ {
public:
inline NodeEmbeddingRequest() : NodeEmbeddingRequest(nullptr) {};
virtual ~NodeEmbeddingRequest();
NodeEmbeddingRequest(const NodeEmbeddingRequest& from);
NodeEmbeddingRequest(NodeEmbeddingRequest&& from) noexcept
: NodeEmbeddingRequest() {
*this = ::std::move(from);
}
inline NodeEmbeddingRequest& operator=(const NodeEmbeddingRequest& from) {
CopyFrom(from);
return *this;
}
inline NodeEmbeddingRequest& operator=(NodeEmbeddingRequest&& from) noexcept {
if (GetArena() == from.GetArena()) {
if (this != &from) InternalSwap(&from);
} else {
CopyFrom(from);
}
return *this;
}
static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() {
return GetDescriptor();
}
static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() {
return GetMetadataStatic().descriptor;
}
static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() {
return GetMetadataStatic().reflection;
}
static const NodeEmbeddingRequest& default_instance();
static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY
static inline const NodeEmbeddingRequest* internal_default_instance() {
return reinterpret_cast<const NodeEmbeddingRequest*>(
&_NodeEmbeddingRequest_default_instance_);
}
static constexpr int kIndexInFileMessages =
0;
friend void swap(NodeEmbeddingRequest& a, NodeEmbeddingRequest& b) {
a.Swap(&b);
}
inline void Swap(NodeEmbeddingRequest* other) {
if (other == this) return;
if (GetArena() == other->GetArena()) {
InternalSwap(other);
} else {
::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other);
}
}
void UnsafeArenaSwap(NodeEmbeddingRequest* other) {
if (other == this) return;
GOOGLE_DCHECK(GetArena() == other->GetArena());
InternalSwap(other);
}
// implements Message ----------------------------------------------
inline NodeEmbeddingRequest* New() const final {
return CreateMaybeMessage<NodeEmbeddingRequest>(nullptr);
}
NodeEmbeddingRequest* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final {
return CreateMaybeMessage<NodeEmbeddingRequest>(arena);
}
void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final;
void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final;
void CopyFrom(const NodeEmbeddingRequest& from);
void MergeFrom(const NodeEmbeddingRequest& from);
PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final;
bool IsInitialized() const final;
size_t ByteSizeLong() const final;
const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final;
::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize(
::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final;
int GetCachedSize() const final { return _cached_size_.Get(); }
private:
inline void SharedCtor();
inline void SharedDtor();
void SetCachedSize(int size) const final;
void InternalSwap(NodeEmbeddingRequest* other);
friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata;
static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() {
return "protoembedding.NodeEmbeddingRequest";
}
protected:
explicit NodeEmbeddingRequest(::PROTOBUF_NAMESPACE_ID::Arena* arena);
private:
static void ArenaDtor(void* object);
inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena);
public:
::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final;
private:
static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() {
::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_embedding_2eproto);
return ::descriptor_table_embedding_2eproto.file_level_metadata[kIndexInFileMessages];
}
public:
// nested types ----------------------------------------------------
// accessors -------------------------------------------------------
enum : int {
kNodeIdsFieldNumber = 1,
};
// repeated uint32 node_ids = 1;
int node_ids_size() const;
private:
int _internal_node_ids_size() const;
public:
void clear_node_ids();
private:
::PROTOBUF_NAMESPACE_ID::uint32 _internal_node_ids(int index) const;
const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >&
_internal_node_ids() const;
void _internal_add_node_ids(::PROTOBUF_NAMESPACE_ID::uint32 value);
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >*
_internal_mutable_node_ids();
public:
::PROTOBUF_NAMESPACE_ID::uint32 node_ids(int index) const;
void set_node_ids(int index, ::PROTOBUF_NAMESPACE_ID::uint32 value);
void add_node_ids(::PROTOBUF_NAMESPACE_ID::uint32 value);
const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >&
node_ids() const;
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >*
mutable_node_ids();
// @@protoc_insertion_point(class_scope:protoembedding.NodeEmbeddingRequest)
private:
class _Internal;
template <typename T> friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper;
typedef void InternalArenaConstructable_;
typedef void DestructorSkippable_;
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 > node_ids_;
mutable std::atomic<int> _node_ids_cached_byte_size_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_embedding_2eproto;
};
// -------------------------------------------------------------------
class NodeEmbeddingResponse PROTOBUF_FINAL :
public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:protoembedding.NodeEmbeddingResponse) */ {
public:
inline NodeEmbeddingResponse() : NodeEmbeddingResponse(nullptr) {};
virtual ~NodeEmbeddingResponse();
NodeEmbeddingResponse(const NodeEmbeddingResponse& from);
NodeEmbeddingResponse(NodeEmbeddingResponse&& from) noexcept
: NodeEmbeddingResponse() {
*this = ::std::move(from);
}
inline NodeEmbeddingResponse& operator=(const NodeEmbeddingResponse& from) {
CopyFrom(from);
return *this;
}
inline NodeEmbeddingResponse& operator=(NodeEmbeddingResponse&& from) noexcept {
if (GetArena() == from.GetArena()) {
if (this != &from) InternalSwap(&from);
} else {
CopyFrom(from);
}
return *this;
}
static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() {
return GetDescriptor();
}
static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() {
return GetMetadataStatic().descriptor;
}
static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() {
return GetMetadataStatic().reflection;
}
static const NodeEmbeddingResponse& default_instance();
static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY
static inline const NodeEmbeddingResponse* internal_default_instance() {
return reinterpret_cast<const NodeEmbeddingResponse*>(
&_NodeEmbeddingResponse_default_instance_);
}
static constexpr int kIndexInFileMessages =
1;
friend void swap(NodeEmbeddingResponse& a, NodeEmbeddingResponse& b) {
a.Swap(&b);
}
inline void Swap(NodeEmbeddingResponse* other) {
if (other == this) return;
if (GetArena() == other->GetArena()) {
InternalSwap(other);
} else {
::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other);
}
}
void UnsafeArenaSwap(NodeEmbeddingResponse* other) {
if (other == this) return;
GOOGLE_DCHECK(GetArena() == other->GetArena());
InternalSwap(other);
}
// implements Message ----------------------------------------------
inline NodeEmbeddingResponse* New() const final {
return CreateMaybeMessage<NodeEmbeddingResponse>(nullptr);
}
NodeEmbeddingResponse* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final {
return CreateMaybeMessage<NodeEmbeddingResponse>(arena);
}
void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final;
void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final;
void CopyFrom(const NodeEmbeddingResponse& from);
void MergeFrom(const NodeEmbeddingResponse& from);
PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final;
bool IsInitialized() const final;
size_t ByteSizeLong() const final;
const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final;
::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize(
::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final;
int GetCachedSize() const final { return _cached_size_.Get(); }
private:
inline void SharedCtor();
inline void SharedDtor();
void SetCachedSize(int size) const final;
void InternalSwap(NodeEmbeddingResponse* other);
friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata;
static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() {
return "protoembedding.NodeEmbeddingResponse";
}
protected:
explicit NodeEmbeddingResponse(::PROTOBUF_NAMESPACE_ID::Arena* arena);
private:
static void ArenaDtor(void* object);
inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena);
public:
::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final;
private:
static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() {
::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_embedding_2eproto);
return ::descriptor_table_embedding_2eproto.file_level_metadata[kIndexInFileMessages];
}
public:
// nested types ----------------------------------------------------
// accessors -------------------------------------------------------
enum : int {
kDimensionsFieldNumber = 2,
kMissingIdsFieldNumber = 3,
kEmbeddingsDataFieldNumber = 1,
};
// repeated int32 dimensions = 2;
int dimensions_size() const;
private:
int _internal_dimensions_size() const;
public:
void clear_dimensions();
private:
::PROTOBUF_NAMESPACE_ID::int32 _internal_dimensions(int index) const;
const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >&
_internal_dimensions() const;
void _internal_add_dimensions(::PROTOBUF_NAMESPACE_ID::int32 value);
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >*
_internal_mutable_dimensions();
public:
::PROTOBUF_NAMESPACE_ID::int32 dimensions(int index) const;
void set_dimensions(int index, ::PROTOBUF_NAMESPACE_ID::int32 value);
void add_dimensions(::PROTOBUF_NAMESPACE_ID::int32 value);
const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >&
dimensions() const;
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >*
mutable_dimensions();
// repeated uint32 missing_ids = 3;
int missing_ids_size() const;
private:
int _internal_missing_ids_size() const;
public:
void clear_missing_ids();
private:
::PROTOBUF_NAMESPACE_ID::uint32 _internal_missing_ids(int index) const;
const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >&
_internal_missing_ids() const;
void _internal_add_missing_ids(::PROTOBUF_NAMESPACE_ID::uint32 value);
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >*
_internal_mutable_missing_ids();
public:
::PROTOBUF_NAMESPACE_ID::uint32 missing_ids(int index) const;
void set_missing_ids(int index, ::PROTOBUF_NAMESPACE_ID::uint32 value);
void add_missing_ids(::PROTOBUF_NAMESPACE_ID::uint32 value);
const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >&
missing_ids() const;
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >*
mutable_missing_ids();
// bytes embeddings_data = 1;
void clear_embeddings_data();
const std::string& embeddings_data() const;
void set_embeddings_data(const std::string& value);
void set_embeddings_data(std::string&& value);
void set_embeddings_data(const char* value);
void set_embeddings_data(const void* value, size_t size);
std::string* mutable_embeddings_data();
std::string* release_embeddings_data();
void set_allocated_embeddings_data(std::string* embeddings_data);
GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for"
" string fields are deprecated and will be removed in a"
" future release.")
std::string* unsafe_arena_release_embeddings_data();
GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for"
" string fields are deprecated and will be removed in a"
" future release.")
void unsafe_arena_set_allocated_embeddings_data(
std::string* embeddings_data);
private:
const std::string& _internal_embeddings_data() const;
void _internal_set_embeddings_data(const std::string& value);
std::string* _internal_mutable_embeddings_data();
public:
// @@protoc_insertion_point(class_scope:protoembedding.NodeEmbeddingResponse)
private:
class _Internal;
template <typename T> friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper;
typedef void InternalArenaConstructable_;
typedef void DestructorSkippable_;
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > dimensions_;
mutable std::atomic<int> _dimensions_cached_byte_size_;
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 > missing_ids_;
mutable std::atomic<int> _missing_ids_cached_byte_size_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr embeddings_data_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_embedding_2eproto;
};
// ===================================================================
// ===================================================================
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
// NodeEmbeddingRequest
// repeated uint32 node_ids = 1;
inline int NodeEmbeddingRequest::_internal_node_ids_size() const {
return node_ids_.size();
}
inline int NodeEmbeddingRequest::node_ids_size() const {
return _internal_node_ids_size();
}
inline void NodeEmbeddingRequest::clear_node_ids() {
node_ids_.Clear();
}
inline ::PROTOBUF_NAMESPACE_ID::uint32 NodeEmbeddingRequest::_internal_node_ids(int index) const {
return node_ids_.Get(index);
}
inline ::PROTOBUF_NAMESPACE_ID::uint32 NodeEmbeddingRequest::node_ids(int index) const {
// @@protoc_insertion_point(field_get:protoembedding.NodeEmbeddingRequest.node_ids)
return _internal_node_ids(index);
}
inline void NodeEmbeddingRequest::set_node_ids(int index, ::PROTOBUF_NAMESPACE_ID::uint32 value) {
node_ids_.Set(index, value);
// @@protoc_insertion_point(field_set:protoembedding.NodeEmbeddingRequest.node_ids)
}
inline void NodeEmbeddingRequest::_internal_add_node_ids(::PROTOBUF_NAMESPACE_ID::uint32 value) {
node_ids_.Add(value);
}
inline void NodeEmbeddingRequest::add_node_ids(::PROTOBUF_NAMESPACE_ID::uint32 value) {
_internal_add_node_ids(value);
// @@protoc_insertion_point(field_add:protoembedding.NodeEmbeddingRequest.node_ids)
}
inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >&
NodeEmbeddingRequest::_internal_node_ids() const {
return node_ids_;
}
inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >&
NodeEmbeddingRequest::node_ids() const {
// @@protoc_insertion_point(field_list:protoembedding.NodeEmbeddingRequest.node_ids)
return _internal_node_ids();
}
inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >*
NodeEmbeddingRequest::_internal_mutable_node_ids() {
return &node_ids_;
}
inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >*
NodeEmbeddingRequest::mutable_node_ids() {
// @@protoc_insertion_point(field_mutable_list:protoembedding.NodeEmbeddingRequest.node_ids)
return _internal_mutable_node_ids();
}
// -------------------------------------------------------------------
// NodeEmbeddingResponse
// bytes embeddings_data = 1;
inline void NodeEmbeddingResponse::clear_embeddings_data() {
embeddings_data_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena());
}
inline const std::string& NodeEmbeddingResponse::embeddings_data() const {
// @@protoc_insertion_point(field_get:protoembedding.NodeEmbeddingResponse.embeddings_data)
return _internal_embeddings_data();
}
inline void NodeEmbeddingResponse::set_embeddings_data(const std::string& value) {
_internal_set_embeddings_data(value);
// @@protoc_insertion_point(field_set:protoembedding.NodeEmbeddingResponse.embeddings_data)
}
inline std::string* NodeEmbeddingResponse::mutable_embeddings_data() {
// @@protoc_insertion_point(field_mutable:protoembedding.NodeEmbeddingResponse.embeddings_data)
return _internal_mutable_embeddings_data();
}
inline const std::string& NodeEmbeddingResponse::_internal_embeddings_data() const {
return embeddings_data_.Get();
}
inline void NodeEmbeddingResponse::_internal_set_embeddings_data(const std::string& value) {
embeddings_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena());
}
inline void NodeEmbeddingResponse::set_embeddings_data(std::string&& value) {
embeddings_data_.Set(
&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena());
// @@protoc_insertion_point(field_set_rvalue:protoembedding.NodeEmbeddingResponse.embeddings_data)
}
inline void NodeEmbeddingResponse::set_embeddings_data(const char* value) {
GOOGLE_DCHECK(value != nullptr);
embeddings_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value),
GetArena());
// @@protoc_insertion_point(field_set_char:protoembedding.NodeEmbeddingResponse.embeddings_data)
}
inline void NodeEmbeddingResponse::set_embeddings_data(const void* value,
size_t size) {
embeddings_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(
reinterpret_cast<const char*>(value), size), GetArena());
// @@protoc_insertion_point(field_set_pointer:protoembedding.NodeEmbeddingResponse.embeddings_data)
}
inline std::string* NodeEmbeddingResponse::_internal_mutable_embeddings_data() {
return embeddings_data_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena());
}
inline std::string* NodeEmbeddingResponse::release_embeddings_data() {
// @@protoc_insertion_point(field_release:protoembedding.NodeEmbeddingResponse.embeddings_data)
return embeddings_data_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena());
}
inline void NodeEmbeddingResponse::set_allocated_embeddings_data(std::string* embeddings_data) {
if (embeddings_data != nullptr) {
} else {
}
embeddings_data_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), embeddings_data,
GetArena());
// @@protoc_insertion_point(field_set_allocated:protoembedding.NodeEmbeddingResponse.embeddings_data)
}
inline std::string* NodeEmbeddingResponse::unsafe_arena_release_embeddings_data() {
// @@protoc_insertion_point(field_unsafe_arena_release:protoembedding.NodeEmbeddingResponse.embeddings_data)
GOOGLE_DCHECK(GetArena() != nullptr);
return embeddings_data_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
GetArena());
}
inline void NodeEmbeddingResponse::unsafe_arena_set_allocated_embeddings_data(
std::string* embeddings_data) {
GOOGLE_DCHECK(GetArena() != nullptr);
if (embeddings_data != nullptr) {
} else {
}
embeddings_data_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
embeddings_data, GetArena());
// @@protoc_insertion_point(field_unsafe_arena_set_allocated:protoembedding.NodeEmbeddingResponse.embeddings_data)
}
// repeated int32 dimensions = 2;
inline int NodeEmbeddingResponse::_internal_dimensions_size() const {
return dimensions_.size();
}
inline int NodeEmbeddingResponse::dimensions_size() const {
return _internal_dimensions_size();
}
inline void NodeEmbeddingResponse::clear_dimensions() {
dimensions_.Clear();
}
inline ::PROTOBUF_NAMESPACE_ID::int32 NodeEmbeddingResponse::_internal_dimensions(int index) const {
return dimensions_.Get(index);
}
inline ::PROTOBUF_NAMESPACE_ID::int32 NodeEmbeddingResponse::dimensions(int index) const {
// @@protoc_insertion_point(field_get:protoembedding.NodeEmbeddingResponse.dimensions)
return _internal_dimensions(index);
}
inline void NodeEmbeddingResponse::set_dimensions(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) {
dimensions_.Set(index, value);
// @@protoc_insertion_point(field_set:protoembedding.NodeEmbeddingResponse.dimensions)
}
inline void NodeEmbeddingResponse::_internal_add_dimensions(::PROTOBUF_NAMESPACE_ID::int32 value) {
dimensions_.Add(value);
}
inline void NodeEmbeddingResponse::add_dimensions(::PROTOBUF_NAMESPACE_ID::int32 value) {
_internal_add_dimensions(value);
// @@protoc_insertion_point(field_add:protoembedding.NodeEmbeddingResponse.dimensions)
}
inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >&
NodeEmbeddingResponse::_internal_dimensions() const {
return dimensions_;
}
inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >&
NodeEmbeddingResponse::dimensions() const {
// @@protoc_insertion_point(field_list:protoembedding.NodeEmbeddingResponse.dimensions)
return _internal_dimensions();
}
inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >*
NodeEmbeddingResponse::_internal_mutable_dimensions() {
return &dimensions_;
}
inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >*
NodeEmbeddingResponse::mutable_dimensions() {
// @@protoc_insertion_point(field_mutable_list:protoembedding.NodeEmbeddingResponse.dimensions)
return _internal_mutable_dimensions();
}
// repeated uint32 missing_ids = 3;
inline int NodeEmbeddingResponse::_internal_missing_ids_size() const {
return missing_ids_.size();
}
inline int NodeEmbeddingResponse::missing_ids_size() const {
return _internal_missing_ids_size();
}
inline void NodeEmbeddingResponse::clear_missing_ids() {
missing_ids_.Clear();
}
inline ::PROTOBUF_NAMESPACE_ID::uint32 NodeEmbeddingResponse::_internal_missing_ids(int index) const {
return missing_ids_.Get(index);
}
inline ::PROTOBUF_NAMESPACE_ID::uint32 NodeEmbeddingResponse::missing_ids(int index) const {
// @@protoc_insertion_point(field_get:protoembedding.NodeEmbeddingResponse.missing_ids)
return _internal_missing_ids(index);
}
inline void NodeEmbeddingResponse::set_missing_ids(int index, ::PROTOBUF_NAMESPACE_ID::uint32 value) {
missing_ids_.Set(index, value);
// @@protoc_insertion_point(field_set:protoembedding.NodeEmbeddingResponse.missing_ids)
}
inline void NodeEmbeddingResponse::_internal_add_missing_ids(::PROTOBUF_NAMESPACE_ID::uint32 value) {
missing_ids_.Add(value);
}
inline void NodeEmbeddingResponse::add_missing_ids(::PROTOBUF_NAMESPACE_ID::uint32 value) {
_internal_add_missing_ids(value);
// @@protoc_insertion_point(field_add:protoembedding.NodeEmbeddingResponse.missing_ids)
}
inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >&
NodeEmbeddingResponse::_internal_missing_ids() const {
return missing_ids_;
}
inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >&
NodeEmbeddingResponse::missing_ids() const {
// @@protoc_insertion_point(field_list:protoembedding.NodeEmbeddingResponse.missing_ids)
return _internal_missing_ids();
}
inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >*
NodeEmbeddingResponse::_internal_mutable_missing_ids() {
return &missing_ids_;
}
inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >*
NodeEmbeddingResponse::mutable_missing_ids() {
// @@protoc_insertion_point(field_mutable_list:protoembedding.NodeEmbeddingResponse.missing_ids)
return _internal_mutable_missing_ids();
}
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif // __GNUC__
// -------------------------------------------------------------------
// @@protoc_insertion_point(namespace_scope)
} // namespace protoembedding
// @@protoc_insertion_point(global_scope)
#include <google/protobuf/port_undef.inc>
#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_embedding_2eproto

View File

@@ -0,0 +1,118 @@
#pragma once
#include <vector>
#include <memory>
#ifdef PYBIND11_EMBEDDED
#include <pybind11/embed.h>
#else
#include <pybind11/pybind11.h>
#endif
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace diskann
{
class PYBIND11_EXPORT EmbeddingComputer
{
public:
static EmbeddingComputer &getInstance()
{
static EmbeddingComputer instance;
return instance;
}
void initialize(const std::string &model_path)
{
try
{
py::module_ sys = py::module_::import("sys");
py::module_ os = py::module_::import("os");
// Add the directory containing embedd_micro.py to Python path
std::string micro_dir = "micro";
sys.attr("path").attr("append")(micro_dir);
// Import our module
py::module_ embedd = py::module_::import("embedd_micro");
// Create benchmark config
py::object config = embedd.attr("BenchmarkConfig")(model_path, // model_path
py::list(), // empty batch_sizes
256, // seq_length
1, // num_runs
true, // use_fp16
false, // use_cuda_graphs
false // use_flash_attention
);
// Create benchmark instance
benchmark = embedd.attr("Benchmark")(config);
}
catch (const std::exception &e)
{
throw std::runtime_error("Failed to initialize Python embedding computer: " + std::string(e.what()));
}
}
template <typename T>
std::vector<float> computeEmbeddings(const std::vector<T *> &points, size_t dim, size_t batch_size = 32)
{
try
{
// Convert points to numpy array
std::vector<T> flattened_points;
flattened_points.reserve(points.size() * dim);
for (const auto &point : points)
{
flattened_points.insert(flattened_points.end(), point, point + dim);
}
py::array_t<T> points_array({static_cast<long>(points.size()), static_cast<long>(dim)},
flattened_points.data());
// Call compute_embeddings
py::object result = benchmark.attr("compute_embeddings")(points_array, batch_size);
// Convert result back to C++
py::array_t<float> np_result = result.cast<py::array_t<float>>();
py::buffer_info buf = np_result.request();
float *ptr = static_cast<float *>(buf.ptr);
return std::vector<float>(ptr, ptr + buf.size);
}
catch (const std::exception &e)
{
throw std::runtime_error("Failed to compute embeddings: " + std::string(e.what()));
}
}
private:
EmbeddingComputer()
{
#ifdef PYBIND11_EMBEDDED
if (!Py_IsInitialized())
{
py::initialize_interpreter();
}
#endif
}
~EmbeddingComputer()
{
#ifdef PYBIND11_EMBEDDED
if (Py_IsInitialized())
{
py::finalize_interpreter();
}
#endif
}
py::object benchmark;
};
} // namespace diskann

View File

@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <stdexcept>
namespace diskann
{
class NotImplementedException : public std::logic_error
{
public:
NotImplementedException() : std::logic_error("Function not yet implemented.")
{
}
};
} // namespace diskann

View File

@@ -0,0 +1,221 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <algorithm>
#include <fcntl.h>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
#include <random>
#include <set>
#include <tuple>
#include <string>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>
#ifdef __APPLE__
#else
#include <malloc.h>
#endif
#ifdef _WINDOWS
#include <Windows.h>
typedef HANDLE FileHandle;
#else
#include <unistd.h>
typedef int FileHandle;
#endif
#ifndef _WINDOWS
#include <sys/uio.h>
#endif
#include "cached_io.h"
#include "common_includes.h"
#include "memory_mapper.h"
#include "utils.h"
#include "windows_customizations.h"
// custom types (for readability)
typedef tsl::robin_set<std::string> label_set;
typedef std::string path;
// structs for returning multiple items from a function
typedef std::tuple<std::vector<label_set>, tsl::robin_map<std::string, uint32_t>, tsl::robin_set<std::string>>
parse_label_file_return_values;
typedef std::tuple<std::vector<std::vector<uint32_t>>, uint64_t> load_label_index_return_values;
namespace diskann
{
template <typename T>
DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels,
unsigned R, unsigned L, float alpha, unsigned num_threads);
DISKANN_DLLEXPORT load_label_index_return_values load_label_index(path label_index_path,
uint32_t label_number_of_points);
template <typename LabelT>
DISKANN_DLLEXPORT std::tuple<std::vector<std::vector<LabelT>>, tsl::robin_set<LabelT>> parse_formatted_label_file(
path label_file);
DISKANN_DLLEXPORT parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label);
template <typename T>
DISKANN_DLLEXPORT tsl::robin_map<std::string, std::vector<uint32_t>> generate_label_specific_vector_files_compat(
path input_data_path, tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
std::vector<label_set> point_ids_to_labels, label_set all_labels);
/*
* For each label, generates a file containing all vectors that have said label.
* Also copies data from original bin file to new dimension-aligned file.
*
* Utilizes POSIX functions mmap and writev in order to minimize memory
* overhead, so we include an STL version as well.
*
* Each data file is saved under the following format:
* input_data_path + "_" + label
*/
#ifndef _WINDOWS
template <typename T>
inline tsl::robin_map<std::string, std::vector<uint32_t>> generate_label_specific_vector_files(
path input_data_path, tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
std::vector<label_set> point_ids_to_labels, label_set all_labels)
{
#ifndef _WINDOWS
auto file_writing_timer = std::chrono::high_resolution_clock::now();
diskann::MemoryMapper input_data(input_data_path);
char *input_start = input_data.getBuf();
uint32_t number_of_points, dimension;
std::memcpy(&number_of_points, input_start, sizeof(uint32_t));
std::memcpy(&dimension, input_start + sizeof(uint32_t), sizeof(uint32_t));
const uint32_t VECTOR_SIZE = dimension * sizeof(T);
const size_t METADATA = 2 * sizeof(uint32_t);
if (number_of_points != point_ids_to_labels.size())
{
std::cerr << "Error: number of points in labels file and data file differ." << std::endl;
throw;
}
tsl::robin_map<std::string, iovec *> label_to_iovec_map;
tsl::robin_map<std::string, uint32_t> label_to_curr_iovec;
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id;
// setup iovec list for each label
for (const auto &lbl : all_labels)
{
iovec *label_iovecs = (iovec *)malloc(labels_to_number_of_points[lbl] * sizeof(iovec));
if (label_iovecs == nullptr)
{
throw;
}
label_to_iovec_map[lbl] = label_iovecs;
label_to_curr_iovec[lbl] = 0;
label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]);
}
// each point added to corresponding per-label iovec list
for (uint32_t point_id = 0; point_id < number_of_points; point_id++)
{
char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id);
iovec curr_iovec;
curr_iovec.iov_base = curr_point;
curr_iovec.iov_len = VECTOR_SIZE;
for (const auto &lbl : point_ids_to_labels[point_id])
{
*(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec;
label_to_curr_iovec[lbl]++;
label_id_to_orig_id[lbl].push_back(point_id);
}
}
// write each label iovec to resp. file
for (const auto &lbl : all_labels)
{
int label_input_data_fd;
path curr_label_input_data_path(input_data_path + "_" + lbl);
uint32_t curr_num_pts = labels_to_number_of_points[lbl];
label_input_data_fd =
open(curr_label_input_data_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t)0644);
if (label_input_data_fd == -1)
throw;
// write metadata
uint32_t metadata[2] = {curr_num_pts, dimension};
int return_value = write(label_input_data_fd, metadata, sizeof(uint32_t) * 2);
if (return_value == -1)
{
throw;
}
// limits on number of iovec structs per writev means we need to perform
// multiple writevs
size_t i = 0;
while (curr_num_pts > IOV_MAX)
{
return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX);
if (return_value == -1)
{
close(label_input_data_fd);
throw;
}
curr_num_pts -= IOV_MAX;
i += 1;
}
return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), curr_num_pts);
if (return_value == -1)
{
close(label_input_data_fd);
throw;
}
free(label_to_iovec_map[lbl]);
close(label_input_data_fd);
}
std::chrono::duration<double> file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer;
std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time "
<< file_writing_time.count() << "\n"
<< std::endl;
return label_id_to_orig_id;
#endif
}
#endif
inline std::vector<uint32_t> loadTags(const std::string &tags_file, const std::string &base_file)
{
const bool tags_enabled = tags_file.empty() ? false : true;
std::vector<uint32_t> location_to_tag;
if (tags_enabled)
{
size_t tag_file_ndims, tag_file_npts;
std::uint32_t *tag_data;
diskann::load_bin<std::uint32_t>(tags_file, tag_data, tag_file_npts, tag_file_ndims);
if (tag_file_ndims != 1)
{
diskann::cerr << "tags file error" << std::endl;
throw diskann::ANNException("tag file error", -1, __FUNCSIG__, __FILE__, __LINE__);
}
// check if the point count match
size_t base_file_npts, base_file_ndims;
diskann::get_bin_metadata(base_file, base_file_npts, base_file_ndims);
if (base_file_npts != tag_file_npts)
{
diskann::cerr << "point num in tags file mismatch" << std::endl;
throw diskann::ANNException("point num in tags file mismatch", -1, __FUNCSIG__, __FILE__, __LINE__);
}
location_to_tag.assign(tag_data, tag_data + tag_file_npts);
delete[] tag_data;
}
return location_to_tag;
}
} // namespace diskann

View File

@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <shared_mutex>
#include <memory>
#include "tsl/robin_map.h"
#include "tsl/robin_set.h"
#include "tsl/sparse_map.h"
// #include "boost/dynamic_bitset.hpp"
#include "abstract_data_store.h"
#include "distance.h"
#include "natural_number_map.h"
#include "natural_number_set.h"
#include "aligned_file_reader.h"
namespace diskann
{
template <typename data_t> class InMemDataStore : public AbstractDataStore<data_t>
{
public:
InMemDataStore(const location_t capacity, const size_t dim, std::unique_ptr<Distance<data_t>> distance_fn);
virtual ~InMemDataStore();
virtual location_t load(const std::string &filename) override;
virtual size_t save(const std::string &filename, const location_t num_points) override;
virtual size_t get_aligned_dim() const override;
// Populate internal data from unaligned data while doing alignment and any
// normalization that is required.
virtual void populate_data(const data_t *vectors, const location_t num_pts) override;
virtual void populate_data(const std::string &filename, const size_t offset) override;
virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) override;
virtual void get_vector(const location_t i, data_t *target) const override;
virtual void set_vector(const location_t i, const data_t *const vector) override;
virtual void prefetch_vector(const location_t loc) override;
virtual void move_vectors(const location_t old_location_start, const location_t new_location_start,
const location_t num_points) override;
virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override;
virtual void preprocess_query(const data_t *query, AbstractScratch<data_t> *query_scratch) const override;
virtual float get_distance(const data_t *preprocessed_query, const location_t loc) const override;
virtual float get_distance(const location_t loc1, const location_t loc2) const override;
virtual void get_distance(const data_t *preprocessed_query, const location_t *locations,
const uint32_t location_count, float *distances,
AbstractScratch<data_t> *scratch) const override;
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const override;
virtual location_t calculate_medoid() const override;
virtual Distance<data_t> *get_dist_fn() const override;
virtual size_t get_alignment_factor() const override;
protected:
virtual location_t expand(const location_t new_size) override;
virtual location_t shrink(const location_t new_size) override;
virtual location_t load_impl(const std::string &filename);
#ifdef EXEC_ENV_OLS
virtual location_t load_impl(AlignedFileReader &reader);
#endif
private:
data_t *_data = nullptr;
size_t _aligned_dim;
// It may seem weird to put distance metric along with the data store class,
// but this gives us perf benefits as the datastore can do distance
// computations during search and compute norms of vectors internally without
// have to copy data back and forth.
std::unique_ptr<Distance<data_t>> _distance_fn;
// in case we need to save vector norms for optimization
std::shared_ptr<float[]> _pre_computed_norms;
};
} // namespace diskann

View File

@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include "abstract_graph_store.h"
namespace diskann
{
class InMemGraphStore : public AbstractGraphStore
{
public:
InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree);
// returns tuple of <nodes_read, start, num_frozen_points>
virtual std::tuple<uint32_t, uint32_t, size_t> load(const std::string &index_path_prefix,
const size_t num_points) override;
virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_frozen_points,
const uint32_t start) override;
virtual const std::vector<location_t> &get_neighbours(const location_t i) const override;
virtual void add_neighbour(const location_t i, location_t neighbour_id) override;
virtual void clear_neighbours(const location_t i) override;
virtual void swap_neighbours(const location_t a, location_t b) override;
virtual void set_neighbours(const location_t i, std::vector<location_t> &neighbors) override;
virtual size_t resize_graph(const size_t new_size) override;
virtual void clear_graph() override;
virtual size_t get_max_range_of_graph() override;
virtual uint32_t get_max_observed_degree() override;
protected:
virtual std::tuple<uint32_t, uint32_t, size_t> load_impl(const std::string &filename, size_t expected_num_points);
#ifdef EXEC_ENV_OLS
virtual std::tuple<uint32_t, uint32_t, size_t> load_impl(AlignedFileReader &reader, size_t expected_num_points);
#endif
int save_graph(const std::string &index_path_prefix, const size_t active_points, const size_t num_frozen_points,
const uint32_t start);
private:
size_t _max_range_of_graph = 0;
uint32_t _max_observed_degree = 0;
std::vector<std::vector<uint32_t>> _graph;
};
} // namespace diskann

View File

@@ -0,0 +1,452 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include "common_includes.h"
#ifdef EXEC_ENV_OLS
#include "aligned_file_reader.h"
#endif
#include "distance.h"
#include "locking.h"
#include "natural_number_map.h"
#include "natural_number_set.h"
#include "neighbor.h"
#include "parameters.h"
#include "utils.h"
#include "windows_customizations.h"
#include "scratch.h"
#include "in_mem_data_store.h"
#include "in_mem_graph_store.h"
#include "abstract_index.h"
#include "quantized_distance.h"
#include "pq_data_store.h"
#define OVERHEAD_FACTOR 1.1
#define EXPAND_IF_FULL 0
#define DEFAULT_MAXC 750
namespace diskann
{
inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, uint32_t degree)
{
double size_of_data = ((double)size) * ROUND_UP(dim, 8) * datasize;
double size_of_graph = ((double)size) * degree * sizeof(uint32_t) * defaults::GRAPH_SLACK_FACTOR;
double size_of_locks = ((double)size) * sizeof(non_recursive_mutex);
double size_of_outer_vector = ((double)size) * sizeof(ptrdiff_t);
return OVERHEAD_FACTOR * (size_of_data + size_of_graph + size_of_locks + size_of_outer_vector);
}
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> class Index : public AbstractIndex
{
/**************************************************************************
*
* Public functions acquire one or more of _update_lock, _consolidate_lock,
* _tag_lock, _delete_lock before calling protected functions which DO NOT
* acquire these locks. They might acquire locks on _locks[i]
*
**************************************************************************/
public:
// Constructor for Bulk operations and for creating the index object solely
// for loading a prexisting index.
DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::shared_ptr<AbstractDataStore<T>> data_store,
std::unique_ptr<AbstractGraphStore> graph_store,
std::shared_ptr<AbstractDataStore<T>> pq_data_store = nullptr);
// Constructor for incremental index
DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points,
const std::shared_ptr<IndexWriteParameters> index_parameters,
const std::shared_ptr<IndexSearchParams> index_search_params,
const size_t num_frozen_pts = 0, const bool dynamic_index = false,
const bool enable_tags = false, const bool concurrent_consolidate = false,
const bool pq_dist_build = false, const size_t num_pq_chunks = 0,
const bool use_opq = false, const bool filtered_index = false);
DISKANN_DLLEXPORT ~Index();
// Saves graph, data, metadata and associated tags.
DISKANN_DLLEXPORT void save(const char *filename, bool compact_before_save = false) override;
// Load functions
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l);
#else
// Reads the number of frozen points from graph's metadata file section.
DISKANN_DLLEXPORT static size_t get_graph_num_frozen_points(const std::string &graph_file);
DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l) override;
#endif
// get some private variables
DISKANN_DLLEXPORT size_t get_num_points();
DISKANN_DLLEXPORT size_t get_max_points();
DISKANN_DLLEXPORT bool detect_common_filters(uint32_t point_id, bool search_invocation,
const std::vector<LabelT> &incoming_labels);
// Batch build from a file. Optionally pass tags vector.
DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load,
const std::vector<TagT> &tags = std::vector<TagT>());
// Batch build from a file. Optionally pass tags file.
DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, const char *tag_filename);
// Batch build from a data array, which must pad vectors to aligned_dim
DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, const std::vector<TagT> &tags);
// Based on filter params builds a filtered or unfiltered index
DISKANN_DLLEXPORT void build(const std::string &data_file, const size_t num_points_to_load,
IndexFilterParams &filter_params) override;
// Filtered Support
DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file,
const size_t num_points_to_load,
const std::vector<TagT> &tags = std::vector<TagT>());
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
// Get converted integer label from string to int map (_label_map)
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label);
// Set starting point of an index before inserting any points incrementally.
// The data count should be equal to _num_frozen_pts * _aligned_dim.
DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count);
// Set starting points to random points on a sphere of certain radius.
// A fixed random seed can be specified for scenarios where it's important
// to have higher consistency between index builds.
DISKANN_DLLEXPORT void set_start_points_at_random(T radius, uint32_t random_seed = 0);
// For FastL2 search on a static index, we interleave the data with graph
DISKANN_DLLEXPORT void optimize_index_layout() override;
// For FastL2 search on optimized layout
DISKANN_DLLEXPORT void search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices);
// Added search overload that takes L as parameter, so that we
// can customize L on a per-query basis without tampering with "Parameters"
template <typename IDType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search(const T *query, const size_t K, const uint32_t L,
IDType *indices, float *distances = nullptr);
// Initialize space for res_vectors before calling.
DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
float *distances, std::vector<T *> &res_vectors, bool use_filters = false,
const std::string filter_label = "");
// Filter support search
template <typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const LabelT &filter_label,
const size_t K, const uint32_t L,
IndexType *indices, float *distances);
// Will fail if tag already in the index or if tag=0.
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);
// Will fail if tag already in the index or if tag=0.
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag, const std::vector<LabelT> &label);
// call this before issuing deletions to sets relevant flags
DISKANN_DLLEXPORT int enable_delete();
// Record deleted point now and restructure graph later. Return -1 if tag
// not found, 0 if OK.
DISKANN_DLLEXPORT int lazy_delete(const TagT &tag);
// Record deleted points now and restructure graph later. Add to failed_tags
// if tag not found.
DISKANN_DLLEXPORT void lazy_delete(const std::vector<TagT> &tags, std::vector<TagT> &failed_tags);
// Call after a series of lazy deletions
// Returns number of live points left after consolidation
// If _conc_consolidates is set in the ctor, then this call can be invoked
// alongside inserts and lazy deletes, else it acquires _update_lock
DISKANN_DLLEXPORT consolidation_report consolidate_deletes(const IndexWriteParameters &parameters) override;
DISKANN_DLLEXPORT void prune_all_neighbors(const uint32_t max_degree, const uint32_t max_occlusion,
const float alpha);
DISKANN_DLLEXPORT bool is_index_saved();
// repositions frozen points to the end of _data - if they have been moved
// during deletion
DISKANN_DLLEXPORT void reposition_frozen_point_to_end();
DISKANN_DLLEXPORT void reposition_points(uint32_t old_location_start, uint32_t new_location_start,
uint32_t num_locations);
// DISKANN_DLLEXPORT void save_index_as_one_file(bool flag);
DISKANN_DLLEXPORT void get_active_tags(tsl::robin_set<TagT> &active_tags);
// memory should be allocated for vec before calling this function
DISKANN_DLLEXPORT int get_vector_by_tag(TagT &tag, T *vec);
DISKANN_DLLEXPORT void print_status();
DISKANN_DLLEXPORT void count_nodes_at_bfs_levels();
// This variable MUST be updated if the number of entries in the metadata
// change.
DISKANN_DLLEXPORT static const int METADATA_ROWS = 5;
DISKANN_DLLEXPORT void get_degree_stats(size_t &max_deg, size_t &min_deg, size_t &avg_deg, size_t &cnt_deg);
DISKANN_DLLEXPORT void dump_degree_stats(std::string filename);
// ********************************
//
// Internals of the library
//
// ********************************
protected:
// overload of abstract index virtual methods
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) override;
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
std::any &indices, float *distances = nullptr) override;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
const std::string &filter_label_raw, const size_t K,
const uint32_t L, std::any &indices,
float *distances) override;
virtual int _insert_point(const DataType &data_point, const TagType tag) override;
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override;
virtual int _lazy_delete(const TagType &tag) override;
virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) override;
virtual void _get_active_tags(TagRobinSet &active_tags) override;
virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) override;
virtual int _get_vector_by_tag(TagType &tag, DataType &vec) override;
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override;
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
float *distances, DataVector &res_vectors, bool use_filters = false,
const std::string filter_label = "") override;
virtual void _set_universal_label(const LabelType universal_label) override;
// No copy/assign.
Index(const Index<T, TagT, LabelT> &) = delete;
Index<T, TagT, LabelT> &operator=(const Index<T, TagT, LabelT> &) = delete;
// Use after _data and _nd have been populated
// Acquire exclusive _update_lock before calling
void build_with_data_populated(const std::vector<TagT> &tags);
// generates 1 frozen point that will never be deleted from the graph
// This is not visible to the user
void generate_frozen_point();
// determines navigating node of the graph by calculating medoid of datafopt
uint32_t calculate_entry_point();
void parse_label_file(const std::string &label_file, size_t &num_pts_labels);
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
// Returns the locations of start point and frozen points suitable for use
// with iterate_to_fixed_point.
std::vector<uint32_t> get_init_ids();
// The query to use is placed in scratch->aligned_query
std::pair<uint32_t, uint32_t> iterate_to_fixed_point(InMemQueryScratch<T> *scratch, const uint32_t Lindex,
const std::vector<uint32_t> &init_ids, bool use_filter,
const std::vector<LabelT> &filters, bool search_invocation);
void search_for_point_and_prune(int location, uint32_t Lindex, std::vector<uint32_t> &pruned_list,
InMemQueryScratch<T> *scratch, bool use_filter = false,
uint32_t filteredLindex = 0);
void prune_neighbors(const uint32_t location, std::vector<Neighbor> &pool, std::vector<uint32_t> &pruned_list,
InMemQueryScratch<T> *scratch);
void prune_neighbors(const uint32_t location, std::vector<Neighbor> &pool, const uint32_t range,
const uint32_t max_candidate_size, const float alpha, std::vector<uint32_t> &pruned_list,
InMemQueryScratch<T> *scratch);
// Prunes candidates in @pool to a shorter list @result
// @pool must be sorted before calling
void occlude_list(const uint32_t location, std::vector<Neighbor> &pool, const float alpha, const uint32_t degree,
const uint32_t maxc, std::vector<uint32_t> &result, InMemQueryScratch<T> *scratch,
const tsl::robin_set<uint32_t> *const delete_set_ptr = nullptr);
// add reverse links from all the visited nodes to node n.
void inter_insert(uint32_t n, std::vector<uint32_t> &pruned_list, const uint32_t range,
InMemQueryScratch<T> *scratch);
void inter_insert(uint32_t n, std::vector<uint32_t> &pruned_list, InMemQueryScratch<T> *scratch);
// Acquire exclusive _update_lock before calling
void link();
// Acquire exclusive _tag_lock and _delete_lock before calling
int reserve_location();
// Acquire exclusive _tag_lock before calling
size_t release_location(int location);
size_t release_locations(const tsl::robin_set<uint32_t> &locations);
// Resize the index when no slots are left for insertion.
// Acquire exclusive _update_lock and _tag_lock before calling.
void resize(size_t new_max_points);
// Acquire unique lock on _update_lock, _consolidate_lock, _tag_lock
// and _delete_lock before calling these functions.
// Renumber nodes, update tag and location maps and compact the
// graph, mode = _consolidated_order in case of lazy deletion and
// _compacted_order in case of eager deletion
DISKANN_DLLEXPORT void compact_data();
DISKANN_DLLEXPORT void compact_frozen_point();
// Remove deleted nodes from adjacency list of node loc
// Replace removed neighbors with second order neighbors.
// Also acquires _locks[i] for i = loc and out-neighbors of loc.
void process_delete(const tsl::robin_set<uint32_t> &old_delete_set, size_t loc, const uint32_t range,
const uint32_t maxc, const float alpha, InMemQueryScratch<T> *scratch);
void initialize_query_scratch(uint32_t num_threads, uint32_t search_l, uint32_t indexing_l, uint32_t r,
uint32_t maxc, size_t dim);
// Do not call without acquiring appropriate locks
// call public member functions save and load to invoke these.
DISKANN_DLLEXPORT size_t save_graph(std::string filename);
DISKANN_DLLEXPORT size_t save_data(std::string filename);
DISKANN_DLLEXPORT size_t save_tags(std::string filename);
DISKANN_DLLEXPORT size_t save_delete_list(const std::string &filename);
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT size_t load_graph(AlignedFileReader &reader, size_t expected_num_points);
DISKANN_DLLEXPORT size_t load_data(AlignedFileReader &reader);
DISKANN_DLLEXPORT size_t load_tags(AlignedFileReader &reader);
DISKANN_DLLEXPORT size_t load_delete_set(AlignedFileReader &reader);
#else
DISKANN_DLLEXPORT size_t load_graph(const std::string filename, size_t expected_num_points);
DISKANN_DLLEXPORT size_t load_data(std::string filename0);
DISKANN_DLLEXPORT size_t load_tags(const std::string tag_file_name);
DISKANN_DLLEXPORT size_t load_delete_set(const std::string &filename);
#endif
private:
// Distance functions
Metric _dist_metric = diskann::L2;
// Data
std::shared_ptr<AbstractDataStore<T>> _data_store;
// Graph related data structures
std::unique_ptr<AbstractGraphStore> _graph_store;
char *_opt_graph = nullptr;
// Dimensions
size_t _dim = 0;
size_t _nd = 0; // number of active points i.e. existing in the graph
size_t _max_points = 0; // total number of points in given data set
// _num_frozen_pts is the number of points which are used as initial
// candidates when iterating to closest point(s). These are not visible
// externally and won't be returned by search. At least 1 frozen point is
// needed for a dynamic index. The frozen points have consecutive locations.
// See also _start below.
size_t _num_frozen_pts = 0;
size_t _frozen_pts_used = 0;
size_t _node_size;
size_t _data_len;
size_t _neighbor_len;
// Start point of the search. When _num_frozen_pts is greater than zero,
// this is the location of the first frozen point. Otherwise, this is a
// location of one of the points in index.
uint32_t _start = 0;
bool _has_built = false;
bool _saturate_graph = false;
bool _save_as_one_file = false; // plan to support in next version
bool _dynamic_index = false;
bool _enable_tags = false;
bool _normalize_vecs = false; // Using normalied L2 for cosine.
bool _deletes_enabled = false;
// Filter Support
bool _filtered_index = false;
// Location to label is only updated during insert_point(), all other reads are protected by
// default as a location can only be released at end of consolidate deletes
std::vector<std::vector<LabelT>> _location_to_labels;
tsl::robin_set<LabelT> _labels;
std::string _labels_file;
std::unordered_map<LabelT, uint32_t> _label_to_start_id;
std::unordered_map<uint32_t, uint32_t> _medoid_counts;
bool _use_universal_label = false;
LabelT _universal_label = 0;
uint32_t _filterIndexingQueueSize;
std::unordered_map<std::string, LabelT> _label_map;
// Indexing parameters
uint32_t _indexingQueueSize;
uint32_t _indexingRange;
uint32_t _indexingMaxC;
float _indexingAlpha;
uint32_t _indexingThreads;
// Query scratch data structures
ConcurrentQueue<InMemQueryScratch<T> *> _query_scratch;
// Flags for PQ based distance calculation
bool _pq_dist = false;
bool _use_opq = false;
size_t _num_pq_chunks = 0;
// REFACTOR
// uint8_t *_pq_data = nullptr;
std::shared_ptr<QuantizedDistance<T>> _pq_distance_fn = nullptr;
std::shared_ptr<AbstractDataStore<T>> _pq_data_store = nullptr;
bool _pq_generated = false;
FixedChunkPQTable _pq_table;
//
// Data structures, locks and flags for dynamic indexing and tags
//
// lazy_delete removes entry from _location_to_tag and _tag_to_location. If
// _location_to_tag does not resolve a location, infer that it was deleted.
tsl::sparse_map<TagT, uint32_t> _tag_to_location;
natural_number_map<uint32_t, TagT> _location_to_tag;
// _empty_slots has unallocated slots and those freed by consolidate_delete.
// _delete_set has locations marked deleted by lazy_delete. Will not be
// immediately available for insert. consolidate_delete will release these
// slots to _empty_slots.
natural_number_set<uint32_t> _empty_slots;
std::unique_ptr<tsl::robin_set<uint32_t>> _delete_set;
bool _data_compacted = true; // true if data has been compacted
bool _is_saved = false; // Checking if the index is already saved.
bool _conc_consolidate = false; // use _lock while searching
// Acquire locks in the order below when acquiring multiple locks
std::shared_timed_mutex // RW mutex between save/load (exclusive lock) and
_update_lock; // search/inserts/deletes/consolidate (shared lock)
std::shared_timed_mutex // Ensure only one consolidate or compact_data is
_consolidate_lock; // ever active
std::shared_timed_mutex // RW lock for _tag_to_location,
_tag_lock; // _location_to_tag, _empty_slots, _nd, _max_points, _label_to_start_id
std::shared_timed_mutex // RW Lock on _delete_set and _data_compacted
_delete_lock; // variable
// Per node lock, cardinality=_max_points + _num_frozen_points
std::vector<non_recursive_mutex> _locks;
static const float INDEX_GROWTH_FACTOR;
};
} // namespace diskann

View File

@@ -0,0 +1,73 @@
#pragma once
#include "common_includes.h"
#include "parameters.h"
namespace diskann
{
struct IndexFilterParams
{
public:
std::string save_path_prefix;
std::string label_file;
std::string tags_file;
std::string universal_label;
uint32_t filter_threshold = 0;
private:
IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file,
const std::string &universal_label, uint32_t filter_threshold)
: save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label),
filter_threshold(filter_threshold)
{
}
friend class IndexFilterParamsBuilder;
};
class IndexFilterParamsBuilder
{
public:
IndexFilterParamsBuilder() = default;
IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix)
{
if (save_path_prefix.empty() || save_path_prefix == "")
throw ANNException("Error: save_path_prefix can't be empty", -1);
this->_save_path_prefix = save_path_prefix;
return *this;
}
IndexFilterParamsBuilder &with_label_file(const std::string &label_file)
{
this->_label_file = label_file;
return *this;
}
IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label)
{
this->_universal_label = univeral_label;
return *this;
}
IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold)
{
this->_filter_threshold = filter_threshold;
return *this;
}
IndexFilterParams build()
{
return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold);
}
IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete;
IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete;
private:
std::string _save_path_prefix;
std::string _label_file;
std::string _tags_file;
std::string _universal_label;
uint32_t _filter_threshold = 0;
};
} // namespace diskann

View File

@@ -0,0 +1,256 @@
#pragma once
#include "common_includes.h"
#include "parameters.h"
namespace diskann
{
enum class DataStoreStrategy
{
MEMORY
};
enum class GraphStoreStrategy
{
MEMORY
};
struct IndexConfig
{
DataStoreStrategy data_strategy;
GraphStoreStrategy graph_strategy;
Metric metric;
size_t dimension;
size_t max_points;
bool dynamic_index;
bool enable_tags;
bool pq_dist_build;
bool concurrent_consolidate;
bool use_opq;
bool filtered_index;
size_t num_pq_chunks;
size_t num_frozen_pts;
std::string label_type;
std::string tag_type;
std::string data_type;
// Params for building index
std::shared_ptr<IndexWriteParameters> index_write_params;
// Params for searching index
std::shared_ptr<IndexSearchParams> index_search_params;
private:
IndexConfig(DataStoreStrategy data_strategy, GraphStoreStrategy graph_strategy, Metric metric, size_t dimension,
size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags,
bool pq_dist_build, bool concurrent_consolidate, bool use_opq, bool filtered_index,
std::string &data_type, const std::string &tag_type, const std::string &label_type,
std::shared_ptr<IndexWriteParameters> index_write_params,
std::shared_ptr<IndexSearchParams> index_search_params)
: data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension),
max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build),
concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), filtered_index(filtered_index),
num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type),
data_type(data_type), index_write_params(index_write_params), index_search_params(index_search_params)
{
}
friend class IndexConfigBuilder;
};
class IndexConfigBuilder
{
public:
IndexConfigBuilder() = default;
IndexConfigBuilder &with_metric(Metric m)
{
this->_metric = m;
return *this;
}
IndexConfigBuilder &with_graph_load_store_strategy(GraphStoreStrategy graph_strategy)
{
this->_graph_strategy = graph_strategy;
return *this;
}
IndexConfigBuilder &with_data_load_store_strategy(DataStoreStrategy data_strategy)
{
this->_data_strategy = data_strategy;
return *this;
}
IndexConfigBuilder &with_dimension(size_t dimension)
{
this->_dimension = dimension;
return *this;
}
IndexConfigBuilder &with_max_points(size_t max_points)
{
this->_max_points = max_points;
return *this;
}
IndexConfigBuilder &is_dynamic_index(bool dynamic_index)
{
this->_dynamic_index = dynamic_index;
return *this;
}
IndexConfigBuilder &is_enable_tags(bool enable_tags)
{
this->_enable_tags = enable_tags;
return *this;
}
IndexConfigBuilder &is_pq_dist_build(bool pq_dist_build)
{
this->_pq_dist_build = pq_dist_build;
return *this;
}
IndexConfigBuilder &is_concurrent_consolidate(bool concurrent_consolidate)
{
this->_concurrent_consolidate = concurrent_consolidate;
return *this;
}
IndexConfigBuilder &is_use_opq(bool use_opq)
{
this->_use_opq = use_opq;
return *this;
}
IndexConfigBuilder &is_filtered(bool is_filtered)
{
this->_filtered_index = is_filtered;
return *this;
}
IndexConfigBuilder &with_num_pq_chunks(size_t num_pq_chunks)
{
this->_num_pq_chunks = num_pq_chunks;
return *this;
}
IndexConfigBuilder &with_num_frozen_pts(size_t num_frozen_pts)
{
this->_num_frozen_pts = num_frozen_pts;
return *this;
}
IndexConfigBuilder &with_label_type(const std::string &label_type)
{
this->_label_type = label_type;
return *this;
}
IndexConfigBuilder &with_tag_type(const std::string &tag_type)
{
this->_tag_type = tag_type;
return *this;
}
IndexConfigBuilder &with_data_type(const std::string &data_type)
{
this->_data_type = data_type;
return *this;
}
IndexConfigBuilder &with_index_write_params(IndexWriteParameters &index_write_params)
{
this->_index_write_params = std::make_shared<IndexWriteParameters>(index_write_params);
return *this;
}
IndexConfigBuilder &with_index_write_params(std::shared_ptr<IndexWriteParameters> index_write_params_ptr)
{
if (index_write_params_ptr == nullptr)
{
diskann::cout << "Passed, empty build_params while creating index config" << std::endl;
return *this;
}
this->_index_write_params = index_write_params_ptr;
return *this;
}
IndexConfigBuilder &with_index_search_params(IndexSearchParams &search_params)
{
this->_index_search_params = std::make_shared<IndexSearchParams>(search_params);
return *this;
}
IndexConfigBuilder &with_index_search_params(std::shared_ptr<IndexSearchParams> search_params_ptr)
{
if (search_params_ptr == nullptr)
{
diskann::cout << "Passed, empty search_params while creating index config" << std::endl;
return *this;
}
this->_index_search_params = search_params_ptr;
return *this;
}
IndexConfig build()
{
if (_data_type == "" || _data_type.empty())
throw ANNException("Error: data_type can not be empty", -1);
if (_dynamic_index && _num_frozen_pts == 0)
{
_num_frozen_pts = 1;
}
if (_dynamic_index)
{
if (_index_search_params != nullptr && _index_search_params->initial_search_list_size == 0)
throw ANNException("Error: please pass initial_search_list_size for building dynamic index.", -1);
}
// sanity check
if (_dynamic_index && _num_frozen_pts == 0)
{
diskann::cout << "_num_frozen_pts passed as 0 for dynamic_index. Setting it to 1 for safety." << std::endl;
_num_frozen_pts = 1;
}
return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks,
_num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate,
_use_opq, _filtered_index, _data_type, _tag_type, _label_type, _index_write_params,
_index_search_params);
}
IndexConfigBuilder(const IndexConfigBuilder &) = delete;
IndexConfigBuilder &operator=(const IndexConfigBuilder &) = delete;
private:
DataStoreStrategy _data_strategy;
GraphStoreStrategy _graph_strategy;
Metric _metric;
size_t _dimension;
size_t _max_points;
bool _dynamic_index = false;
bool _enable_tags = false;
bool _pq_dist_build = false;
bool _concurrent_consolidate = false;
bool _use_opq = false;
bool _filtered_index{defaults::HAS_LABELS};
size_t _num_pq_chunks = 0;
size_t _num_frozen_pts{defaults::NUM_FROZEN_POINTS_STATIC};
std::string _label_type{"uint32"};
std::string _tag_type{"uint32"};
std::string _data_type;
std::shared_ptr<IndexWriteParameters> _index_write_params;
std::shared_ptr<IndexSearchParams> _index_search_params;
};
} // namespace diskann

View File

@@ -0,0 +1,51 @@
#pragma once
#include "index.h"
#include "abstract_graph_store.h"
#include "in_mem_graph_store.h"
#include "pq_data_store.h"
namespace diskann
{
class IndexFactory
{
public:
DISKANN_DLLEXPORT explicit IndexFactory(const IndexConfig &config);
DISKANN_DLLEXPORT std::unique_ptr<AbstractIndex> create_instance();
DISKANN_DLLEXPORT static std::unique_ptr<AbstractGraphStore> construct_graphstore(
const GraphStoreStrategy stratagy, const size_t size, const size_t reserve_graph_degree);
template <typename T>
DISKANN_DLLEXPORT static std::shared_ptr<AbstractDataStore<T>> construct_datastore(DataStoreStrategy stratagy,
size_t num_points,
size_t dimension, Metric m);
// For now PQDataStore incorporates within itself all variants of quantization that we support. In the
// future it may be necessary to introduce an AbstractPQDataStore class to spearate various quantization
// flavours.
template <typename T>
DISKANN_DLLEXPORT static std::shared_ptr<PQDataStore<T>> construct_pq_datastore(DataStoreStrategy strategy,
size_t num_points, size_t dimension,
Metric m, size_t num_pq_chunks,
bool use_opq);
template <typename T> static Distance<T> *construct_inmem_distance_fn(Metric m);
private:
void check_config();
template <typename data_type, typename tag_type, typename label_type>
std::unique_ptr<AbstractIndex> create_instance();
std::unique_ptr<AbstractIndex> create_instance(const std::string &data_type, const std::string &tag_type,
const std::string &label_type);
template <typename data_type>
std::unique_ptr<AbstractIndex> create_instance(const std::string &tag_type, const std::string &label_type);
template <typename data_type, typename tag_type>
std::unique_ptr<AbstractIndex> create_instance(const std::string &label_type);
std::unique_ptr<IndexConfig> _config;
};
} // namespace diskann

View File

@@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#ifndef _WINDOWS
#ifndef __APPLE__
#include "aligned_file_reader.h"
class LinuxAlignedFileReader : public AlignedFileReader
{
private:
uint64_t file_sz;
FileHandle file_desc;
io_context_t bad_ctx = (io_context_t)-1;
public:
LinuxAlignedFileReader();
~LinuxAlignedFileReader();
IOContext &get_ctx();
// register thread-id for a context
void register_thread();
// de-register thread-id for a context
void deregister_thread();
void deregister_all_threads();
// Open & close ops
// Blocking calls
void open(const std::string &fname);
void close();
// process batch of aligned requests in parallel
// NOTE :: blocking call
void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async = false);
};
#endif
#endif

View File

@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <mutex>
#ifdef _WINDOWS
#include "windows_slim_lock.h"
#endif
namespace diskann
{
#ifdef _WINDOWS
using non_recursive_mutex = windows_exclusive_slim_lock;
using LockGuard = windows_exclusive_slim_lock_guard;
#else
using non_recursive_mutex = std::mutex;
using LockGuard = std::lock_guard<non_recursive_mutex>;
#endif
} // namespace diskann

View File

@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <functional>
#include <iostream>
#include "windows_customizations.h"
#ifdef EXEC_ENV_OLS
#ifndef ENABLE_CUSTOM_LOGGER
#define ENABLE_CUSTOM_LOGGER
#endif // !ENABLE_CUSTOM_LOGGER
#endif // EXEC_ENV_OLS
namespace diskann
{
#ifdef ENABLE_CUSTOM_LOGGER
DISKANN_DLLEXPORT extern std::basic_ostream<char> cout;
DISKANN_DLLEXPORT extern std::basic_ostream<char> cerr;
#else
using std::cerr;
using std::cout;
#endif
enum class DISKANN_DLLEXPORT LogLevel
{
LL_Info = 0,
LL_Error,
LL_Count
};
#ifdef ENABLE_CUSTOM_LOGGER
DISKANN_DLLEXPORT void SetCustomLogger(std::function<void(LogLevel, const char *)> logger);
#endif
} // namespace diskann

View File

@@ -0,0 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <sstream>
#include <mutex>
#include "ann_exception.h"
#include "logger.h"
namespace diskann
{
#ifdef ENABLE_CUSTOM_LOGGER
class ANNStreamBuf : public std::basic_streambuf<char>
{
public:
DISKANN_DLLEXPORT explicit ANNStreamBuf(FILE *fp);
DISKANN_DLLEXPORT ~ANNStreamBuf();
DISKANN_DLLEXPORT bool is_open() const
{
return true; // because stdout and stderr are always open.
}
DISKANN_DLLEXPORT void close();
DISKANN_DLLEXPORT virtual int underflow();
DISKANN_DLLEXPORT virtual int overflow(int c);
DISKANN_DLLEXPORT virtual int sync();
private:
FILE *_fp;
char *_buf;
int _bufIndex;
std::mutex _mutex;
LogLevel _logLevel;
int flush();
void logImpl(char *str, int numchars);
// Why the two buffer-sizes? If we are running normally, we are basically
// interacting with a character output system, so we short-circuit the
// output process by keeping an empty buffer and writing each character
// to stdout/stderr. But if we are running in OLS, we have to take all
// the text that is written to diskann::cout/diskann:cerr, consolidate it
// and push it out in one-shot, because the OLS infra does not give us
// character based output. Therefore, we use a larger buffer that is large
// enough to store the longest message, and continuously add characters
// to it. When the calling code outputs a std::endl or std::flush, sync()
// will be called and will output a log level, component name, and the text
// that has been collected. (sync() is also called if the buffer is full, so
// overflows/missing text are not a concern).
// This implies calling code _must_ either print std::endl or std::flush
// to ensure that the message is written immediately.
static const int BUFFER_SIZE = 1024;
ANNStreamBuf(const ANNStreamBuf &);
ANNStreamBuf &operator=(const ANNStreamBuf &);
};
#endif
} // namespace diskann

View File

@@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include "common_includes.h"
#include "utils.h"
namespace math_utils
{
float calc_distance(float *vec_1, float *vec_2, size_t dim);
// compute l2-squared norms of data stored in row major num_points * dim,
// needs
// to be pre-allocated
void compute_vecs_l2sq(float *vecs_l2sq, float *data, const size_t num_points, const size_t dim);
void rotate_data_randomly(float *data, size_t num_points, size_t dim, float *rot_mat, float *&new_mat,
bool transpose_rot = false);
// calculate closest center to data of num_points * dim (row major)
// centers is num_centers * dim (row major)
// data_l2sq has pre-computed squared norms of data
// centers_l2sq has pre-computed squared norms of centers
// pre-allocated center_index will contain id of k nearest centers
// pre-allocated dist_matrix shound be num_points * num_centers and contain
// squared distances
// Ideally used only by compute_closest_centers
void compute_closest_centers_in_block(const float *const data, const size_t num_points, const size_t dim,
const float *const centers, const size_t num_centers,
const float *const docs_l2sq, const float *const centers_l2sq,
uint32_t *center_index, float *const dist_matrix, size_t k = 1);
// Given data in num_points * new_dim row major
// Pivots stored in full_pivot_data as k * new_dim row major
// Calculate the closest pivot for each point and store it in vector
// closest_centers_ivf (which needs to be allocated outside)
// Additionally, if inverted index is not null (and pre-allocated), it will
// return inverted index for each center Additionally, if pts_norms_squared is
// not null, then it will assume that point norms are pre-computed and use
// those
// values
void compute_closest_centers(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers,
size_t k, uint32_t *closest_centers_ivf, std::vector<size_t> *inverted_index = NULL,
float *pts_norms_squared = NULL);
// if to_subtract is 1, will subtract nearest center from each row. Else will
// add. Output will be in data_load iself.
// Nearest centers need to be provided in closst_centers.
void process_residuals(float *data_load, size_t num_points, size_t dim, float *cur_pivot_data, size_t num_centers,
uint32_t *closest_centers, bool to_subtract);
} // namespace math_utils
namespace kmeans
{
// run Lloyds one iteration
// Given data in row major num_points * dim, and centers in row major
// num_centers * dim
// And squared lengths of data points, output the closest center to each data
// point, update centers, and also return inverted index.
// If closest_centers == NULL, will allocate memory and return.
// Similarly, if closest_docs == NULL, will allocate memory and return.
float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, size_t num_centers, float *docs_l2sq,
std::vector<size_t> *closest_docs, uint32_t *&closest_center);
// Run Lloyds until max_reps or stopping criterion
// If you pass NULL for closest_docs and closest_center, it will NOT return
// the results, else it will assume appriate allocation as closest_docs = new
// vector<size_t> [num_centers], and closest_center = new size_t[num_points]
// Final centers are output in centers as row major num_centers * dim
//
float run_lloyds(float *data, size_t num_points, size_t dim, float *centers, const size_t num_centers,
const size_t max_reps, std::vector<size_t> *closest_docs, uint32_t *closest_center);
// assumes already memory allocated for pivot_data as new
// float[num_centers*dim] and select randomly num_centers points as pivots
void selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers);
void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers);
} // namespace kmeans

View File

@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#ifndef _WINDOWS
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#else
#include <Windows.h>
#endif
#include <string>
namespace diskann
{
class MemoryMapper
{
private:
#ifndef _WINDOWS
int _fd;
#else
HANDLE _bareFile;
HANDLE _fd;
#endif
char *_buf;
size_t _fileSize;
const char *_fileName;
public:
MemoryMapper(const char *filename);
MemoryMapper(const std::string &filename);
char *getBuf();
size_t getFileSize();
~MemoryMapper();
};
} // namespace diskann

View File

@@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <memory>
#include <type_traits>
#include <vector>
#include <boost/dynamic_bitset.hpp>
namespace diskann
{
// A map whose key is a natural number (from 0 onwards) and maps to a value.
// Made as both memory and performance efficient map for scenario such as
// DiskANN location-to-tag map. There, the pool of numbers is consecutive from
// zero to some max value, and it's expected that most if not all keys from 0
// up to some current maximum will be present in the map. The memory usage of
// the map is determined by the largest inserted key since it uses vector as a
// backing store and bitset for presence indication.
//
// Thread-safety: this class is not thread-safe in general.
// Exception: multiple read-only operations are safe on the object only if
// there are no writers to it in parallel.
template <typename Key, typename Value> class natural_number_map
{
public:
static_assert(std::is_trivial<Key>::value, "Key must be a trivial type");
// Represents a reference to a element in the map. Used while iterating
// over map entries.
struct position
{
size_t _key;
// The number of keys that were enumerated when iterating through the
// map so far. Used to early-terminate enumeration when ithere are no
// more entries in the map.
size_t _keys_already_enumerated;
// Returns whether it's valid to access the element at this position in
// the map.
bool is_valid() const;
};
natural_number_map();
void reserve(size_t count);
size_t size() const;
void set(Key key, Value value);
void erase(Key key);
bool contains(Key key) const;
bool try_get(Key key, Value &value) const;
// Returns the value at the specified position. Prerequisite: position is
// valid.
Value get(const position &pos) const;
// Finds the first element in the map, if any. Invalidated by changes in the
// map.
position find_first() const;
// Finds the next element in the map after the specified position.
// Invalidated by changes in the map.
position find_next(const position &after_position) const;
void clear();
private:
// Number of entries in the map. Not the same as size() of the
// _values_vector below.
size_t _size;
// Array of values. The key is the index of the value.
std::vector<Value> _values_vector;
// Values that are in the set have the corresponding bit index set
// to 1.
//
// Use a pointer here to allow for forward declaration of dynamic_bitset
// in public headers to avoid making boost a dependency for clients
// of DiskANN.
std::unique_ptr<boost::dynamic_bitset<>> _values_bitset;
};
} // namespace diskann

View File

@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <memory>
#include <type_traits>
#include "boost_dynamic_bitset_fwd.h"
namespace diskann
{
// A set of natural numbers (from 0 onwards). Made for scenario where the
// pool of numbers is consecutive from zero to some max value and very
// efficient methods for "add to set", "get any value from set", "is in set"
// are needed. The memory usage of the set is determined by the largest
// number of inserted entries (uses a vector as a backing store) as well as
// the largest value to be placed in it (uses bitset as well).
//
// Thread-safety: this class is not thread-safe in general.
// Exception: multiple read-only operations (e.g. is_in_set, empty, size) are
// safe on the object only if there are no writers to it in parallel.
template <typename T> class natural_number_set
{
public:
static_assert(std::is_trivial<T>::value, "Identifier must be a trivial type");
natural_number_set();
bool is_empty() const;
void reserve(size_t count);
void insert(T id);
T pop_any();
void clear();
size_t size() const;
bool is_in_set(T id) const;
private:
// Values that are currently in set.
std::vector<T> _values_vector;
// Values that are in the set have the corresponding bit index set
// to 1.
//
// Use a pointer here to allow for forward declaration of dynamic_bitset
// in public headers to avoid making boost a dependency for clients
// of DiskANN.
std::unique_ptr<boost::dynamic_bitset<>> _values_bitset;
};
} // namespace diskann

View File

@@ -0,0 +1,152 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <cstddef>
#include <mutex>
#include <vector>
#include "utils.h"
namespace diskann
{
struct Neighbor
{
unsigned id;
float distance;
bool expanded;
Neighbor() = default;
Neighbor(unsigned id, float distance) : id{id}, distance{distance}, expanded(false)
{
}
inline bool operator<(const Neighbor &other) const
{
return distance < other.distance || (distance == other.distance && id < other.id);
}
inline bool operator==(const Neighbor &other) const
{
return (id == other.id);
}
};
// Invariant: after every `insert` and `closest_unexpanded()`, `_cur` points to
// the first Neighbor which is unexpanded.
class NeighborPriorityQueue
{
public:
NeighborPriorityQueue() : _size(0), _capacity(0), _cur(0)
{
}
explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1)
{
}
// Inserts the item ordered into the set up to the sets capacity.
// The item will be dropped if it is the same id as an exiting
// set item or it has a greated distance than the final
// item in the set. The set cursor that is used to pop() the
// next item will be set to the lowest index of an uncheck item
void insert(const Neighbor &nbr)
{
if (_size == _capacity && _data[_size - 1] < nbr)
{
return;
}
size_t lo = 0, hi = _size;
while (lo < hi)
{
size_t mid = (lo + hi) >> 1;
if (nbr < _data[mid])
{
hi = mid;
// Make sure the same id isn't inserted into the set
}
else if (_data[mid].id == nbr.id)
{
return;
}
else
{
lo = mid + 1;
}
}
if (lo < _capacity)
{
std::memmove(&_data[lo + 1], &_data[lo], (_size - lo) * sizeof(Neighbor));
}
_data[lo] = {nbr.id, nbr.distance};
if (_size < _capacity)
{
_size++;
}
if (lo < _cur)
{
_cur = lo;
}
}
Neighbor closest_unexpanded()
{
_data[_cur].expanded = true;
size_t pre = _cur;
while (_cur < _size && _data[_cur].expanded)
{
_cur++;
}
return _data[pre];
}
bool has_unexpanded_node() const
{
return _cur < _size;
}
size_t size() const
{
return _size;
}
size_t capacity() const
{
return _capacity;
}
void reserve(size_t capacity)
{
if (capacity + 1 > _data.size())
{
_data.resize(capacity + 1);
}
_capacity = capacity;
}
Neighbor &operator[](size_t i)
{
return _data[i];
}
Neighbor operator[](size_t i) const
{
return _data[i];
}
void clear()
{
_size = 0;
_cur = 0;
}
private:
size_t _size, _capacity, _cur;
std::vector<Neighbor> _data;
};
} // namespace diskann

View File

@@ -0,0 +1,119 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <sstream>
#include <typeinfo>
#include <unordered_map>
#include "omp.h"
#include "defaults.h"
namespace diskann
{
class IndexWriteParameters
{
public:
const uint32_t search_list_size; // L
const uint32_t max_degree; // R
const bool saturate_graph;
const uint32_t max_occlusion_size; // C
const float alpha;
const uint32_t num_threads;
const uint32_t filter_list_size; // Lf
IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph,
const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads,
const uint32_t filter_list_size)
: search_list_size(search_list_size), max_degree(max_degree), saturate_graph(saturate_graph),
max_occlusion_size(max_occlusion_size), alpha(alpha), num_threads(num_threads),
filter_list_size(filter_list_size)
{
}
friend class IndexWriteParametersBuilder;
};
class IndexSearchParams
{
public:
IndexSearchParams(const uint32_t initial_search_list_size, const uint32_t num_search_threads)
: initial_search_list_size(initial_search_list_size), num_search_threads(num_search_threads)
{
}
const uint32_t initial_search_list_size; // search L
const uint32_t num_search_threads; // search threads
};
class IndexWriteParametersBuilder
{
/**
* Fluent builder pattern to keep track of the 7 non-default properties
* and their order. The basic ctor was getting unwieldy.
*/
public:
IndexWriteParametersBuilder(const uint32_t search_list_size, // L
const uint32_t max_degree // R
)
: _search_list_size(search_list_size), _max_degree(max_degree)
{
}
IndexWriteParametersBuilder &with_max_occlusion_size(const uint32_t max_occlusion_size)
{
_max_occlusion_size = max_occlusion_size;
return *this;
}
IndexWriteParametersBuilder &with_saturate_graph(const bool saturate_graph)
{
_saturate_graph = saturate_graph;
return *this;
}
IndexWriteParametersBuilder &with_alpha(const float alpha)
{
_alpha = alpha;
return *this;
}
IndexWriteParametersBuilder &with_num_threads(const uint32_t num_threads)
{
_num_threads = num_threads == 0 ? omp_get_num_procs() : num_threads;
return *this;
}
IndexWriteParametersBuilder &with_filter_list_size(const uint32_t filter_list_size)
{
_filter_list_size = filter_list_size == 0 ? _search_list_size : filter_list_size;
return *this;
}
IndexWriteParameters build() const
{
return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, _max_occlusion_size, _alpha,
_num_threads, _filter_list_size);
}
IndexWriteParametersBuilder(const IndexWriteParameters &wp)
: _search_list_size(wp.search_list_size), _max_degree(wp.max_degree),
_max_occlusion_size(wp.max_occlusion_size), _saturate_graph(wp.saturate_graph), _alpha(wp.alpha),
_filter_list_size(wp.filter_list_size)
{
}
IndexWriteParametersBuilder(const IndexWriteParametersBuilder &) = delete;
IndexWriteParametersBuilder &operator=(const IndexWriteParametersBuilder &) = delete;
private:
uint32_t _search_list_size{};
uint32_t _max_degree{};
uint32_t _max_occlusion_size{defaults::MAX_OCCLUSION_SIZE};
bool _saturate_graph{defaults::SATURATE_GRAPH};
float _alpha{defaults::ALPHA};
uint32_t _num_threads{defaults::NUM_THREADS};
uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE};
};
} // namespace diskann

View File

@@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <cassert>
#include <sstream>
#include <stack>
#include <string>
#include <unordered_map>
#include "neighbor.h"
#include "parameters.h"
#include "tsl/robin_set.h"
#include "utils.h"
#include "windows_customizations.h"
template <typename T>
void gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate);
template <typename T>
void gen_random_slice(const std::string data_file, double p_val, float *&sampled_data, size_t &slice_size,
size_t &ndims);
template <typename T>
void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, double p_val, float *&sampled_data,
size_t &slice_size);
int estimate_cluster_sizes(float *test_data_float, size_t num_test, float *pivots, const size_t num_centers,
const size_t dim, const size_t k_base, std::vector<size_t> &cluster_sizes);
template <typename T>
int shard_data_into_clusters(const std::string data_file, float *pivots, const size_t num_centers, const size_t dim,
const size_t k_base, std::string prefix_path);
template <typename T>
int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots, const size_t num_centers,
const size_t dim, const size_t k_base, std::string prefix_path);
template <typename T>
int retrieve_shard_data_from_ids(const std::string data_file, std::string idmap_filename, std::string data_filename);
template <typename T>
int partition(const std::string data_file, const float sampling_rate, size_t num_centers, size_t max_k_means_reps,
const std::string prefix_path, size_t k_base);
template <typename T>
int partition_with_ram_budget(const std::string data_file, const double sampling_rate, double ram_budget,
size_t graph_degree, const std::string prefix_path, size_t k_base);

View File

@@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <functional>
#ifdef _WINDOWS
#include <numeric>
#endif
#include <string>
#include <vector>
#include "distance.h"
#include "parameters.h"
namespace diskann
{
struct QueryStats
{
float total_us = 0; // total time to process query in micros
float io_us = 0; // total time spent in IO
float cpu_us = 0; // total time spent in CPU
unsigned n_4k = 0; // # of 4kB reads
unsigned n_8k = 0; // # of 8kB reads
unsigned n_12k = 0; // # of 12kB reads
unsigned n_ios = 0; // total # of IOs issued
unsigned read_size = 0; // total # of bytes read
unsigned n_cmps_saved = 0; // # cmps saved
unsigned n_cmps = 0; // # cmps
unsigned n_cache_hits = 0; // # cache_hits
unsigned n_hops = 0; // # search hops
};
template <typename T>
inline T get_percentile_stats(QueryStats *stats, uint64_t len, float percentile,
const std::function<T(const QueryStats &)> &member_fn)
{
std::vector<T> vals(len);
for (uint64_t i = 0; i < len; i++)
{
vals[i] = member_fn(stats[i]);
}
std::sort(vals.begin(), vals.end(), [](const T &left, const T &right) { return left < right; });
auto retval = vals[(uint64_t)(percentile * len)];
vals.clear();
return retval;
}
template <typename T>
inline double get_mean_stats(QueryStats *stats, uint64_t len, const std::function<T(const QueryStats &)> &member_fn)
{
double avg = 0;
for (uint64_t i = 0; i < len; i++)
{
avg += (double)member_fn(stats[i]);
}
return avg / len;
}
} // namespace diskann

View File

@@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include "utils.h"
#include "pq_common.h"
namespace diskann
{
class FixedChunkPQTable
{
float *tables = nullptr; // pq_tables = float array of size [256 * ndims]
uint64_t ndims = 0; // ndims = true dimension of vectors
uint64_t n_chunks = 0;
bool use_rotation = false;
uint32_t *chunk_offsets = nullptr;
float *centroid = nullptr;
float *tables_tr = nullptr; // same as pq_tables, but col-major
float *rotmat_tr = nullptr;
public:
FixedChunkPQTable();
virtual ~FixedChunkPQTable();
#ifdef EXEC_ENV_OLS
void load_pq_centroid_bin(MemoryMappedFiles &files, const char *pq_table_file, size_t num_chunks);
#else
void load_pq_centroid_bin(const char *pq_table_file, size_t num_chunks);
#endif
uint32_t get_num_chunks();
void preprocess_query(float *query_vec);
// assumes pre-processed query
void populate_chunk_distances(const float *query_vec, float *dist_vec);
float l2_distance(const float *query_vec, uint8_t *base_vec);
float inner_product(const float *query_vec, uint8_t *base_vec);
// assumes no rotation is involved
void inflate_vector(uint8_t *base_vec, float *out_vec);
void populate_chunk_inner_products(const float *query_vec, float *dist_vec);
};
void aggregate_coords(const std::vector<unsigned> &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out);
void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
std::vector<float> &dists_out);
// Need to replace calls to these with calls to vector& based functions above
void aggregate_coords(const unsigned *ids, const uint64_t n_ids, const uint8_t *all_coords, const uint64_t ndims,
uint8_t *out);
void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
float *dists_out);
DISKANN_DLLEXPORT int generate_pq_pivots(const float *const train_data, size_t num_train, unsigned dim,
unsigned num_centers, unsigned num_pq_chunks, unsigned max_k_means_reps,
std::string pq_pivots_path, bool make_zero_mean = false);
DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_train, unsigned dim, unsigned num_centers,
unsigned num_pq_chunks, std::string opq_pivots_path,
bool make_zero_mean = false);
DISKANN_DLLEXPORT int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim,
size_t num_pq_chunks, std::vector<float> &pivot_data_vector);
template <typename T>
int generate_pq_data_from_pivots(const std::string &data_file, unsigned num_centers, unsigned num_pq_chunks,
const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path,
bool use_opq = false);
DISKANN_DLLEXPORT int generate_pq_data_from_pivots_simplified(const float *data, const size_t num,
const float *pivot_data, const size_t pivots_num,
const size_t dim, const size_t num_pq_chunks,
std::vector<uint8_t> &pq);
template <typename T>
void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path,
const std::string &disk_pq_compressed_vectors_path,
const diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims);
template <typename T>
void generate_quantized_data(const std::string &data_file_to_use, const std::string &pq_pivots_path,
const std::string &pq_compressed_vectors_path, const diskann::Metric compareMetric,
const double p_val, const uint64_t num_pq_chunks, const bool use_opq,
const std::string &codebook_prefix = "");
} // namespace diskann

View File

@@ -0,0 +1,30 @@
#pragma once
#include <string>
#include <sstream>
#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512
namespace diskann
{
inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin";
}
inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin";
}
inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename)
{
return pivot_data_filename + "_rotation_matrix.bin";
}
} // namespace diskann

View File

@@ -0,0 +1,97 @@
#pragma once
#include <memory>
#include "distance.h"
#include "quantized_distance.h"
#include "pq.h"
#include "abstract_data_store.h"
namespace diskann
{
// REFACTOR TODO: By default, the PQDataStore is an in-memory datastore because both Vamana and
// DiskANN treat it the same way. But with DiskPQ, that may need to change.
template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>
{
public:
PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, std::unique_ptr<Distance<data_t>> distance_fn,
std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn);
PQDataStore(const PQDataStore &) = delete;
PQDataStore &operator=(const PQDataStore &) = delete;
~PQDataStore();
// Load quantized vectors from a set of files. Here filename is treated
// as a prefix and the files are assumed to be named with DiskANN
// conventions.
virtual location_t load(const std::string &file_prefix) override;
// Save quantized vectors to a set of files whose names start with
// file_prefix.
// Currently, the plan is to save the quantized vectors to the quantized
// vectors file.
virtual size_t save(const std::string &file_prefix, const location_t num_points) override;
// Since base class function is pure virtual, we need to declare it here, even though alignent concept is not needed
// for Quantized data stores.
virtual size_t get_aligned_dim() const override;
// Populate quantized data from unaligned data using PQ functionality
virtual void populate_data(const data_t *vectors, const location_t num_pts) override;
virtual void populate_data(const std::string &filename, const size_t offset) override;
virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) override;
virtual void get_vector(const location_t i, data_t *target) const override;
virtual void set_vector(const location_t i, const data_t *const vector) override;
virtual void prefetch_vector(const location_t loc) override;
virtual void move_vectors(const location_t old_location_start, const location_t new_location_start,
const location_t num_points) override;
virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override;
virtual void preprocess_query(const data_t *query, AbstractScratch<data_t> *scratch) const override;
virtual float get_distance(const data_t *query, const location_t loc) const override;
virtual float get_distance(const location_t loc1, const location_t loc2) const override;
// NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling
// this function.
virtual void get_distance(const data_t *preprocessed_query, const location_t *locations,
const uint32_t location_count, float *distances,
AbstractScratch<data_t> *scratch_space) const override;
// NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling
// this function.
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const override;
// We are returning the distance function that is used for full precision
// vectors here, not the PQ distance function. This is because the callers
// all are expecting a Distance<T> not QuantizedDistance<T>.
virtual Distance<data_t> *get_dist_fn() const override;
virtual location_t calculate_medoid() const override;
virtual size_t get_alignment_factor() const override;
protected:
virtual location_t expand(const location_t new_size) override;
virtual location_t shrink(const location_t new_size) override;
virtual location_t load_impl(const std::string &filename);
#ifdef EXEC_ENV_OLS
virtual location_t load_impl(AlignedFileReader &reader);
#endif
private:
uint8_t *_quantized_data = nullptr;
size_t _num_chunks = 0;
// REFACTOR TODO: Doing this temporarily before refactoring OPQ into
// its own class. Remove later.
bool _use_opq = false;
Metric _distance_metric;
std::unique_ptr<Distance<data_t>> _distance_fn = nullptr;
std::unique_ptr<QuantizedDistance<data_t>> _pq_distance_fn = nullptr;
};
} // namespace diskann

View File

@@ -0,0 +1,286 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include "common_includes.h"
#include "aligned_file_reader.h"
#include "concurrent_queue.h"
#include "neighbor.h"
#include "parameters.h"
#include "percentile_stats.h"
#include "pq.h"
#include "utils.h"
#include "windows_customizations.h"
#include "scratch.h"
#include "tsl/robin_map.h"
#include "tsl/robin_set.h"
#define FULL_PRECISION_REORDER_MULTIPLIER 3
namespace diskann
{
template <typename T, typename LabelT = uint32_t> class PQFlashIndex
{
public:
DISKANN_DLLEXPORT PQFlashIndex(std::shared_ptr<AlignedFileReader> &fileReader,
std::shared_ptr<AlignedFileReader> &graphReader,
diskann::Metric metric = diskann::Metric::L2);
DISKANN_DLLEXPORT ~PQFlashIndex();
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix,
const char *pq_prefix = nullptr);
#else
// load compressed data, and obtains the handle to the disk-resident index
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix, const char *pq_prefix = nullptr,
const char *partition_prefix = nullptr);
#endif
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
const char *index_filepath, const char *pivots_filepath,
const char *compressed_filepath, const char *graph_file);
#else
DISKANN_DLLEXPORT int load_from_separate_paths(uint32_t num_threads, const char *index_filepath,
const char *pivots_filepath, const char *compressed_filepath,
const char *graph_file, const char *partition_file);
#endif
DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(MemoryMappedFiles &files, std::string sample_bin,
uint64_t l_search, uint64_t beamwidth,
uint64_t num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list);
#else
DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(std::string sample_bin, uint64_t l_search,
uint64_t beamwidth, uint64_t num_nodes_to_cache,
uint32_t num_threads,
std::vector<uint32_t> &node_list);
#endif
DISKANN_DLLEXPORT void cache_bfs_levels(uint64_t num_nodes_to_cache, std::vector<uint32_t> &node_list,
const bool shuffle = false);
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_reorder_data = false, QueryStats *stats = nullptr,
const bool USE_DEFERRED_FETCH = false,
const bool skip_search_reorder = false,
const bool recompute_beighbor_embeddings = false,
const bool dedup_node_dis = false, float prune_ratio = 0,
const bool batch_recompute = false, bool global_pruning = false);
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_filter, const LabelT &filter_label,
const bool use_reorder_data = false, QueryStats *stats = nullptr,
const bool USE_DEFERRED_FETCH = false,
const bool skip_search_reorder = false,
const bool recompute_beighbor_embeddings = false,
const bool dedup_node_dis = false, float prune_ratio = 0,
const bool batch_recompute = false, bool global_pruning = false);
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const uint32_t io_limit, const bool use_reorder_data = false,
QueryStats *stats = nullptr, const bool USE_DEFERRED_FETCH = false,
const bool skip_search_reorder = false,
const bool recompute_beighbor_embeddings = false,
const bool dedup_node_dis = false, float prune_ratio = 0,
const bool batch_recompute = false, bool global_pruning = false);
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_filter, const LabelT &filter_label,
const uint32_t io_limit, const bool use_reorder_data = false,
QueryStats *stats = nullptr, const bool USE_DEFERRED_FETCH = false,
const bool skip_search_reorder = false,
const bool recompute_beighbor_embeddings = false,
const bool dedup_node_dis = false, float prune_ratio = 0,
const bool batch_recompute = false, bool global_pruning = false);
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);
DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search,
const uint64_t max_l_search, std::vector<uint64_t> &indices,
std::vector<float> &distances, const uint64_t min_beam_width,
QueryStats *stats = nullptr);
DISKANN_DLLEXPORT uint64_t get_data_dim();
std::shared_ptr<AlignedFileReader> &reader;
DISKANN_DLLEXPORT diskann::Metric get_metric();
//
// node_ids: input list of node_ids to be read
// coord_buffers: pointers to pre-allocated buffers that coords need to copied to. If null, dont copy.
// nbr_buffers: pre-allocated buffers to copy neighbors into
//
// returns a vector of bool one for each node_id: true if read is success, else false
//
DISKANN_DLLEXPORT std::vector<bool> read_nodes(const std::vector<uint32_t> &node_ids,
std::vector<T *> &coord_buffers,
std::vector<std::pair<uint32_t, uint32_t *>> &nbr_buffers);
DISKANN_DLLEXPORT std::vector<std::uint8_t> get_pq_vector(std::uint64_t vid);
DISKANN_DLLEXPORT uint64_t get_num_points();
protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096);
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);
void reset_stream_for_reading(std::basic_istream<char> &infile);
// sector # on disk where node_id is present with in the graph part
DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id);
// ptr to start of the node
DISKANN_DLLEXPORT char *offset_to_node(char *sector_buf, uint64_t node_id);
// returns region of `node_buf` containing [NNBRS][NBR_ID(uint32_t)]
DISKANN_DLLEXPORT uint32_t *offset_to_node_nhood(char *node_buf);
// returns region of `node_buf` containing [COORD(T)]
DISKANN_DLLEXPORT T *offset_to_node_coords(char *node_buf);
DISKANN_DLLEXPORT int load_graph_index(const std::string &graph_index_file);
DISKANN_DLLEXPORT int read_partition_info(const std::string &partition_bin);
DISKANN_DLLEXPORT int read_neighbors(const std::string &graph_index_file, uint64_t target_node_id);
// index info for multi-node sectors
// nhood of node `i` is in sector: [i / nnodes_per_sector]
// offset in sector: [(i % nnodes_per_sector) * max_node_len]
//
// index info for multi-sector nodes
// nhood of node `i` is in sector: [i * DIV_ROUND_UP(_max_node_len, SECTOR_LEN)]
// offset in sector: [0]
//
// Common info
// coords start at ofsset
// #nbrs of node `i`: *(unsigned*) (offset + disk_bytes_per_point)
// nbrs of node `i` : (unsigned*) (offset + disk_bytes_per_point + 1)
uint64_t _max_node_len = 0;
uint64_t _nnodes_per_sector = 0; // 0 for multi-sector nodes, >0 for multi-node sectors
uint64_t _max_degree = 0;
uint64_t _C = 0;
// Data used for searching with re-order vectors
uint64_t _ndims_reorder_vecs = 0;
uint64_t _reorder_data_start_sector = 0;
uint64_t _nvecs_per_sector = 0;
diskann::Metric metric = diskann::Metric::L2;
// used only for inner product search to re-scale the result value
// (due to the pre-processing of base during index build)
float _max_base_norm = 0.0f;
// data info
uint64_t _num_points = 0;
uint64_t _num_frozen_points = 0;
uint64_t _frozen_location = 0;
uint64_t _data_dim = 0;
uint64_t _aligned_dim = 0;
uint64_t _disk_bytes_per_point = 0; // Number of bytes
std::string _disk_index_file;
std::vector<std::pair<uint32_t, uint32_t>> _node_visit_counter;
// PQ data
// _n_chunks = # of chunks ndims is split into
// data: char * _n_chunks
// chunk_size = chunk size of each dimension chunk
// pq_tables = float* [[2^8 * [chunk_size]] * _n_chunks]
uint8_t *data = nullptr;
uint64_t _n_chunks;
FixedChunkPQTable _pq_table;
// distance comparator
std::shared_ptr<Distance<T>> _dist_cmp;
std::shared_ptr<Distance<float>> _dist_cmp_float;
// for very large datasets: we use PQ even for the disk resident index
bool _use_disk_index_pq = false;
uint64_t _disk_pq_n_chunks = 0;
FixedChunkPQTable _disk_pq_table;
// medoid/start info
// graph has one entry point by default,
// we can optionally have multiple starting points
uint32_t *_medoids = nullptr;
// defaults to 1
size_t _num_medoids;
// by default, it is empty. If there are multiple
// centroids, we pick the medoid corresponding to the
// closest centroid as the starting point of search
float *_centroid_data = nullptr;
// nhood_cache; the uint32_t in nhood_Cache are offsets into nhood_cache_buf
unsigned *_nhood_cache_buf = nullptr;
tsl::robin_map<uint32_t, std::pair<uint32_t, uint32_t *>> _nhood_cache;
// coord_cache; The T* in coord_cache are offsets into coord_cache_buf
T *_coord_cache_buf = nullptr;
tsl::robin_map<uint32_t, T *> _coord_cache;
// thread-specific scratch
ConcurrentQueue<SSDThreadData<T> *> _thread_data;
uint64_t _max_nthreads;
bool _load_flag = false;
bool _count_visited_nodes = false;
bool _reorder_data_exists = false;
uint64_t _reoreder_data_offset = 0;
// filter support
uint32_t *_pts_to_label_offsets = nullptr;
uint32_t *_pts_to_label_counts = nullptr;
LabelT *_pts_to_labels = nullptr;
std::unordered_map<LabelT, std::vector<uint32_t>> _filter_to_medoid_ids;
bool _use_universal_label = false;
LabelT _universal_filter_label;
tsl::robin_set<uint32_t> _dummy_pts;
tsl::robin_set<uint32_t> _has_dummy_pts;
tsl::robin_map<uint32_t, uint32_t> _dummy_to_real_map;
tsl::robin_map<uint32_t, std::vector<uint32_t>> _real_to_dummy_map;
std::unordered_map<std::string, LabelT> _label_map;
private:
bool _use_partition = false;
std::shared_ptr<AlignedFileReader> graph_reader; // Graph file reader
std::string _graph_index_file; // Graph file path
uint64_t _graph_node_len; // Graph node length
uint64_t _emb_node_len; // Embedding node length
// Partition related data structures
uint64_t _num_partitions; // Number of partitions
std::vector<std::vector<uint32_t>> _graph_partitions; // Partition information
std::vector<uint32_t> _id2partition; // ID to partition mapping
#ifdef EXEC_ENV_OLS
// Set to a larger value than the actual header to accommodate
// any additions we make to the header. This is an outer limit
// on how big the header can be.
static const int HEADER_SIZE = defaults::SECTOR_LEN;
char *getHeaderBytes();
#endif
};
} // namespace diskann

View File

@@ -0,0 +1,87 @@
#pragma once
#include "quantized_distance.h"
namespace diskann
{
template <typename data_t> class PQL2Distance : public QuantizedDistance<data_t>
{
public:
// REFACTOR TODO: We could take a file prefix here and load the
// PQ pivots file, so that the distance object is initialized
// immediately after construction. But this would not work well
// with our data store concept where the store is created first
// and data populated after.
// REFACTOR TODO: Ideally, we should only read the num_chunks from
// the pivots file. However, we read the pivots file only later, but
// clients can call functions like get_<xxx>_filename without calling
// load_pivot_data. Hence this. The TODO is whether we should check
// that the num_chunks from the file is the same as this one.
PQL2Distance(uint32_t num_chunks, bool use_opq = false);
virtual ~PQL2Distance() override;
virtual bool is_opq() const override;
virtual std::string get_quantized_vectors_filename(const std::string &prefix) const override;
virtual std::string get_pivot_data_filename(const std::string &prefix) const override;
virtual std::string get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const override;
#ifdef EXEC_ENV_OLS
virtual void load_pivot_data(MemoryMappedFiles &files, const std::string &pq_table_file,
size_t num_chunks) override;
#else
virtual void load_pivot_data(const std::string &pq_table_file, size_t num_chunks) override;
#endif
// Number of chunks in the PQ table. Depends on the compression level used.
// Has to be < ndim
virtual uint32_t get_num_chunks() const override;
// Preprocess the query by computing chunk distances from the query vector to
// various centroids. Since we don't want this class to do scratch management,
// we will take a PQScratch object which can come either from Index class or
// PQFlashIndex class.
virtual void preprocess_query(const data_t *aligned_query, uint32_t original_dim,
PQScratch<data_t> &pq_scratch) override;
// Distance function used for graph traversal. This function must be called
// after
// preprocess_query. The reason we do not call preprocess ourselves is because
// that function has to be called once per query, while this function is
// called at each iteration of the graph walk. NOTE: This function expects
// 1. the query to be preprocessed using preprocess_query()
// 2. the scratch object to contain the quantized vectors corresponding to ids
// in aligned_pq_coord_scratch. Done by calling aggregate_coords()
//
virtual void preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t id_count,
float *dists_out) override;
// Same as above, but returns the distances in a vector instead of an array.
// Convenience function for index.cpp.
virtual void preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t n_ids,
std::vector<float> &dists_out) override;
// Currently this function is required for DiskPQ. However, it too can be
// subsumed under preprocessed_distance if we add the appropriate scratch
// variables to PQScratch and initialize them in
// pq_flash_index.cpp::disk_iterate_to_fixed_point()
virtual float brute_force_distance(const float *query_vec, uint8_t *base_vec) override;
protected:
// assumes pre-processed query
virtual void prepopulate_chunkwise_distances(const float *query_vec, float *dist_vec);
// assumes no rotation is involved
// virtual void inflate_vector(uint8_t *base_vec, float *out_vec);
float *_tables = nullptr; // pq_tables = float array of size [256 * ndims]
uint64_t _ndims = 0; // ndims = true dimension of vectors
uint64_t _num_chunks = 0;
bool _is_opq = false;
uint32_t *_chunk_offsets = nullptr;
float *_centroid = nullptr;
float *_tables_tr = nullptr; // same as pq_tables, but col-major
float *_rotmat_tr = nullptr;
};
} // namespace diskann

View File

@@ -0,0 +1,23 @@
#pragma once
#include <cstdint>
#include "pq_common.h"
#include "utils.h"
namespace diskann
{
template <typename T> class PQScratch
{
public:
float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
uint8_t *aligned_pq_coord_scratch = nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE]
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;
PQScratch(size_t graph_degree, size_t aligned_dim);
void initialize(size_t dim, const T *query, const float norm = 1.0f);
virtual ~PQScratch();
};
} // namespace diskann

View File

@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <string.h>
namespace program_options_utils
{
const std::string make_program_description(const char *executable_name, const char *description)
{
return std::string("\n")
.append(description)
.append("\n\n")
.append("Usage: ")
.append(executable_name)
.append(" [OPTIONS]");
}
// Required parameters
const char *DATA_TYPE_DESCRIPTION = "data type, one of {int8, uint8, float} - float is single precision (32 bit)";
const char *DISTANCE_FUNCTION_DESCRIPTION =
"distance function {l2, mips, fast_l2, cosine}. 'fast l2' and 'mips' only support data_type float";
const char *INDEX_PATH_PREFIX_DESCRIPTION = "Path prefix to the index, e.g. '/mnt/data/my_ann_index'";
const char *RESULT_PATH_DESCRIPTION =
"Path prefix for saving results of the queries, e.g. '/mnt/data/query_file_X.bin'";
const char *QUERY_FILE_DESCRIPTION = "Query file in binary format, e.g. '/mnt/data/query_file_X.bin'";
const char *NUMBER_OF_RESULTS_DESCRIPTION = "Number of neighbors to be returned (K in the DiskANN white paper)";
const char *SEARCH_LIST_DESCRIPTION =
"Size of search list to use. This value is the number of neighbor/distance pairs to keep in memory at the same "
"time while performing a query. This can also be described as the size of the working set at query time. This "
"must be greater than or equal to the number of results/neighbors to return (K in the white paper). Corresponds "
"to L in the DiskANN white paper.";
const char *INPUT_DATA_PATH = "Input data file in bin format. This is the file you want to build the index over. "
"File format: Shape of the vector followed by the vector of embeddings as binary data.";
// Optional parameters
const char *FILTER_LABEL_DESCRIPTION =
"Filter to use when running a query. 'filter_label' and 'query_filters_file' are mutually exclusive.";
const char *FILTERS_FILE_DESCRIPTION =
"Filter file for Queries for Filtered Search. File format is text with one filter per line. File must "
"have exactly one filter OR the same number of filters as there are queries in the 'query_file'.";
const char *LABEL_TYPE_DESCRIPTION =
"Storage type of Labels {uint/uint32, ushort/uint16}, default value is uint which will consume memory 4 bytes per "
"filter. 'uint' is an alias for 'uint32' and 'ushort' is an alias for 'uint16'.";
const char *GROUND_TRUTH_FILE_DESCRIPTION =
"ground truth file for the queryset"; // what's the format, what's the requirements? does it need to include an
// entry for every item or just a small subset? I have so many questions about
// this file
const char *NUMBER_THREADS_DESCRIPTION = "Number of threads used for building index. Defaults to number of logical "
"processor cores on your this machine returned by omp_get_num_procs()";
const char *FAIL_IF_RECALL_BELOW =
"Value between 0 (inclusive) and 100 (exclusive) indicating the recall tolerance percentage threshold before "
"program fails with a non-zero exit code. The default value of 0 means that the program will complete "
"successfully with any recall value. A non-zero value indicates the floor for acceptable recall values. If the "
"calculated recall value is below this threshold then the program will write out the results but return a non-zero "
"exit code as a signal that the recall was not acceptable."; // does it continue running or die immediately? Will I
// still get my results even if the return code is -1?
const char *NUMBER_OF_NODES_TO_CACHE = "Number of BFS nodes around medoid(s) to cache. Default value: 0";
const char *BEAMWIDTH = "Beamwidth for search. Set 0 to optimize internally. Default value: 2";
const char *MAX_BUILD_DEGREE = "Maximum graph degree";
const char *GRAPH_BUILD_COMPLEXITY =
"Size of the search working set during build time. This is the numer of neighbor/distance pairs to keep in memory "
"while building the index. Higher value results in a higher quality graph but it will take more time to build the "
"graph.";
const char *GRAPH_BUILD_ALPHA = "Alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for "
"denser graphs with lower diameter";
const char *BUIlD_GRAPH_PQ_BYTES = "Number of PQ bytes to build the index; 0 for full precision build";
const char *USE_OPQ = "Use Optimized Product Quantization (OPQ).";
const char *LABEL_FILE = "Input label file in txt format for Filtered Index build. The file should contain comma "
"separated filters for each node with each line corresponding to a graph node";
const char *UNIVERSAL_LABEL =
"Universal label, Use only in conjunction with label file for filtered index build. If a "
"graph node has all the labels against it, we can assign a special universal filter to the "
"point instead of comma separated filters for that point. The universal label should be assigned to nodes "
"in the labels file instead of listing all labels for a node. DiskANN will not automatically assign a "
"universal label to a node.";
const char *FILTERED_LBUILD = "Build complexity for filtered points, higher value results in better graphs";
} // namespace program_options_utils

View File

@@ -0,0 +1,9 @@
#pragma once
#include "embedding.pb.h"
// This header ensures that the protobuf files are included correctly
// and provides a namespace alias for convenience
namespace diskann {
namespace proto = protoembedding;
}

View File

@@ -0,0 +1,56 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "abstract_scratch.h"
namespace diskann
{
template <typename data_t> class PQScratch;
template <typename data_t> class QuantizedDistance
{
public:
QuantizedDistance() = default;
QuantizedDistance(const QuantizedDistance &) = delete;
QuantizedDistance &operator=(const QuantizedDistance &) = delete;
virtual ~QuantizedDistance() = default;
virtual bool is_opq() const = 0;
virtual std::string get_quantized_vectors_filename(const std::string &prefix) const = 0;
virtual std::string get_pivot_data_filename(const std::string &prefix) const = 0;
virtual std::string get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const = 0;
// Loading the PQ centroid table need not be part of the abstract class.
// However, we want to indicate that this function will change once we have a
// file reader hierarchy, so leave it here as-is.
#ifdef EXEC_ENV_OLS
virtual void load_pivot_data(MemoryMappedFiles &files, const std::String &pq_table_file, size_t num_chunks) = 0;
#else
virtual void load_pivot_data(const std::string &pq_table_file, size_t num_chunks) = 0;
#endif
// Number of chunks in the PQ table. Depends on the compression level used.
// Has to be < ndim
virtual uint32_t get_num_chunks() const = 0;
// Preprocess the query by computing chunk distances from the query vector to
// various centroids. Since we don't want this class to do scratch management,
// we will take a PQScratch object which can come either from Index class or
// PQFlashIndex class.
virtual void preprocess_query(const data_t *query_vec, uint32_t query_dim, PQScratch<data_t> &pq_scratch) = 0;
// Workhorse
// This function must be called after preprocess_query
virtual void preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t id_count, float *dists_out) = 0;
// Same as above, but convenience function for index.cpp.
virtual void preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t n_ids,
std::vector<float> &dists_out) = 0;
// Currently this function is required for DiskPQ. However, it too can be subsumed
// under preprocessed_distance if we add the appropriate scratch variables to
// PQScratch and initialize them in pq_flash_index.cpp::disk_iterate_to_fixed_point()
virtual float brute_force_distance(const float *query_vec, uint8_t *base_vec) = 0;
};
} // namespace diskann

View File

@@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <cpprest/base_uri.h>
#include <restapi/search_wrapper.h>
namespace diskann
{
// Constants
static const std::string VECTOR_KEY = "query", K_KEY = "k", INDICES_KEY = "indices", DISTANCES_KEY = "distances",
TAGS_KEY = "tags", QUERY_ID_KEY = "query_id", ERROR_MESSAGE_KEY = "error", L_KEY = "Ls",
TIME_TAKEN_KEY = "time_taken_in_us", PARTITION_KEY = "partition",
UNKNOWN_ERROR = "unknown_error";
const unsigned int DEFAULT_L = 100;
} // namespace diskann

View File

@@ -0,0 +1,140 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <string>
#include <vector>
#include <stdexcept>
#include <index.h>
#include <pq_flash_index.h>
namespace diskann
{
class SearchResult
{
public:
SearchResult(unsigned int K, unsigned int elapsed_time_in_ms, const unsigned *const indices,
const float *const distances, const std::string *const tags = nullptr,
const unsigned *const partitions = nullptr);
const std::vector<unsigned int> &get_indices() const
{
return _indices;
}
const std::vector<float> &get_distances() const
{
return _distances;
}
bool tags_enabled() const
{
return _tags_enabled;
}
const std::vector<std::string> &get_tags() const
{
return _tags;
}
bool partitions_enabled() const
{
return _partitions_enabled;
}
const std::vector<unsigned> &get_partitions() const
{
return _partitions;
}
unsigned get_time() const
{
return _search_time_in_ms;
}
private:
unsigned int _K;
unsigned int _search_time_in_ms;
std::vector<unsigned int> _indices;
std::vector<float> _distances;
bool _tags_enabled;
std::vector<std::string> _tags;
bool _partitions_enabled;
std::vector<unsigned> _partitions;
};
class SearchNotImplementedException : public std::logic_error
{
private:
std::string _errormsg;
public:
SearchNotImplementedException(const char *type) : std::logic_error("Not Implemented")
{
_errormsg = "Search with data type ";
_errormsg += std::string(type);
_errormsg += " not implemented : ";
_errormsg += __FUNCTION__;
}
virtual const char *what() const throw()
{
return _errormsg.c_str();
}
};
class BaseSearch
{
public:
BaseSearch(const std::string &tagsFile = nullptr);
virtual SearchResult search(const float *query, const unsigned int dimensions, const unsigned int K,
const unsigned int Ls)
{
throw SearchNotImplementedException("float");
}
virtual SearchResult search(const int8_t *query, const unsigned int dimensions, const unsigned int K,
const unsigned int Ls)
{
throw SearchNotImplementedException("int8_t");
}
virtual SearchResult search(const uint8_t *query, const unsigned int dimensions, const unsigned int K,
const unsigned int Ls)
{
throw SearchNotImplementedException("uint8_t");
}
void lookup_tags(const unsigned K, const unsigned *indices, std::string *ret_tags);
protected:
bool _tags_enabled;
std::vector<std::string> _tags_str;
};
template <typename T> class InMemorySearch : public BaseSearch
{
public:
InMemorySearch(const std::string &baseFile, const std::string &indexFile, const std::string &tagsFile, Metric m,
uint32_t num_threads, uint32_t search_l);
virtual ~InMemorySearch();
SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls);
private:
unsigned int _dimensions, _numPoints;
std::unique_ptr<diskann::Index<T>> _index;
};
template <typename T> class PQFlashSearch : public BaseSearch
{
public:
PQFlashSearch(const std::string &indexPrefix, const unsigned num_nodes_to_cache, const unsigned num_threads,
const std::string &tagsFile, Metric m);
virtual ~PQFlashSearch();
SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls);
private:
unsigned int _dimensions, _numPoints;
std::unique_ptr<diskann::PQFlashIndex<T>> _index;
std::shared_ptr<AlignedFileReader> reader;
};
} // namespace diskann

View File

@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <restapi/common.h>
#include <cpprest/http_listener.h>
namespace diskann
{
class Server
{
public:
Server(web::uri &url, std::vector<std::unique_ptr<diskann::BaseSearch>> &multi_searcher,
const std::string &typestring);
virtual ~Server();
pplx::task<void> open();
pplx::task<void> close();
protected:
template <class T> void handle_post(web::http::http_request message);
template <typename T>
web::json::value toJsonArray(const std::vector<T> &v, std::function<web::json::value(const T &)> valConverter);
web::json::value prepareResponse(const int64_t &queryId, const int k);
template <class T>
void parseJson(const utility::string_t &body, unsigned int &k, int64_t &queryId, T *&queryVector,
unsigned int &dimensions, unsigned &Ls);
web::json::value idsToJsonArray(const diskann::SearchResult &result);
web::json::value distancesToJsonArray(const diskann::SearchResult &result);
web::json::value tagsToJsonArray(const diskann::SearchResult &result);
web::json::value partitionsToJsonArray(const diskann::SearchResult &result);
SearchResult aggregate_results(const unsigned K, const std::vector<diskann::SearchResult> &results);
private:
bool _isDebug;
std::unique_ptr<web::http::experimental::listener::http_listener> _listener;
const bool _multi_search;
std::vector<std::unique_ptr<diskann::BaseSearch>> _multi_searcher;
};
} // namespace diskann

View File

@@ -0,0 +1,216 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <vector>
#include "boost_dynamic_bitset_fwd.h"
// #include "boost/dynamic_bitset.hpp"
#include "tsl/robin_set.h"
#include "tsl/robin_map.h"
#include "tsl/sparse_map.h"
#include "aligned_file_reader.h"
#include "abstract_scratch.h"
#include "neighbor.h"
#include "defaults.h"
#include "concurrent_queue.h"
namespace diskann
{
template <typename T> class PQScratch;
//
// AbstractScratch space for in-memory index based search
//
template <typename T> class InMemQueryScratch : public AbstractScratch<T>
{
public:
~InMemQueryScratch();
InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim,
size_t alignment_factor, bool init_pq_scratch = false);
void resize_for_new_L(uint32_t new_search_l);
void clear();
inline uint32_t get_L()
{
return _L;
}
inline uint32_t get_R()
{
return _R;
}
inline uint32_t get_maxc()
{
return _maxc;
}
inline T *aligned_query()
{
return this->_aligned_query_T;
}
inline PQScratch<T> *pq_scratch()
{
return this->_pq_scratch;
}
inline std::vector<Neighbor> &pool()
{
return _pool;
}
inline NeighborPriorityQueue &best_l_nodes()
{
return _best_l_nodes;
}
inline std::vector<float> &occlude_factor()
{
return _occlude_factor;
}
inline tsl::robin_set<uint32_t> &inserted_into_pool_rs()
{
return _inserted_into_pool_rs;
}
inline boost::dynamic_bitset<> &inserted_into_pool_bs()
{
return *_inserted_into_pool_bs;
}
inline std::vector<uint32_t> &id_scratch()
{
return _id_scratch;
}
inline std::vector<float> &dist_scratch()
{
return _dist_scratch;
}
inline tsl::robin_set<uint32_t> &expanded_nodes_set()
{
return _expanded_nodes_set;
}
inline std::vector<Neighbor> &expanded_nodes_vec()
{
return _expanded_nghrs_vec;
}
inline std::vector<uint32_t> &occlude_list_output()
{
return _occlude_list_output;
}
private:
uint32_t _L;
uint32_t _R;
uint32_t _maxc;
// _pool stores all neighbors explored from best_L_nodes.
// Usually around L+R, but could be higher.
// Initialized to 3L+R for some slack, expands as needed.
std::vector<Neighbor> _pool;
// _best_l_nodes is reserved for storing best L entries
// Underlying storage is L+1 to support inserts
NeighborPriorityQueue _best_l_nodes;
// _occlude_factor.size() >= pool.size() in occlude_list function
// _pool is clipped to maxc in occlude_list before affecting _occlude_factor
// _occlude_factor is initialized to maxc size
std::vector<float> _occlude_factor;
// Capacity initialized to 20L
tsl::robin_set<uint32_t> _inserted_into_pool_rs;
// Use a pointer here to allow for forward declaration of dynamic_bitset
// in public headers to avoid making boost a dependency for clients
// of DiskANN.
boost::dynamic_bitset<> *_inserted_into_pool_bs;
// _id_scratch.size() must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp
std::vector<uint32_t> _id_scratch;
// _dist_scratch must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp
// _dist_scratch should be at least the size of id_scratch
std::vector<float> _dist_scratch;
// Buffers used in process delete, capacity increases as needed
tsl::robin_set<uint32_t> _expanded_nodes_set;
std::vector<Neighbor> _expanded_nghrs_vec;
std::vector<uint32_t> _occlude_list_output;
};
//
// AbstractScratch space for SSD index based search
//
template <typename T> class SSDQueryScratch : public AbstractScratch<T>
{
public:
T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim]
char *sector_scratch = nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN]
size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use
tsl::robin_set<size_t> visited;
NeighborPriorityQueue retset;
std::vector<Neighbor> full_retset;
SSDQueryScratch(size_t aligned_dim, size_t visited_reserve);
~SSDQueryScratch();
void reset();
};
template <typename T> class SSDThreadData
{
public:
SSDQueryScratch<T> scratch;
IOContext ctx;
SSDThreadData(size_t aligned_dim, size_t visited_reserve);
void clear();
};
//
// Class to avoid the hassle of pushing and popping the query scratch.
//
template <typename T> class ScratchStoreManager
{
public:
ScratchStoreManager(ConcurrentQueue<T *> &query_scratch) : _scratch_pool(query_scratch)
{
_scratch = query_scratch.pop();
while (_scratch == nullptr)
{
query_scratch.wait_for_push_notify();
_scratch = query_scratch.pop();
}
}
T *scratch_space()
{
return _scratch;
}
~ScratchStoreManager()
{
_scratch->clear();
_scratch_pool.push(_scratch);
_scratch_pool.push_notify_all();
}
void destroy()
{
while (!_scratch_pool.empty())
{
auto scratch = _scratch_pool.pop();
while (scratch == nullptr)
{
_scratch_pool.wait_for_push_notify();
scratch = _scratch_pool.pop();
}
delete scratch;
}
}
private:
T *_scratch;
ConcurrentQueue<T *> &_scratch_pool;
ScratchStoreManager(const ScratchStoreManager<T> &);
ScratchStoreManager &operator=(const ScratchStoreManager<T> &);
};
} // namespace diskann

View File

@@ -0,0 +1,106 @@
#pragma once
#ifdef _WINDOWS
#include <immintrin.h>
#include <smmintrin.h>
#include <tmmintrin.h>
#include <intrin.h>
#else
#include <immintrin.h>
#endif
namespace diskann
{
static inline __m256 _mm256_mul_epi8(__m256i X)
{
__m256i zero = _mm256_setzero_si256();
__m256i sign_x = _mm256_cmpgt_epi8(zero, X);
__m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
__m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, xlo), _mm256_madd_epi16(xhi, xhi)));
}
static inline __m128 _mm_mulhi_epi8(__m128i X)
{
__m128i zero = _mm_setzero_si128();
__m128i sign_x = _mm_cmplt_epi8(X, zero);
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi)));
}
static inline __m128 _mm_mulhi_epi8_shift32(__m128i X)
{
__m128i zero = _mm_setzero_si128();
X = _mm_srli_epi64(X, 32);
__m128i sign_x = _mm_cmplt_epi8(X, zero);
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi)));
}
static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y)
{
__m128i zero = _mm_setzero_si128();
__m128i sign_x = _mm_cmplt_epi8(X, zero);
__m128i sign_y = _mm_cmplt_epi8(Y, zero);
__m128i xlo = _mm_unpacklo_epi8(X, sign_x);
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
__m128i ylo = _mm_unpacklo_epi8(Y, sign_y);
__m128i yhi = _mm_unpackhi_epi8(Y, sign_y);
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi)));
}
static inline __m128 _mm_mul_epi8(__m128i X)
{
__m128i zero = _mm_setzero_si128();
__m128i sign_x = _mm_cmplt_epi8(X, zero);
__m128i xlo = _mm_unpacklo_epi8(X, sign_x);
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, xlo), _mm_madd_epi16(xhi, xhi)));
}
static inline __m128 _mm_mul32_pi8(__m128i X, __m128i Y)
{
__m128i xlo = _mm_cvtepi8_epi16(X), ylo = _mm_cvtepi8_epi16(Y);
return _mm_cvtepi32_ps(_mm_unpacklo_epi32(_mm_madd_epi16(xlo, ylo), _mm_setzero_si128()));
}
static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y)
{
__m256i zero = _mm256_setzero_si256();
__m256i sign_x = _mm256_cmpgt_epi8(zero, X);
__m256i sign_y = _mm256_cmpgt_epi8(zero, Y);
__m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
__m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
__m256i ylo = _mm256_unpacklo_epi8(Y, sign_y);
__m256i yhi = _mm256_unpackhi_epi8(Y, sign_y);
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi)));
}
static inline __m256 _mm256_mul32_pi8(__m128i X, __m128i Y)
{
__m256i xlo = _mm256_cvtepi8_epi16(X), ylo = _mm256_cvtepi8_epi16(Y);
return _mm256_blend_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(xlo, ylo)), _mm256_setzero_ps(), 252);
}
static inline float _mm256_reduce_add_ps(__m256 x)
{
/* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */
const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
/* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
/* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
/* Conversion to float is a no-op on x86-64 */
return _mm_cvtss_f32(x32);
}
} // namespace diskann

View File

@@ -0,0 +1,68 @@
#pragma once
#include <cstdint>
#include <type_traits>
namespace diskann
{
#pragma pack(push, 1)
struct tag_uint128
{
std::uint64_t _data1 = 0;
std::uint64_t _data2 = 0;
bool operator==(const tag_uint128 &other) const
{
return _data1 == other._data1 && _data2 == other._data2;
}
bool operator==(std::uint64_t other) const
{
return _data1 == other && _data2 == 0;
}
tag_uint128 &operator=(const tag_uint128 &other)
{
_data1 = other._data1;
_data2 = other._data2;
return *this;
}
tag_uint128 &operator=(std::uint64_t other)
{
_data1 = other;
_data2 = 0;
return *this;
}
};
#pragma pack(pop)
} // namespace diskann
namespace std
{
// Hash 128 input bits down to 64 bits of output.
// This is intended to be a reasonably good hash function.
inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t &high)
{
// Murmur-inspired hashing.
const std::uint64_t kMul = 0x9ddfea08eb382d69ULL;
std::uint64_t a = (low ^ high) * kMul;
a ^= (a >> 47);
std::uint64_t b = (high ^ a) * kMul;
b ^= (b >> 47);
b *= kMul;
return b;
}
template <> struct hash<diskann::tag_uint128>
{
size_t operator()(const diskann::tag_uint128 &key) const noexcept
{
return Hash128to64(key._data1, key._data2); // map -0 to 0
}
};
} // namespace std

View File

@@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <chrono>
namespace diskann
{
class Timer
{
typedef std::chrono::high_resolution_clock _clock;
std::chrono::time_point<_clock> check_point;
public:
Timer() : check_point(_clock::now())
{
}
void reset()
{
check_point = _clock::now();
}
long long elapsed() const
{
return std::chrono::duration_cast<std::chrono::microseconds>(_clock::now() - check_point).count();
}
float elapsed_seconds() const
{
return (float)elapsed() / 1000000.0f;
}
std::string elapsed_seconds_for_step(const std::string &step) const
{
return std::string("Time for ") + step + std::string(": ") + std::to_string(elapsed_seconds()) +
std::string(" seconds");
}
};
} // namespace diskann

View File

@@ -0,0 +1,2 @@
DisableFormat: true
SortIncludes: false

View File

@@ -0,0 +1,330 @@
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_GROWTH_POLICY_H
#define TSL_ROBIN_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
#ifndef tsl_assert
# ifdef TSL_DEBUG
# define tsl_assert(expr) assert(expr)
# else
# define tsl_assert(expr) (static_cast<void>(0))
# endif
#endif
/**
* If exceptions are enabled, throw the exception passed in parameter, otherwise call std::terminate.
*/
#ifndef TSL_THROW_OR_TERMINATE
# if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (defined (_MSC_VER) && defined (_CPPUNWIND))) && !defined(TSL_NO_EXCEPTIONS)
# define TSL_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
# else
# ifdef NDEBUG
# define TSL_THROW_OR_TERMINATE(ex, msg) std::terminate()
# else
# include <cstdio>
# define TSL_THROW_OR_TERMINATE(ex, msg) do { std::fprintf(stderr, msg); std::terminate(); } while(0)
# endif
# endif
#endif
#ifndef TSL_LIKELY
# if defined(__GNUC__) || defined(__clang__)
# define TSL_LIKELY(exp) (__builtin_expect(!!(exp), true))
# else
# define TSL_LIKELY(exp) (exp)
# endif
#endif
namespace tsl {
namespace rh {
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows
* the table to use a mask operation instead of a modulo operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template<std::size_t GrowthFactor>
class power_of_two_growth_policy {
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter.
* This number is a minimum, the policy may update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and
* bucket_for_hash must always return 0 in this case.
*/
explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) {
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
}
if(min_bucket_count_in_out > 0) {
min_bucket_count_in_out = round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1;
}
else {
m_mask = 0;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash & m_mask;
}
/**
* Return the number of buckets that should be used on next growth.
*/
std::size_t next_bucket_count() const {
if((m_mask + 1) > max_bucket_count() / GrowthFactor) {
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
}
return (m_mask + 1) * GrowthFactor;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std::size_t max_bucket_count() const {
// Largest power of two.
return ((std::numeric_limits<std::size_t>::max)() / 2) + 1;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is called.
*/
void clear() noexcept {
m_mask = 0;
}
private:
static std::size_t round_up_to_power_of_two(std::size_t value) {
if(is_power_of_two(value)) {
return value;
}
if(value == 0) {
return 1;
}
--value;
for(std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i;
}
return value + 1;
}
static constexpr bool is_power_of_two(std::size_t value) {
return value != 0 && (value & (value - 1)) == 0;
}
protected:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, "GrowthFactor must be a power of two >= 2.");
std::size_t m_mask;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash
* to a bucket. Slower but it can be useful if you want a slower growth.
*/
template<class GrowthFactor = std::ratio<3, 2>>
class mod_growth_policy {
public:
explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) {
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
}
if(min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out;
}
else {
m_mod = 1;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash % m_mod;
}
std::size_t next_bucket_count() const {
if(m_mod == max_bucket_count()) {
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
}
const double next_bucket_count = std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if(!std::isnormal(next_bucket_count)) {
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
}
if(next_bucket_count > double(max_bucket_count())) {
return max_bucket_count();
}
else {
return std::size_t(next_bucket_count);
}
}
std::size_t max_bucket_count() const {
return MAX_BUCKET_COUNT;
}
void clear() noexcept {
m_mod = 1;
}
private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = 1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double(
(std::numeric_limits<std::size_t>::max)() / REHASH_SIZE_MULTIPLICATION_FACTOR
));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, "Growth factor should be >= 1.1.");
std::size_t m_mod;
};
namespace detail {
static constexpr const std::array<std::size_t, 40> PRIMES = {{
1ul, 5ul, 17ul, 29ul, 37ul, 53ul, 67ul, 79ul, 97ul, 131ul, 193ul, 257ul, 389ul, 521ul, 769ul, 1031ul,
1543ul, 2053ul, 3079ul, 6151ul, 12289ul, 24593ul, 49157ul, 98317ul, 196613ul, 393241ul, 786433ul,
1572869ul, 3145739ul, 6291469ul, 12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul
}};
template<unsigned int IPrime>
static constexpr std::size_t mod(std::size_t hash) { return hash % PRIMES[IPrime]; }
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the
// compiler can optimize the modulo code better with a constant known at the compilation.
static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {{
&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>,
&mod<11>, &mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>, &mod<29>, &mod<30>,
&mod<31>, &mod<32>, &mod<33>, &mod<34>, &mod<35>, &mod<36>, &mod<37> , &mod<38>, &mod<39>
}};
}
/**
* Grow the hash table by using prime numbers as bucket count. Slower than tsl::rh::power_of_two_growth_policy in
* general but will probably distribute the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize the operation
* by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement.
*/
class prime_growth_policy {
public:
explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) {
auto it_prime = std::lower_bound(detail::PRIMES.begin(),
detail::PRIMES.end(), min_bucket_count_in_out);
if(it_prime == detail::PRIMES.end()) {
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
}
m_iprime = static_cast<unsigned int>(std::distance(detail::PRIMES.begin(), it_prime));
if(min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime;
}
else {
min_bucket_count_in_out = 0;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return detail::MOD_PRIME[m_iprime](hash);
}
std::size_t next_bucket_count() const {
if(m_iprime + 1 >= detail::PRIMES.size()) {
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
}
return detail::PRIMES[m_iprime + 1];
}
std::size_t max_bucket_count() const {
return detail::PRIMES.back();
}
void clear() noexcept {
m_iprime = 0;
}
private:
unsigned int m_iprime;
static_assert((std::numeric_limits<decltype(m_iprime)>::max)() >= detail::PRIMES.size(),
"The type of m_iprime is not big enough.");
};
}
}
#endif

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,668 @@
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_MAP_H
#define TSL_ROBIN_MAP_H
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
#include "robin_hash.h"
namespace tsl {
/**
* Implementation of a hash map using open-adressing and the robin hood hashing algorithm with backward shift deletion.
*
* For operations modifying the hash map (insert, erase, rehash, ...), the strong exception guarantee
* is only guaranteed when the expression `std::is_nothrow_swappable<std::pair<Key, T>>::value &&
* std::is_nothrow_move_constructible<std::pair<Key, T>>::value` is true, otherwise if an exception
* is thrown during the swap or the move, the hash map may end up in a undefined state. Per the standard
* a `Key` or `T` with a noexcept copy constructor and no move constructor also satisfies the
* `std::is_nothrow_move_constructible<std::pair<Key, T>>::value` criterion (and will thus guarantee the
* strong exception for the map).
*
* When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve
* the performance during lookups if the `KeyEqual` function takes time (if it engenders a cache-miss for example)
* as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used
* as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash.
* When it is detected that storing the hash will not incur any memory penality due to alignement (i.e.
* `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, true>) ==
* sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`) and `tsl::rh::power_of_two_growth_policy` is
* used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will
* not be used on lookups unless `StoreHash` is true).
*
* `GrowthPolicy` defines how the map grows and consequently how a hash value is mapped to a bucket.
* By default the map uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets
* to a power of two and uses a mask to map the hash to a bucket instead of the slow modulo.
* Other growth policies are available and you may define your own growth policy,
* check `tsl::rh::power_of_two_growth_policy` for the interface.
*
* If the destructor of `Key` or `T` throws an exception, the behaviour of the class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template<class Key,
class T,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
bool StoreHash = false,
class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
class robin_map {
private:
template<typename U>
using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type& operator()(const std::pair<Key, T>& key_value) const noexcept {
return key_value.first;
}
key_type& operator()(std::pair<Key, T>& key_value) noexcept {
return key_value.first;
}
};
class ValueSelect {
public:
using value_type = T;
const value_type& operator()(const std::pair<Key, T>& key_value) const noexcept {
return key_value.second;
}
value_type& operator()(std::pair<Key, T>& key_value) noexcept {
return key_value.second;
}
};
using ht = detail_robin_hash::robin_hash<std::pair<Key, T>, KeySelect, ValueSelect,
Hash, KeyEqual, Allocator, StoreHash, GrowthPolicy>;
public:
using key_type = typename ht::key_type;
using mapped_type = T;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
public:
/*
* Constructors
*/
robin_map(): robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE) {
}
explicit robin_map(size_type bucket_count,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()):
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
{
}
robin_map(size_type bucket_count,
const Allocator& alloc): robin_map(bucket_count, Hash(), KeyEqual(), alloc)
{
}
robin_map(size_type bucket_count,
const Hash& hash,
const Allocator& alloc): robin_map(bucket_count, hash, KeyEqual(), alloc)
{
}
explicit robin_map(const Allocator& alloc): robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
}
template<class InputIt>
robin_map(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()): robin_map(bucket_count, hash, equal, alloc)
{
insert(first, last);
}
template<class InputIt>
robin_map(InputIt first, InputIt last,
size_type bucket_count,
const Allocator& alloc): robin_map(first, last, bucket_count, Hash(), KeyEqual(), alloc)
{
}
template<class InputIt>
robin_map(InputIt first, InputIt last,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc): robin_map(first, last, bucket_count, hash, KeyEqual(), alloc)
{
}
robin_map(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()):
robin_map(init.begin(), init.end(), bucket_count, hash, equal, alloc)
{
}
robin_map(std::initializer_list<value_type> init,
size_type bucket_count,
const Allocator& alloc):
robin_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
{
}
robin_map(std::initializer_list<value_type> init,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc):
robin_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
{
}
robin_map& operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type& value) {
return m_ht.insert(value);
}
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
std::pair<iterator, bool> insert(P&& value) {
return m_ht.emplace(std::forward<P>(value));
}
std::pair<iterator, bool> insert(value_type&& value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type& value) {
return m_ht.insert(hint, value);
}
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
iterator insert(const_iterator hint, P&& value) {
return m_ht.emplace_hint(hint, std::forward<P>(value));
}
iterator insert(const_iterator hint, value_type&& value) {
return m_ht.insert(hint, std::move(value));
}
template<class InputIt>
void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
template<class M>
std::pair<iterator, bool> insert_or_assign(const key_type& k, M&& obj) {
return m_ht.insert_or_assign(k, std::forward<M>(obj));
}
template<class M>
std::pair<iterator, bool> insert_or_assign(key_type&& k, M&& obj) {
return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
}
template<class M>
iterator insert_or_assign(const_iterator hint, const key_type& k, M&& obj) {
return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
}
template<class M>
iterator insert_or_assign(const_iterator hint, key_type&& k, M&& obj) {
return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
template<class... Args>
std::pair<iterator, bool> try_emplace(const key_type& k, Args&&... args) {
return m_ht.try_emplace(k, std::forward<Args>(args)...);
}
template<class... Args>
std::pair<iterator, bool> try_emplace(key_type&& k, Args&&... args) {
return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
}
template<class... Args>
iterator try_emplace(const_iterator hint, const key_type& k, Args&&... args) {
return m_ht.try_emplace(hint, k, std::forward<Args>(args)...);
}
template<class... Args>
iterator try_emplace(const_iterator hint, key_type&& k, Args&&... args) {
return m_ht.try_emplace(hint, std::move(k), std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
size_type erase(const key_type& key) { return m_ht.erase(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type erase(const key_type& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key) { return m_ht.erase(key); }
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(robin_map& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
T& at(const Key& key) { return m_ht.at(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
T& at(const Key& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
const T& at(const Key& key) const { return m_ht.at(key); }
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const T& at(const Key& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
T& at(const K& key) { return m_ht.at(key); }
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
T& at(const K& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
/**
* @copydoc at(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const T& at(const K& key) const { return m_ht.at(key); }
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const T& at(const K& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
T& operator[](const Key& key) { return m_ht[key]; }
T& operator[](Key&& key) { return m_ht[std::move(key)]; }
size_type count(const Key& key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type count(const Key& key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key) const { return m_ht.count(key); }
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
iterator find(const Key& key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
const_iterator find(const Key& key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key& key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key) { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
/**
* @copydoc find(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key) const { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count) { m_ht.rehash(count); }
void reserve(size_type count) { m_ht.reserve(count); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
friend bool operator==(const robin_map& lhs, const robin_map& rhs) {
if(lhs.size() != rhs.size()) {
return false;
}
for(const auto& element_lhs: lhs) {
const auto it_element_rhs = rhs.find(element_lhs.first);
if(it_element_rhs == rhs.cend() || element_lhs.second != it_element_rhs->second) {
return false;
}
}
return true;
}
friend bool operator!=(const robin_map& lhs, const robin_map& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(robin_map& lhs, robin_map& rhs) {
lhs.swap(rhs);
}
private:
ht m_ht;
};
/**
* Same as `tsl::robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>`.
*/
template<class Key,
class T,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
bool StoreHash = false>
using robin_pg_map = robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>;
} // end namespace tsl
#endif

View File

@@ -0,0 +1,535 @@
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_SET_H
#define TSL_ROBIN_SET_H
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
#include "robin_hash.h"
namespace tsl {
/**
* Implementation of a hash set using open-adressing and the robin hood hashing algorithm with backward shift deletion.
*
* For operations modifying the hash set (insert, erase, rehash, ...), the strong exception guarantee
* is only guaranteed when the expression `std::is_nothrow_swappable<Key>::value &&
* std::is_nothrow_move_constructible<Key>::value` is true, otherwise if an exception
* is thrown during the swap or the move, the hash set may end up in a undefined state. Per the standard
* a `Key` with a noexcept copy constructor and no move constructor also satisfies the
* `std::is_nothrow_move_constructible<Key>::value` criterion (and will thus guarantee the
* strong exception for the set).
*
* When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve
* the performance during lookups if the `KeyEqual` function takes time (or engenders a cache-miss for example)
* as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used
* as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash.
* When it is detected that storing the hash will not incur any memory penality due to alignement (i.e.
* `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, true>) ==
* sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`) and `tsl::rh::power_of_two_growth_policy` is
* used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will
* not be used on lookups unless `StoreHash` is true).
*
* `GrowthPolicy` defines how the set grows and consequently how a hash value is mapped to a bucket.
* By default the set uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets
* to a power of two and uses a mask to set the hash to a bucket instead of the slow modulo.
* Other growth policies are available and you may define your own growth policy,
* check `tsl::rh::power_of_two_growth_policy` for the interface.
*
* If the destructor of `Key` throws an exception, the behaviour of the class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template<class Key,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<Key>,
bool StoreHash = false,
class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
class robin_set {
private:
template<typename U>
using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type& operator()(const Key& key) const noexcept {
return key;
}
key_type& operator()(Key& key) noexcept {
return key;
}
};
using ht = detail_robin_hash::robin_hash<Key, KeySelect, void,
Hash, KeyEqual, Allocator, StoreHash, GrowthPolicy>;
public:
using key_type = typename ht::key_type;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
/*
* Constructors
*/
robin_set(): robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE) {
}
explicit robin_set(size_type bucket_count,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()):
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
{
}
robin_set(size_type bucket_count,
const Allocator& alloc): robin_set(bucket_count, Hash(), KeyEqual(), alloc)
{
}
robin_set(size_type bucket_count,
const Hash& hash,
const Allocator& alloc): robin_set(bucket_count, hash, KeyEqual(), alloc)
{
}
explicit robin_set(const Allocator& alloc): robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
}
template<class InputIt>
robin_set(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()): robin_set(bucket_count, hash, equal, alloc)
{
insert(first, last);
}
template<class InputIt>
robin_set(InputIt first, InputIt last,
size_type bucket_count,
const Allocator& alloc): robin_set(first, last, bucket_count, Hash(), KeyEqual(), alloc)
{
}
template<class InputIt>
robin_set(InputIt first, InputIt last,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc): robin_set(first, last, bucket_count, hash, KeyEqual(), alloc)
{
}
robin_set(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()):
robin_set(init.begin(), init.end(), bucket_count, hash, equal, alloc)
{
}
robin_set(std::initializer_list<value_type> init,
size_type bucket_count,
const Allocator& alloc):
robin_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
{
}
robin_set(std::initializer_list<value_type> init,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc):
robin_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
{
}
robin_set& operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type& value) {
return m_ht.insert(value);
}
std::pair<iterator, bool> insert(value_type&& value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type& value) {
return m_ht.insert(hint, value);
}
iterator insert(const_iterator hint, value_type&& value) {
return m_ht.insert(hint, std::move(value));
}
template<class InputIt>
void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
size_type erase(const key_type& key) { return m_ht.erase(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type erase(const key_type& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key) { return m_ht.erase(key); }
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(robin_set& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
size_type count(const Key& key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type count(const Key& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key) const { return m_ht.count(key); }
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
iterator find(const Key& key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
const_iterator find(const Key& key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key) { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
/**
* @copydoc find(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key) const { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count) { m_ht.rehash(count); }
void reserve(size_type count) { m_ht.reserve(count); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
friend bool operator==(const robin_set& lhs, const robin_set& rhs) {
if(lhs.size() != rhs.size()) {
return false;
}
for(const auto& element_lhs: lhs) {
const auto it_element_rhs = rhs.find(element_lhs);
if(it_element_rhs == rhs.cend()) {
return false;
}
}
return true;
}
friend bool operator!=(const robin_set& lhs, const robin_set& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(robin_set& lhs, robin_set& rhs) {
lhs.swap(rhs);
}
private:
ht m_ht;
};
/**
* Same as `tsl::robin_set<Key, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>`.
*/
template<class Key,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<Key>,
bool StoreHash = false>
using robin_pg_set = robin_set<Key, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>;
} // end namespace tsl
#endif

View File

@@ -0,0 +1,301 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_SPARSE_GROWTH_POLICY_H
#define TSL_SPARSE_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
namespace tsl {
namespace sh {
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
* power of two. It allows the table to use a mask operation instead of a modulo
* operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template <std::size_t GrowthFactor>
class power_of_two_growth_policy {
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for
* the table is passed in parameter. This number is a minimum, the policy may
* update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy
* creation and bucket_for_hash must always return 0 in this case.
*/
explicit power_of_two_growth_policy(std::size_t &min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out =
round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1;
} else {
m_mask = 0;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash & m_mask;
}
/**
* Return the number of buckets that should be used on next growth.
*/
std::size_t next_bucket_count() const {
if ((m_mask + 1) > max_bucket_count() / GrowthFactor) {
throw std::length_error("The hash table exceeds its maximum size.");
}
return (m_mask + 1) * GrowthFactor;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std::size_t max_bucket_count() const {
// Largest power of two.
return (std::numeric_limits<std::size_t>::max() / 2) + 1;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is
* called.
*/
void clear() noexcept { m_mask = 0; }
private:
static std::size_t round_up_to_power_of_two(std::size_t value) {
if (is_power_of_two(value)) {
return value;
}
if (value == 0) {
return 1;
}
--value;
for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i;
}
return value + 1;
}
static constexpr bool is_power_of_two(std::size_t value) {
return value != 0 && (value & (value - 1)) == 0;
}
protected:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2,
"GrowthFactor must be a power of two >= 2.");
std::size_t m_mask;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
* to map a hash to a bucket. Slower but it can be useful if you want a slower
* growth.
*/
template <class GrowthFactor = std::ratio<3, 2>>
class mod_growth_policy {
public:
explicit mod_growth_policy(std::size_t &min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
if (min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out;
} else {
m_mod = 1;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash % m_mod;
}
std::size_t next_bucket_count() const {
if (m_mod == max_bucket_count()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
const double next_bucket_count =
std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if (!std::isnormal(next_bucket_count)) {
throw std::length_error("The hash table exceeds its maximum size.");
}
if (next_bucket_count > double(max_bucket_count())) {
return max_bucket_count();
} else {
return std::size_t(next_bucket_count);
}
}
std::size_t max_bucket_count() const { return MAX_BUCKET_COUNT; }
void clear() noexcept { m_mod = 1; }
private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR =
1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double(std::numeric_limits<std::size_t>::max() /
REHASH_SIZE_MULTIPLICATION_FACTOR));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1,
"Growth factor should be >= 1.1.");
std::size_t m_mod;
};
/**
* Grow the hash table by using prime numbers as bucket count. Slower than
* tsl::sh::power_of_two_growth_policy in general but will probably distribute
* the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is
* used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize
* the operation by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
* * 5' in a 64 bits environment.
*/
class prime_growth_policy {
public:
explicit prime_growth_policy(std::size_t &min_bucket_count_in_out) {
auto it_prime = std::lower_bound(primes().begin(), primes().end(),
min_bucket_count_in_out);
if (it_prime == primes().end()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
m_iprime =
static_cast<unsigned int>(std::distance(primes().begin(), it_prime));
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime;
} else {
min_bucket_count_in_out = 0;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return mod_prime()[m_iprime](hash);
}
std::size_t next_bucket_count() const {
if (m_iprime + 1 >= primes().size()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
return primes()[m_iprime + 1];
}
std::size_t max_bucket_count() const { return primes().back(); }
void clear() noexcept { m_iprime = 0; }
private:
static const std::array<std::size_t, 40> &primes() {
static const std::array<std::size_t, 40> PRIMES = {
{1ul, 5ul, 17ul, 29ul, 37ul,
53ul, 67ul, 79ul, 97ul, 131ul,
193ul, 257ul, 389ul, 521ul, 769ul,
1031ul, 1543ul, 2053ul, 3079ul, 6151ul,
12289ul, 24593ul, 49157ul, 98317ul, 196613ul,
393241ul, 786433ul, 1572869ul, 3145739ul, 6291469ul,
12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul}};
static_assert(
std::numeric_limits<decltype(m_iprime)>::max() >= PRIMES.size(),
"The type of m_iprime is not big enough.");
return PRIMES;
}
static const std::array<std::size_t (*)(std::size_t), 40> &mod_prime() {
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows
// for faster modulo as the compiler can optimize the modulo code better
// with a constant known at the compilation.
static const std::array<std::size_t (*)(std::size_t), 40> MOD_PRIME = {
{&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>,
&mod<7>, &mod<8>, &mod<9>, &mod<10>, &mod<11>, &mod<12>, &mod<13>,
&mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>,
&mod<28>, &mod<29>, &mod<30>, &mod<31>, &mod<32>, &mod<33>, &mod<34>,
&mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>}};
return MOD_PRIME;
}
template <unsigned int IPrime>
static std::size_t mod(std::size_t hash) {
return hash % primes()[IPrime];
}
private:
unsigned int m_iprime;
};
} // namespace sh
} // namespace tsl
#endif

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,800 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_SPARSE_MAP_H
#define TSL_SPARSE_MAP_H
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
#include "sparse_hash.h"
namespace tsl {
/**
* Implementation of a sparse hash map using open-addressing with quadratic
* probing. The goal on the hash map is to be the most memory efficient
* possible, even at low load factor, while keeping reasonable performances.
*
* `GrowthPolicy` defines how the map grows and consequently how a hash value is
* mapped to a bucket. By default the map uses
* `tsl::sh::power_of_two_growth_policy`. This policy keeps the number of
* buckets to a power of two and uses a mask to map the hash to a bucket instead
* of the slow modulo. Other growth policies are available and you may define
* your own growth policy, check `tsl::sh::power_of_two_growth_policy` for the
* interface.
*
* `ExceptionSafety` defines the exception guarantee provided by the class. By
* default only the basic exception safety is guaranteed which mean that all
* resources used by the hash map will be freed (no memory leaks) but the hash
* map may end-up in an undefined state if an exception is thrown (undefined
* here means that some elements may be missing). This can ONLY happen on rehash
* (either on insert or if `rehash` is called explicitly) and will occur if the
* Allocator can't allocate memory (`std::bad_alloc`) or if the copy constructor
* (when a nothrow move constructor is not available) throws an exception. This
* can be avoided by calling `reserve` beforehand. This basic guarantee is
* similar to the one of `google::sparse_hash_map` and `spp::sparse_hash_map`.
* It is possible to ask for the strong exception guarantee with
* `tsl::sh::exception_safety::strong`, the drawback is that the map will be
* slower on rehashes and will also need more memory on rehashes.
*
* `Sparsity` defines how much the hash set will compromise between insertion
* speed and memory usage. A high sparsity means less memory usage but longer
* insertion times, and vice-versa for low sparsity. The default
* `tsl::sh::sparsity::medium` sparsity offers a good compromise. It doesn't
* change the lookup speed.
*
* `Key` and `T` must be nothrow move constructible and/or copy constructible.
*
* If the destructor of `Key` or `T` throws an exception, the behaviour of the
* class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective
* insert, invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template <class Key, class T, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
class GrowthPolicy = tsl::sh::power_of_two_growth_policy<2>,
tsl::sh::exception_safety ExceptionSafety =
tsl::sh::exception_safety::basic,
tsl::sh::sparsity Sparsity = tsl::sh::sparsity::medium>
class sparse_map {
private:
template <typename U>
using has_is_transparent = tsl::detail_sparse_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type &operator()(
const std::pair<Key, T> &key_value) const noexcept {
return key_value.first;
}
key_type &operator()(std::pair<Key, T> &key_value) noexcept {
return key_value.first;
}
};
class ValueSelect {
public:
using value_type = T;
const value_type &operator()(
const std::pair<Key, T> &key_value) const noexcept {
return key_value.second;
}
value_type &operator()(std::pair<Key, T> &key_value) noexcept {
return key_value.second;
}
};
using ht = detail_sparse_hash::sparse_hash<
std::pair<Key, T>, KeySelect, ValueSelect, Hash, KeyEqual, Allocator,
GrowthPolicy, ExceptionSafety, Sparsity, tsl::sh::probing::quadratic>;
public:
using key_type = typename ht::key_type;
using mapped_type = T;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
public:
/*
* Constructors
*/
sparse_map() : sparse_map(ht::DEFAULT_INIT_BUCKET_COUNT) {}
explicit sparse_map(size_type bucket_count, const Hash &hash = Hash(),
const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR) {}
sparse_map(size_type bucket_count, const Allocator &alloc)
: sparse_map(bucket_count, Hash(), KeyEqual(), alloc) {}
sparse_map(size_type bucket_count, const Hash &hash, const Allocator &alloc)
: sparse_map(bucket_count, hash, KeyEqual(), alloc) {}
explicit sparse_map(const Allocator &alloc)
: sparse_map(ht::DEFAULT_INIT_BUCKET_COUNT, alloc) {}
template <class InputIt>
sparse_map(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: sparse_map(bucket_count, hash, equal, alloc) {
insert(first, last);
}
template <class InputIt>
sparse_map(InputIt first, InputIt last, size_type bucket_count,
const Allocator &alloc)
: sparse_map(first, last, bucket_count, Hash(), KeyEqual(), alloc) {}
template <class InputIt>
sparse_map(InputIt first, InputIt last, size_type bucket_count,
const Hash &hash, const Allocator &alloc)
: sparse_map(first, last, bucket_count, hash, KeyEqual(), alloc) {}
sparse_map(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: sparse_map(init.begin(), init.end(), bucket_count, hash, equal, alloc) {
}
sparse_map(std::initializer_list<value_type> init, size_type bucket_count,
const Allocator &alloc)
: sparse_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(),
alloc) {}
sparse_map(std::initializer_list<value_type> init, size_type bucket_count,
const Hash &hash, const Allocator &alloc)
: sparse_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(),
alloc) {}
sparse_map &operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type &value) {
return m_ht.insert(value);
}
template <class P, typename std::enable_if<std::is_constructible<
value_type, P &&>::value>::type * = nullptr>
std::pair<iterator, bool> insert(P &&value) {
return m_ht.emplace(std::forward<P>(value));
}
std::pair<iterator, bool> insert(value_type &&value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type &value) {
return m_ht.insert_hint(hint, value);
}
template <class P, typename std::enable_if<std::is_constructible<
value_type, P &&>::value>::type * = nullptr>
iterator insert(const_iterator hint, P &&value) {
return m_ht.emplace_hint(hint, std::forward<P>(value));
}
iterator insert(const_iterator hint, value_type &&value) {
return m_ht.insert_hint(hint, std::move(value));
}
template <class InputIt>
void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
template <class M>
std::pair<iterator, bool> insert_or_assign(const key_type &k, M &&obj) {
return m_ht.insert_or_assign(k, std::forward<M>(obj));
}
template <class M>
std::pair<iterator, bool> insert_or_assign(key_type &&k, M &&obj) {
return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
}
template <class M>
iterator insert_or_assign(const_iterator hint, const key_type &k, M &&obj) {
return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
}
template <class M>
iterator insert_or_assign(const_iterator hint, key_type &&k, M &&obj) {
return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the
* key-value once. The method is equivalent to
* `insert(value_type(std::forward<Args>(args)...));`.
*
* Mainly here for compatibility with the `std::unordered_map` interface.
*/
template <class... Args>
std::pair<iterator, bool> emplace(Args &&...args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy
* the key-value once. The method is equivalent to `insert(hint,
* value_type(std::forward<Args>(args)...));`.
*
* Mainly here for compatibility with the `std::unordered_map` interface.
*/
template <class... Args>
iterator emplace_hint(const_iterator hint, Args &&...args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
template <class... Args>
std::pair<iterator, bool> try_emplace(const key_type &k, Args &&...args) {
return m_ht.try_emplace(k, std::forward<Args>(args)...);
}
template <class... Args>
std::pair<iterator, bool> try_emplace(key_type &&k, Args &&...args) {
return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
}
template <class... Args>
iterator try_emplace(const_iterator hint, const key_type &k, Args &&...args) {
return m_ht.try_emplace_hint(hint, k, std::forward<Args>(args)...);
}
template <class... Args>
iterator try_emplace(const_iterator hint, key_type &&k, Args &&...args) {
return m_ht.try_emplace_hint(hint, std::move(k),
std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
size_type erase(const key_type &key) { return m_ht.erase(key); }
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
size_type erase(const key_type &key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type erase(const K &key) {
return m_ht.erase(key);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type erase(const K &key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(sparse_map &other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
T &at(const Key &key) { return m_ht.at(key); }
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
T &at(const Key &key, std::size_t precalculated_hash) {
return m_ht.at(key, precalculated_hash);
}
const T &at(const Key &key) const { return m_ht.at(key); }
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const T &at(const Key &key, std::size_t precalculated_hash) const {
return m_ht.at(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
T &at(const K &key) {
return m_ht.at(key);
}
/**
* @copydoc at(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
T &at(const K &key, std::size_t precalculated_hash) {
return m_ht.at(key, precalculated_hash);
}
/**
* @copydoc at(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const T &at(const K &key) const {
return m_ht.at(key);
}
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const T &at(const K &key, std::size_t precalculated_hash) const {
return m_ht.at(key, precalculated_hash);
}
T &operator[](const Key &key) { return m_ht[key]; }
T &operator[](Key &&key) { return m_ht[std::move(key)]; }
size_type count(const Key &key) const { return m_ht.count(key); }
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
size_type count(const Key &key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type count(const K &key) const {
return m_ht.count(key);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type count(const K &key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
iterator find(const Key &key) { return m_ht.find(key); }
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
iterator find(const Key &key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
const_iterator find(const Key &key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key &key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
iterator find(const K &key) {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
iterator find(const K &key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
/**
* @copydoc find(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const_iterator find(const K &key) const {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const_iterator find(const K &key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
bool contains(const Key &key) const { return m_ht.contains(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
bool contains(const Key &key, std::size_t precalculated_hash) const {
return m_ht.contains(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
bool contains(const K &key) const {
return m_ht.contains(key);
}
/**
* @copydoc contains(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
bool contains(const K &key, std::size_t precalculated_hash) const {
return m_ht.contains(key, precalculated_hash);
}
std::pair<iterator, iterator> equal_range(const Key &key) {
return m_ht.equal_range(key);
}
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
std::pair<iterator, iterator> equal_range(const Key &key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key &key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const Key &key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<iterator, iterator> equal_range(const K &key) {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<iterator, iterator> equal_range(const K &key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K &key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<const_iterator, const_iterator> equal_range(
const K &key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count) { m_ht.rehash(count); }
void reserve(size_type count) { m_ht.reserve(count); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a `const_iterator` to an `iterator`.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
/**
* Serialize the map through the `serializer` parameter.
*
* The `serializer` parameter must be a function object that supports the
* following call:
* - `template<typename U> void operator()(const U& value);` where the types
* `std::uint64_t`, `float` and `std::pair<Key, T>` must be supported for U.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, ...) of the types it serializes in the hands of the `Serializer`
* function object if compatibility is required.
*/
template <class Serializer>
void serialize(Serializer &serializer) const {
m_ht.serialize(serializer);
}
/**
* Deserialize a previously serialized map through the `deserializer`
* parameter.
*
* The `deserializer` parameter must be a function object that supports the
* following calls:
* - `template<typename U> U operator()();` where the types `std::uint64_t`,
* `float` and `std::pair<Key, T>` must be supported for U.
*
* If the deserialized hash map type is hash compatible with the serialized
* map, the deserialization process can be sped up by setting
* `hash_compatible` to true. To be hash compatible, the Hash, KeyEqual and
* GrowthPolicy must behave the same way than the ones used on the serialized
* map. The `std::size_t` must also be of the same size as the one on the
* platform used to serialize the map. If these criteria are not met, the
* behaviour is undefined with `hash_compatible` sets to true.
*
* The behaviour is undefined if the type `Key` and `T` of the `sparse_map`
* are not the same as the types used during serialization.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, size of int, ...) of the types it deserializes in the hands of the
* `Deserializer` function object if compatibility is required.
*/
template <class Deserializer>
static sparse_map deserialize(Deserializer &deserializer,
bool hash_compatible = false) {
sparse_map map(0);
map.m_ht.deserialize(deserializer, hash_compatible);
return map;
}
friend bool operator==(const sparse_map &lhs, const sparse_map &rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto &element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs.first);
if (it_element_rhs == rhs.cend() ||
element_lhs.second != it_element_rhs->second) {
return false;
}
}
return true;
}
friend bool operator!=(const sparse_map &lhs, const sparse_map &rhs) {
return !operator==(lhs, rhs);
}
friend void swap(sparse_map &lhs, sparse_map &rhs) { lhs.swap(rhs); }
private:
ht m_ht;
};
/**
* Same as `tsl::sparse_map<Key, T, Hash, KeyEqual, Allocator,
* tsl::sh::prime_growth_policy>`.
*/
template <class Key, class T, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>>
using sparse_pg_map =
sparse_map<Key, T, Hash, KeyEqual, Allocator, tsl::sh::prime_growth_policy>;
} // end namespace tsl
#endif

View File

@@ -0,0 +1,655 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_SPARSE_SET_H
#define TSL_SPARSE_SET_H
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
#include "sparse_hash.h"
namespace tsl {
/**
* Implementation of a sparse hash set using open-addressing with quadratic
* probing. The goal on the hash set is to be the most memory efficient
* possible, even at low load factor, while keeping reasonable performances.
*
* `GrowthPolicy` defines how the set grows and consequently how a hash value is
* mapped to a bucket. By default the set uses
* `tsl::sh::power_of_two_growth_policy`. This policy keeps the number of
* buckets to a power of two and uses a mask to map the hash to a bucket instead
* of the slow modulo. Other growth policies are available and you may define
* your own growth policy, check `tsl::sh::power_of_two_growth_policy` for the
* interface.
*
* `ExceptionSafety` defines the exception guarantee provided by the class. By
* default only the basic exception safety is guaranteed which mean that all
* resources used by the hash set will be freed (no memory leaks) but the hash
* set may end-up in an undefined state if an exception is thrown (undefined
* here means that some elements may be missing). This can ONLY happen on rehash
* (either on insert or if `rehash` is called explicitly) and will occur if the
* Allocator can't allocate memory (`std::bad_alloc`) or if the copy constructor
* (when a nothrow move constructor is not available) throws an exception. This
* can be avoided by calling `reserve` beforehand. This basic guarantee is
* similar to the one of `google::sparse_hash_map` and `spp::sparse_hash_map`.
* It is possible to ask for the strong exception guarantee with
* `tsl::sh::exception_safety::strong`, the drawback is that the set will be
* slower on rehashes and will also need more memory on rehashes.
*
* `Sparsity` defines how much the hash set will compromise between insertion
* speed and memory usage. A high sparsity means less memory usage but longer
* insertion times, and vice-versa for low sparsity. The default
* `tsl::sh::sparsity::medium` sparsity offers a good compromise. It doesn't
* change the lookup speed.
*
* `Key` must be nothrow move constructible and/or copy constructible.
*
* If the destructor of `Key` throws an exception, the behaviour of the class is
* undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint: if there is an effective insert, invalidate
* the iterators.
* - erase: always invalidate the iterators.
*/
template <class Key, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<Key>,
class GrowthPolicy = tsl::sh::power_of_two_growth_policy<2>,
tsl::sh::exception_safety ExceptionSafety =
tsl::sh::exception_safety::basic,
tsl::sh::sparsity Sparsity = tsl::sh::sparsity::medium>
class sparse_set {
private:
template <typename U>
using has_is_transparent = tsl::detail_sparse_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type &operator()(const Key &key) const noexcept { return key; }
key_type &operator()(Key &key) noexcept { return key; }
};
using ht =
detail_sparse_hash::sparse_hash<Key, KeySelect, void, Hash, KeyEqual,
Allocator, GrowthPolicy, ExceptionSafety,
Sparsity, tsl::sh::probing::quadratic>;
public:
using key_type = typename ht::key_type;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
/*
* Constructors
*/
sparse_set() : sparse_set(ht::DEFAULT_INIT_BUCKET_COUNT) {}
explicit sparse_set(size_type bucket_count, const Hash &hash = Hash(),
const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR) {}
sparse_set(size_type bucket_count, const Allocator &alloc)
: sparse_set(bucket_count, Hash(), KeyEqual(), alloc) {}
sparse_set(size_type bucket_count, const Hash &hash, const Allocator &alloc)
: sparse_set(bucket_count, hash, KeyEqual(), alloc) {}
explicit sparse_set(const Allocator &alloc)
: sparse_set(ht::DEFAULT_INIT_BUCKET_COUNT, alloc) {}
template <class InputIt>
sparse_set(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: sparse_set(bucket_count, hash, equal, alloc) {
insert(first, last);
}
template <class InputIt>
sparse_set(InputIt first, InputIt last, size_type bucket_count,
const Allocator &alloc)
: sparse_set(first, last, bucket_count, Hash(), KeyEqual(), alloc) {}
template <class InputIt>
sparse_set(InputIt first, InputIt last, size_type bucket_count,
const Hash &hash, const Allocator &alloc)
: sparse_set(first, last, bucket_count, hash, KeyEqual(), alloc) {}
sparse_set(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: sparse_set(init.begin(), init.end(), bucket_count, hash, equal, alloc) {
}
sparse_set(std::initializer_list<value_type> init, size_type bucket_count,
const Allocator &alloc)
: sparse_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(),
alloc) {}
sparse_set(std::initializer_list<value_type> init, size_type bucket_count,
const Hash &hash, const Allocator &alloc)
: sparse_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(),
alloc) {}
sparse_set &operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type &value) {
return m_ht.insert(value);
}
std::pair<iterator, bool> insert(value_type &&value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type &value) {
return m_ht.insert_hint(hint, value);
}
iterator insert(const_iterator hint, value_type &&value) {
return m_ht.insert_hint(hint, std::move(value));
}
template <class InputIt>
void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
/**
* Due to the way elements are stored, emplace will need to move or copy the
* key-value once. The method is equivalent to
* `insert(value_type(std::forward<Args>(args)...));`.
*
* Mainly here for compatibility with the `std::unordered_map` interface.
*/
template <class... Args>
std::pair<iterator, bool> emplace(Args &&...args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy
* the key-value once. The method is equivalent to `insert(hint,
* value_type(std::forward<Args>(args)...));`.
*
* Mainly here for compatibility with the `std::unordered_map` interface.
*/
template <class... Args>
iterator emplace_hint(const_iterator hint, Args &&...args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
size_type erase(const key_type &key) { return m_ht.erase(key); }
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
size_type erase(const key_type &key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type erase(const K &key) {
return m_ht.erase(key);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type erase(const K &key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(sparse_set &other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
size_type count(const Key &key) const { return m_ht.count(key); }
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
size_type count(const Key &key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type count(const K &key) const {
return m_ht.count(key);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type count(const K &key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
iterator find(const Key &key) { return m_ht.find(key); }
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
iterator find(const Key &key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
const_iterator find(const Key &key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key &key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
iterator find(const K &key) {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
iterator find(const K &key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
/**
* @copydoc find(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const_iterator find(const K &key) const {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const_iterator find(const K &key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
bool contains(const Key &key) const { return m_ht.contains(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
bool contains(const Key &key, std::size_t precalculated_hash) const {
return m_ht.contains(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
bool contains(const K &key) const {
return m_ht.contains(key);
}
/**
* @copydoc contains(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
bool contains(const K &key, std::size_t precalculated_hash) const {
return m_ht.contains(key, precalculated_hash);
}
std::pair<iterator, iterator> equal_range(const Key &key) {
return m_ht.equal_range(key);
}
/**
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
std::pair<iterator, iterator> equal_range(const Key &key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key &key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const Key &key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
* comparable to `Key`.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<iterator, iterator> equal_range(const K &key) {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value `precalculated_hash` instead of hashing the key. The
* hash value should be the same as `hash_function()(key)`, otherwise the
* behaviour is undefined. Useful to speed-up the lookup if you already have
* the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<iterator, iterator> equal_range(const K &key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K &key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<const_iterator, const_iterator> equal_range(
const K &key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count) { m_ht.rehash(count); }
void reserve(size_type count) { m_ht.reserve(count); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a `const_iterator` to an `iterator`.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
/**
* Serialize the set through the `serializer` parameter.
*
* The `serializer` parameter must be a function object that supports the
* following call:
* - `void operator()(const U& value);` where the types `std::uint64_t`,
* `float` and `Key` must be supported for U.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, ...) of the types it serializes in the hands of the `Serializer`
* function object if compatibility is required.
*/
template <class Serializer>
void serialize(Serializer &serializer) const {
m_ht.serialize(serializer);
}
/**
* Deserialize a previously serialized set through the `deserializer`
* parameter.
*
* The `deserializer` parameter must be a function object that supports the
* following calls:
* - `template<typename U> U operator()();` where the types `std::uint64_t`,
* `float` and `Key` must be supported for U.
*
* If the deserialized hash set type is hash compatible with the serialized
* set, the deserialization process can be sped up by setting
* `hash_compatible` to true. To be hash compatible, the Hash, KeyEqual and
* GrowthPolicy must behave the same way than the ones used on the serialized
* set. The `std::size_t` must also be of the same size as the one on the
* platform used to serialize the set. If these criteria are not met, the
* behaviour is undefined with `hash_compatible` sets to true.
*
* The behaviour is undefined if the type `Key` of the `sparse_set` is not the
* same as the type used during serialization.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, size of int, ...) of the types it deserializes in the hands of the
* `Deserializer` function object if compatibility is required.
*/
template <class Deserializer>
static sparse_set deserialize(Deserializer &deserializer,
bool hash_compatible = false) {
sparse_set set(0);
set.m_ht.deserialize(deserializer, hash_compatible);
return set;
}
friend bool operator==(const sparse_set &lhs, const sparse_set &rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto &element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs);
if (it_element_rhs == rhs.cend()) {
return false;
}
}
return true;
}
friend bool operator!=(const sparse_set &lhs, const sparse_set &rhs) {
return !operator==(lhs, rhs);
}
friend void swap(sparse_set &lhs, sparse_set &rhs) { lhs.swap(rhs); }
private:
ht m_ht;
};
/**
* Same as `tsl::sparse_set<Key, Hash, KeyEqual, Allocator,
* tsl::sh::prime_growth_policy>`.
*/
template <class Key, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<Key>>
using sparse_pg_set =
sparse_set<Key, Hash, KeyEqual, Allocator, tsl::sh::prime_growth_policy>;
} // end namespace tsl
#endif

View File

@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <cstdint>
#include <cstddef>
#include <any>
#include "any_wrappers.h"
namespace diskann
{
typedef uint32_t location_t;
using DataType = std::any;
using TagType = std::any;
using LabelType = std::any;
using TagVector = AnyWrapper::AnyVector;
using DataVector = AnyWrapper::AnyVector;
using Labelvector = AnyWrapper::AnyVector;
using TagRobinSet = AnyWrapper::AnyRobinSet;
} // namespace diskann

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#ifdef _WINDOWS
#ifndef USE_BING_INFRA
#include <Windows.h>
#include <fcntl.h>
#include <malloc.h>
#include <minwinbase.h>
#include <cstdio>
#include <mutex>
#include <thread>
#include "aligned_file_reader.h"
#include "tsl/robin_map.h"
#include "utils.h"
#include "windows_customizations.h"
class WindowsAlignedFileReader : public AlignedFileReader
{
private:
#ifdef UNICODE
std::wstring m_filename;
#else
std::string m_filename;
#endif
protected:
// virtual IOContext createContext();
public:
DISKANN_DLLEXPORT WindowsAlignedFileReader(){};
DISKANN_DLLEXPORT virtual ~WindowsAlignedFileReader(){};
// Open & close ops
// Blocking calls
DISKANN_DLLEXPORT virtual void open(const std::string &fname) override;
DISKANN_DLLEXPORT virtual void close() override;
DISKANN_DLLEXPORT virtual void register_thread() override;
DISKANN_DLLEXPORT virtual void deregister_thread() override
{
// TODO: Needs implementation.
}
DISKANN_DLLEXPORT virtual void deregister_all_threads() override
{
// TODO: Needs implementation.
}
DISKANN_DLLEXPORT virtual IOContext &get_ctx() override;
// process batch of aligned requests in parallel
// NOTE :: blocking call for the calling thread, but can thread-safe
DISKANN_DLLEXPORT virtual void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async) override;
};
#endif // USE_BING_INFRA
#endif //_WINDOWS

View File

@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#ifdef _WINDOWS
#ifdef _WINDLL
#define DISKANN_DLLEXPORT __declspec(dllexport)
#else
#define DISKANN_DLLEXPORT __declspec(dllimport)
#endif
#else
#define DISKANN_DLLEXPORT
#endif

View File

@@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#include "Windows.h"
namespace diskann
{
// A thin C++ wrapper around Windows exclusive functionality of Windows
// SlimReaderWriterLock.
//
// The SlimReaderWriterLock is simpler/more lightweight than std::mutex
// (8 bytes vs 80 bytes), which is useful in the scenario where DiskANN has
// one lock per vector in the index. It does not support recursive locking and
// requires Windows Vista or later.
//
// Full documentation can be found at.
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa904937(v=vs.85).aspx
class windows_exclusive_slim_lock
{
public:
windows_exclusive_slim_lock() : _lock(SRWLOCK_INIT)
{
}
// The lock is non-copyable. This also disables move constructor/operator=.
windows_exclusive_slim_lock(const windows_exclusive_slim_lock &) = delete;
windows_exclusive_slim_lock &operator=(const windows_exclusive_slim_lock &) = delete;
void lock()
{
return AcquireSRWLockExclusive(&_lock);
}
bool try_lock()
{
return TryAcquireSRWLockExclusive(&_lock) != FALSE;
}
void unlock()
{
return ReleaseSRWLockExclusive(&_lock);
}
private:
SRWLOCK _lock;
};
// An exclusive lock over a SlimReaderWriterLock.
class windows_exclusive_slim_lock_guard
{
public:
windows_exclusive_slim_lock_guard(windows_exclusive_slim_lock &p_lock) : _lock(p_lock)
{
_lock.lock();
}
// The lock is non-copyable. This also disables move constructor/operator=.
windows_exclusive_slim_lock_guard(const windows_exclusive_slim_lock_guard &) = delete;
windows_exclusive_slim_lock_guard &operator=(const windows_exclusive_slim_lock_guard &) = delete;
~windows_exclusive_slim_lock_guard()
{
_lock.unlock();
}
private:
windows_exclusive_slim_lock &_lock;
};
} // namespace diskann