Initial commit
This commit is contained in:
76
packages/leann-backend-hnsw/third_party/faiss/contrib/README.md
vendored
Normal file
76
packages/leann-backend-hnsw/third_party/faiss/contrib/README.md
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
|
||||
# The contrib modules
|
||||
|
||||
The contrib directory contains helper modules for Faiss for various tasks.
|
||||
|
||||
## Code structure
|
||||
|
||||
The contrib directory gets compiled in the module faiss.contrib.
|
||||
Note that although some of the modules may depend on additional modules (eg. GPU Faiss, pytorch, hdf5), they are not necessarily compiled in to avoid adding dependencies. It is the user's responsibility to provide them.
|
||||
|
||||
In contrib, we are progressively dropping python2 support.
|
||||
|
||||
## List of contrib modules
|
||||
|
||||
### rpc.py
|
||||
|
||||
A very simple Remote Procedure Call library, where function parameters and results are pickled, for use with client_server.py
|
||||
|
||||
### client_server.py
|
||||
|
||||
The server handles requests to a Faiss index. The client calls the remote index.
|
||||
This is mainly to shard datasets over several machines, see [Distributed index](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#distributed-index)
|
||||
|
||||
### ondisk.py
|
||||
|
||||
Encloses the main logic to merge indexes into an on-disk index.
|
||||
See [On-disk storage](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#on-disk-storage)
|
||||
|
||||
### exhaustive_search.py
|
||||
|
||||
Computes the ground-truth search results for a dataset that possibly does not fit in RAM. Uses GPU if available.
|
||||
Tested in `tests/test_contrib.TestComputeGT`
|
||||
|
||||
### torch_utils.py
|
||||
|
||||
Interoperability functions for pytorch and Faiss: Importing this will allow pytorch Tensors (CPU or GPU) to be used as arguments to Faiss indexes and other functions. Torch GPU tensors can only be used with Faiss GPU indexes. If this is imported with a package that supports Faiss GPU, the necessary stream synchronization with the current pytorch stream will be automatically performed.
|
||||
|
||||
Numpy ndarrays can continue to be used in the Faiss python interface after importing this file. All arguments must be uniformly either numpy ndarrays or Torch tensors; no mixing is allowed.
|
||||
|
||||
Tested in `tests/test_contrib_torch.py` (CPU) and `gpu/test/test_contrib_torch_gpu.py` (GPU).
|
||||
|
||||
### inspect_tools.py
|
||||
|
||||
Functions to inspect C++ objects wrapped by SWIG. Most often this just means reading
|
||||
fields and converting them to the proper python array.
|
||||
|
||||
### ivf_tools.py
|
||||
|
||||
A few functions to override the coarse quantizer in IVF, providing additional flexibility for assignment.
|
||||
|
||||
### datasets.py
|
||||
|
||||
(may require h5py)
|
||||
|
||||
Definition of how to access data for some standard datasets.
|
||||
|
||||
### factory_tools.py
|
||||
|
||||
Functions related to factory strings.
|
||||
|
||||
### evaluation.py
|
||||
|
||||
A few non-trivial evaluation functions for search results
|
||||
|
||||
### clustering.py
|
||||
|
||||
Contains:
|
||||
|
||||
- a Python implementation of kmeans, that can be used for special datatypes (eg. sparse matrices).
|
||||
|
||||
- a 2-level clustering routine and a function that can apply it to train an IndexIVF
|
||||
|
||||
### big_batch_search.py
|
||||
|
||||
Search IVF indexes with one centroid after another. Useful for large
|
||||
databases that do not fit in RAM *and* a large number of queries.
|
||||
0
packages/leann-backend-hnsw/third_party/faiss/contrib/__init__.py
vendored
Normal file
0
packages/leann-backend-hnsw/third_party/faiss/contrib/__init__.py
vendored
Normal file
515
packages/leann-backend-hnsw/third_party/faiss/contrib/big_batch_search.py
vendored
Normal file
515
packages/leann-backend-hnsw/third_party/faiss/contrib/big_batch_search.py
vendored
Normal file
@@ -0,0 +1,515 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import pickle
|
||||
import os
|
||||
import logging
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import threading
|
||||
import _thread
|
||||
from queue import Queue
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
import numpy as np
|
||||
import faiss
|
||||
|
||||
from faiss.contrib.inspect_tools import get_invlist
|
||||
|
||||
|
||||
class BigBatchSearcher:
|
||||
"""
|
||||
Object that manages all the data related to the computation
|
||||
except the actual within-bucket matching and the organization of the
|
||||
computation (parallel or not)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index, xq, k,
|
||||
verbose=0,
|
||||
use_float16=False):
|
||||
|
||||
# verbosity
|
||||
self.verbose = verbose
|
||||
self.tictoc = []
|
||||
|
||||
self.xq = xq
|
||||
self.index = index
|
||||
self.use_float16 = use_float16
|
||||
keep_max = faiss.is_similarity_metric(index.metric_type)
|
||||
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
|
||||
self.t_accu = [0] * 6
|
||||
self.t_display = self.t0 = time.time()
|
||||
|
||||
def start_t_accu(self):
|
||||
self.t_accu_t0 = time.time()
|
||||
|
||||
def stop_t_accu(self, n):
|
||||
self.t_accu[n] += time.time() - self.t_accu_t0
|
||||
|
||||
def tic(self, name):
|
||||
self.tictoc = (name, time.time())
|
||||
if self.verbose > 0:
|
||||
print(name, end="\r", flush=True)
|
||||
|
||||
def toc(self):
|
||||
name, t0 = self.tictoc
|
||||
dt = time.time() - t0
|
||||
if self.verbose > 0:
|
||||
print(f"{name}: {dt:.3f} s")
|
||||
return dt
|
||||
|
||||
def report(self, l):
|
||||
if self.verbose == 1 or (
|
||||
self.verbose == 2 and (
|
||||
l > 1000 and time.time() < self.t_display + 1.0
|
||||
)
|
||||
):
|
||||
return
|
||||
t = time.time() - self.t0
|
||||
print(
|
||||
f"[{t:.1f} s] list {l}/{self.index.nlist} "
|
||||
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
|
||||
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
|
||||
f"wait in {self.t_accu[4]:.3f} "
|
||||
f"wait out {self.t_accu[5]:.3f} "
|
||||
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
|
||||
f"mem {faiss.get_mem_usage_kb()}",
|
||||
end="\r" if self.verbose <= 2 else "\n",
|
||||
flush=True,
|
||||
)
|
||||
self.t_display = time.time()
|
||||
|
||||
def coarse_quantization(self):
|
||||
self.tic("coarse quantization")
|
||||
bs = 65536
|
||||
nq = len(self.xq)
|
||||
q_assign = np.empty((nq, self.index.nprobe), dtype='int32')
|
||||
for i0 in range(0, nq, bs):
|
||||
i1 = min(nq, i0 + bs)
|
||||
q_dis_i, q_assign_i = self.index.quantizer.search(
|
||||
self.xq[i0:i1], self.index.nprobe)
|
||||
# q_dis[i0:i1] = q_dis_i
|
||||
q_assign[i0:i1] = q_assign_i
|
||||
self.toc()
|
||||
self.q_assign = q_assign
|
||||
|
||||
def reorder_assign(self):
|
||||
self.tic("bucket sort")
|
||||
q_assign = self.q_assign
|
||||
q_assign += 1 # move -1 -> 0
|
||||
self.bucket_lims = faiss.matrix_bucket_sort_inplace(
|
||||
self.q_assign, nbucket=self.index.nlist + 1, nt=16)
|
||||
self.query_ids = self.q_assign.ravel()
|
||||
if self.verbose > 0:
|
||||
print(' number of -1s:', self.bucket_lims[1])
|
||||
self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s
|
||||
del self.q_assign # inplace so let's forget about the old version...
|
||||
self.toc()
|
||||
|
||||
def prepare_bucket(self, l):
|
||||
""" prepare the queries and database items for bucket l"""
|
||||
t0 = time.time()
|
||||
index = self.index
|
||||
# prepare queries
|
||||
i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1]
|
||||
q_subset = self.query_ids[i0:i1]
|
||||
xq_l = self.xq[q_subset]
|
||||
if self.by_residual:
|
||||
xq_l = xq_l - index.quantizer.reconstruct(l)
|
||||
t1 = time.time()
|
||||
# prepare database side
|
||||
list_ids, xb_l = get_invlist(index.invlists, l)
|
||||
|
||||
if self.decode_func is None:
|
||||
xb_l = xb_l.ravel()
|
||||
else:
|
||||
xb_l = self.decode_func(xb_l)
|
||||
|
||||
if self.use_float16:
|
||||
xb_l = xb_l.astype('float16')
|
||||
xq_l = xq_l.astype('float16')
|
||||
|
||||
t2 = time.time()
|
||||
self.t_accu[0] += t1 - t0
|
||||
self.t_accu[1] += t2 - t1
|
||||
return q_subset, xq_l, list_ids, xb_l
|
||||
|
||||
def add_results_to_heap(self, q_subset, D, list_ids, I):
|
||||
"""add the bucket results to the heap structure"""
|
||||
if D is None:
|
||||
return
|
||||
t0 = time.time()
|
||||
if I is None:
|
||||
I = list_ids
|
||||
else:
|
||||
I = list_ids[I]
|
||||
self.rh.add_result_subset(q_subset, D, I)
|
||||
self.t_accu[3] += time.time() - t0
|
||||
|
||||
def sizes_in_checkpoint(self):
|
||||
return (self.xq.shape, self.index.nprobe, self.index.nlist)
|
||||
|
||||
def write_checkpoint(self, fname, completed):
|
||||
# write to temp file then move to final file
|
||||
tmpname = fname + ".tmp"
|
||||
with open(tmpname, "wb") as f:
|
||||
pickle.dump(
|
||||
{
|
||||
"sizes": self.sizes_in_checkpoint(),
|
||||
"completed": completed,
|
||||
"rh": (self.rh.D, self.rh.I),
|
||||
}, f, -1)
|
||||
os.replace(tmpname, fname)
|
||||
|
||||
def read_checkpoint(self, fname):
|
||||
with open(fname, "rb") as f:
|
||||
ckp = pickle.load(f)
|
||||
assert ckp["sizes"] == self.sizes_in_checkpoint()
|
||||
self.rh.D[:] = ckp["rh"][0]
|
||||
self.rh.I[:] = ckp["rh"][1]
|
||||
return ckp["completed"]
|
||||
|
||||
|
||||
class BlockComputer:
|
||||
""" computation within one bucket """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index,
|
||||
method="knn_function",
|
||||
pairwise_distances=faiss.pairwise_distances,
|
||||
knn=faiss.knn):
|
||||
|
||||
self.index = index
|
||||
if index.__class__ == faiss.IndexIVFFlat:
|
||||
index_help = faiss.IndexFlat(index.d, index.metric_type)
|
||||
decode_func = lambda x: x.view("float32")
|
||||
by_residual = False
|
||||
elif index.__class__ == faiss.IndexIVFPQ:
|
||||
index_help = faiss.IndexPQ(
|
||||
index.d, index.pq.M, index.pq.nbits, index.metric_type)
|
||||
index_help.pq = index.pq
|
||||
decode_func = index_help.pq.decode
|
||||
index_help.is_trained = True
|
||||
by_residual = index.by_residual
|
||||
elif index.__class__ == faiss.IndexIVFScalarQuantizer:
|
||||
index_help = faiss.IndexScalarQuantizer(
|
||||
index.d, index.sq.qtype, index.metric_type)
|
||||
index_help.sq = index.sq
|
||||
decode_func = index_help.sq.decode
|
||||
index_help.is_trained = True
|
||||
by_residual = index.by_residual
|
||||
else:
|
||||
raise RuntimeError(f"index type {index.__class__} not supported")
|
||||
self.index_help = index_help
|
||||
self.decode_func = None if method == "index" else decode_func
|
||||
self.by_residual = by_residual
|
||||
self.method = method
|
||||
self.pairwise_distances = pairwise_distances
|
||||
self.knn = knn
|
||||
|
||||
def block_search(self, xq_l, xb_l, list_ids, k, **extra_args):
|
||||
metric_type = self.index.metric_type
|
||||
if xq_l.size == 0 or xb_l.size == 0:
|
||||
D = I = None
|
||||
elif self.method == "index":
|
||||
faiss.copy_array_to_vector(xb_l, self.index_help.codes)
|
||||
self.index_help.ntotal = len(list_ids)
|
||||
D, I = self.index_help.search(xq_l, k)
|
||||
elif self.method == "pairwise_distances":
|
||||
# TODO implement blockwise to avoid mem blowup
|
||||
D = self.pairwise_distances(xq_l, xb_l, metric=metric_type)
|
||||
I = None
|
||||
elif self.method == "knn_function":
|
||||
D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args)
|
||||
|
||||
return D, I
|
||||
|
||||
|
||||
def big_batch_search(
|
||||
index, xq, k,
|
||||
method="knn_function",
|
||||
pairwise_distances=faiss.pairwise_distances,
|
||||
knn=faiss.knn,
|
||||
verbose=0,
|
||||
threaded=0,
|
||||
use_float16=False,
|
||||
prefetch_threads=1,
|
||||
computation_threads=1,
|
||||
q_assign=None,
|
||||
checkpoint=None,
|
||||
checkpoint_freq=7200,
|
||||
start_list=0,
|
||||
end_list=None,
|
||||
crash_at=-1
|
||||
):
|
||||
"""
|
||||
Search queries xq in the IVF index, with a search function that collects
|
||||
batches of query vectors per inverted list. This can be faster than the
|
||||
regular search indexes.
|
||||
Supports IVFFlat, IVFPQ and IVFScalarQuantizer.
|
||||
|
||||
Supports three computation methods:
|
||||
method = "index":
|
||||
build a flat index and populate it separately for each index
|
||||
method = "pairwise_distances":
|
||||
decompress codes and compute all pairwise distances for the queries
|
||||
and index and add result to heap
|
||||
method = "knn_function":
|
||||
decompress codes and compute knn results for the queries
|
||||
|
||||
threaded=0: sequential execution
|
||||
threaded=1: prefetch next bucket while computing the current one
|
||||
threaded=2: prefetch prefetch_threads buckets at a time.
|
||||
|
||||
compute_threads>1: the knn function will get an additional thread_no that
|
||||
tells which worker should handle this.
|
||||
|
||||
In threaded mode, the computation is tiled with the bucket perparation and
|
||||
the writeback of results (useful to maximize GPU utilization).
|
||||
|
||||
use_float16: convert all matrices to float16 (faster for GPU gemm)
|
||||
|
||||
q_assign: override coarse assignment, should be a matrix of size nq * nprobe
|
||||
|
||||
checkpointing (only for threaded > 1):
|
||||
checkpoint: file where the checkpoints are stored
|
||||
checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded
|
||||
|
||||
start_list, end_list: process only a subset of invlists
|
||||
"""
|
||||
nprobe = index.nprobe
|
||||
|
||||
assert method in ("index", "pairwise_distances", "knn_function")
|
||||
|
||||
mem_queries = xq.nbytes
|
||||
mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize
|
||||
mem_res = len(xq) * k * (
|
||||
np.dtype('int64').itemsize
|
||||
+ np.dtype('float32').itemsize
|
||||
)
|
||||
mem_tot = mem_queries + mem_assign + mem_res
|
||||
if verbose > 0:
|
||||
logging.info(
|
||||
f"memory: queries {mem_queries} assign {mem_assign} "
|
||||
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
|
||||
)
|
||||
|
||||
bbs = BigBatchSearcher(
|
||||
index, xq, k,
|
||||
verbose=verbose,
|
||||
use_float16=use_float16
|
||||
)
|
||||
|
||||
comp = BlockComputer(
|
||||
index,
|
||||
method=method,
|
||||
pairwise_distances=pairwise_distances,
|
||||
knn=knn
|
||||
)
|
||||
|
||||
bbs.decode_func = comp.decode_func
|
||||
|
||||
bbs.by_residual = comp.by_residual
|
||||
if q_assign is None:
|
||||
bbs.coarse_quantization()
|
||||
else:
|
||||
bbs.q_assign = q_assign
|
||||
bbs.reorder_assign()
|
||||
|
||||
if end_list is None:
|
||||
end_list = index.nlist
|
||||
|
||||
completed = set()
|
||||
if checkpoint is not None:
|
||||
assert (start_list, end_list) == (0, index.nlist)
|
||||
if os.path.exists(checkpoint):
|
||||
logging.info(f"recovering checkpoint: {checkpoint}")
|
||||
completed = bbs.read_checkpoint(checkpoint)
|
||||
logging.info(f" already completed: {len(completed)}")
|
||||
else:
|
||||
logging.info("no checkpoint: starting from scratch")
|
||||
|
||||
if threaded == 0:
|
||||
# simple sequential version
|
||||
|
||||
for l in range(start_list, end_list):
|
||||
bbs.report(l)
|
||||
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
|
||||
t0i = time.time()
|
||||
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
||||
bbs.t_accu[2] += time.time() - t0i
|
||||
bbs.add_results_to_heap(q_subset, D, list_ids, I)
|
||||
|
||||
elif threaded == 1:
|
||||
|
||||
# parallel version with granularity 1
|
||||
|
||||
def add_results_and_prefetch(to_add, l):
|
||||
""" perform the addition for the previous bucket and
|
||||
prefetch the next (if applicable) """
|
||||
if to_add is not None:
|
||||
bbs.add_results_to_heap(*to_add)
|
||||
if l < index.nlist:
|
||||
return bbs.prepare_bucket(l)
|
||||
|
||||
prefetched_bucket = bbs.prepare_bucket(start_list)
|
||||
to_add = None
|
||||
pool = ThreadPool(1)
|
||||
|
||||
for l in range(start_list, end_list):
|
||||
bbs.report(l)
|
||||
prefetched_bucket_a = pool.apply_async(
|
||||
add_results_and_prefetch, (to_add, l + 1))
|
||||
q_subset, xq_l, list_ids, xb_l = prefetched_bucket
|
||||
bbs.start_t_accu()
|
||||
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
||||
bbs.stop_t_accu(2)
|
||||
to_add = q_subset, D, list_ids, I
|
||||
bbs.start_t_accu()
|
||||
prefetched_bucket = prefetched_bucket_a.get()
|
||||
bbs.stop_t_accu(4)
|
||||
|
||||
bbs.add_results_to_heap(*to_add)
|
||||
pool.close()
|
||||
else:
|
||||
|
||||
def task_manager_thread(
|
||||
task,
|
||||
pool_size,
|
||||
start_task,
|
||||
end_task,
|
||||
completed,
|
||||
output_queue,
|
||||
input_queue,
|
||||
):
|
||||
try:
|
||||
with ThreadPool(pool_size) as pool:
|
||||
res = [pool.apply_async(
|
||||
task,
|
||||
args=(i, output_queue, input_queue))
|
||||
for i in range(start_task, end_task)
|
||||
if i not in completed]
|
||||
for r in res:
|
||||
r.get()
|
||||
pool.close()
|
||||
pool.join()
|
||||
output_queue.put(None)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
_thread.interrupt_main()
|
||||
raise
|
||||
|
||||
def task_manager(*args):
|
||||
task_manager = threading.Thread(
|
||||
target=task_manager_thread,
|
||||
args=args,
|
||||
)
|
||||
task_manager.daemon = True
|
||||
task_manager.start()
|
||||
return task_manager
|
||||
|
||||
def prepare_task(task_id, output_queue, input_queue=None):
|
||||
try:
|
||||
logging.info(f"Prepare start: {task_id}")
|
||||
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
|
||||
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
|
||||
logging.info(f"Prepare end: {task_id}")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
_thread.interrupt_main()
|
||||
raise
|
||||
|
||||
def compute_task(task_id, output_queue, input_queue):
|
||||
try:
|
||||
logging.info(f"Compute start: {task_id}")
|
||||
t_wait_out = 0
|
||||
while True:
|
||||
t0 = time.time()
|
||||
logging.info(f'Compute input: task {task_id}')
|
||||
input_value = input_queue.get()
|
||||
t_wait_in = time.time() - t0
|
||||
if input_value is None:
|
||||
# signal for other compute tasks
|
||||
input_queue.put(None)
|
||||
break
|
||||
centroid, q_subset, xq_l, list_ids, xb_l = input_value
|
||||
logging.info(f'Compute work: task {task_id}, centroid {centroid}')
|
||||
t0 = time.time()
|
||||
if computation_threads > 1:
|
||||
D, I = comp.block_search(
|
||||
xq_l, xb_l, list_ids, k, thread_id=task_id
|
||||
)
|
||||
else:
|
||||
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
||||
t_compute = time.time() - t0
|
||||
logging.info(f'Compute output: task {task_id}, centroid {centroid}')
|
||||
t0 = time.time()
|
||||
output_queue.put(
|
||||
(centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I)
|
||||
)
|
||||
t_wait_out = time.time() - t0
|
||||
logging.info(f"Compute end: {task_id}")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
_thread.interrupt_main()
|
||||
raise
|
||||
|
||||
prepare_to_compute_queue = Queue(2)
|
||||
compute_to_main_queue = Queue(2)
|
||||
compute_task_manager = task_manager(
|
||||
compute_task,
|
||||
computation_threads,
|
||||
0,
|
||||
computation_threads,
|
||||
set(),
|
||||
compute_to_main_queue,
|
||||
prepare_to_compute_queue,
|
||||
)
|
||||
prepare_task_manager = task_manager(
|
||||
prepare_task,
|
||||
prefetch_threads,
|
||||
start_list,
|
||||
end_list,
|
||||
completed,
|
||||
prepare_to_compute_queue,
|
||||
None,
|
||||
)
|
||||
|
||||
t_checkpoint = time.time()
|
||||
while True:
|
||||
logging.info("Waiting for result")
|
||||
value = compute_to_main_queue.get()
|
||||
if not value:
|
||||
break
|
||||
centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value
|
||||
# to test checkpointing
|
||||
if centroid == crash_at:
|
||||
1 / 0
|
||||
bbs.t_accu[2] += t_compute
|
||||
bbs.t_accu[4] += t_wait_in
|
||||
bbs.t_accu[5] += t_wait_out
|
||||
logging.info(f"Adding to heap start: centroid {centroid}")
|
||||
bbs.add_results_to_heap(q_subset, D, list_ids, I)
|
||||
logging.info(f"Adding to heap end: centroid {centroid}")
|
||||
completed.add(centroid)
|
||||
bbs.report(centroid)
|
||||
if checkpoint is not None:
|
||||
if time.time() - t_checkpoint > checkpoint_freq:
|
||||
logging.info("writing checkpoint")
|
||||
bbs.write_checkpoint(checkpoint, completed)
|
||||
t_checkpoint = time.time()
|
||||
|
||||
prepare_task_manager.join()
|
||||
compute_task_manager.join()
|
||||
|
||||
bbs.tic("finalize heap")
|
||||
bbs.rh.finalize()
|
||||
bbs.toc()
|
||||
|
||||
return bbs.rh.D, bbs.rh.I
|
||||
91
packages/leann-backend-hnsw/third_party/faiss/contrib/client_server.py
vendored
Executable file
91
packages/leann-backend-hnsw/third_party/faiss/contrib/client_server.py
vendored
Executable file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import faiss
|
||||
from typing import List, Tuple
|
||||
|
||||
from . import rpc
|
||||
|
||||
############################################################
|
||||
# Server implementation
|
||||
############################################################
|
||||
|
||||
|
||||
class SearchServer(rpc.Server):
|
||||
""" Assign version that can be exposed via RPC """
|
||||
|
||||
def __init__(self, s: int, index: faiss.Index):
|
||||
rpc.Server.__init__(self, s)
|
||||
self.index = index
|
||||
self.index_ivf = faiss.extract_index_ivf(index)
|
||||
|
||||
def set_nprobe(self, nprobe: int) -> int:
|
||||
""" set nprobe field """
|
||||
self.index_ivf.nprobe = nprobe
|
||||
|
||||
def get_ntotal(self) -> int:
|
||||
return self.index.ntotal
|
||||
|
||||
def __getattr__(self, f):
|
||||
# all other functions get forwarded to the index
|
||||
return getattr(self.index, f)
|
||||
|
||||
|
||||
def run_index_server(index: faiss.Index, port: int, v6: bool = False):
|
||||
""" serve requests for that index forerver """
|
||||
rpc.run_server(
|
||||
lambda s: SearchServer(s, index),
|
||||
port, v6=v6)
|
||||
|
||||
|
||||
############################################################
|
||||
# Client implementation
|
||||
############################################################
|
||||
|
||||
class ClientIndex:
|
||||
"""manages a set of distance sub-indexes. The sub_indexes search a
|
||||
subset of the inverted lists. Searches are merged afterwards
|
||||
"""
|
||||
|
||||
def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False):
|
||||
""" connect to a series of (host, port) pairs """
|
||||
self.sub_indexes = []
|
||||
for machine, port in machine_ports:
|
||||
self.sub_indexes.append(rpc.Client(machine, port, v6))
|
||||
|
||||
self.ni = len(self.sub_indexes)
|
||||
# pool of threads. Each thread manages one sub-index.
|
||||
self.pool = ThreadPool(self.ni)
|
||||
# test connection...
|
||||
self.ntotal = self.get_ntotal()
|
||||
self.verbose = False
|
||||
|
||||
def set_nprobe(self, nprobe: int) -> None:
|
||||
self.pool.map(
|
||||
lambda idx: idx.set_nprobe(nprobe),
|
||||
self.sub_indexes
|
||||
)
|
||||
|
||||
def set_omp_num_threads(self, nt: int) -> None:
|
||||
self.pool.map(
|
||||
lambda idx: idx.set_omp_num_threads(nt),
|
||||
self.sub_indexes
|
||||
)
|
||||
|
||||
def get_ntotal(self) -> None:
|
||||
return sum(self.pool.map(
|
||||
lambda idx: idx.get_ntotal(),
|
||||
self.sub_indexes
|
||||
))
|
||||
|
||||
def search(self, x, k: int):
|
||||
|
||||
rh = faiss.ResultHeap(x.shape[0], k)
|
||||
|
||||
for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes):
|
||||
rh.add_result(Di, Ii)
|
||||
rh.finalize()
|
||||
return rh.D, rh.I
|
||||
428
packages/leann-backend-hnsw/third_party/faiss/contrib/clustering.py
vendored
Normal file
428
packages/leann-backend-hnsw/third_party/faiss/contrib/clustering.py
vendored
Normal file
@@ -0,0 +1,428 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This contrib module contains a few routines useful to do clustering variants.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import faiss
|
||||
import time
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
|
||||
try:
|
||||
import scipy.sparse
|
||||
except ImportError:
|
||||
print("scipy not accessible, Python k-means will not work")
|
||||
|
||||
def print_nop(*arg, **kwargs):
|
||||
pass
|
||||
|
||||
def two_level_clustering(xt, nc1, nc2, rebalance=True, clustering_niter=25, **args):
|
||||
"""
|
||||
perform 2-level clustering on a training set xt
|
||||
nc1 and nc2 are the number of clusters at each level, the final number of
|
||||
clusters is nc2. Additional arguments are passed to the Kmeans object.
|
||||
|
||||
Rebalance allocates the number of sub-clusters depending on the number of
|
||||
first-level assignment.
|
||||
"""
|
||||
d = xt.shape[1]
|
||||
|
||||
verbose = args.get("verbose", False)
|
||||
|
||||
log = print if verbose else print_nop
|
||||
|
||||
log(f"2-level clustering of {xt.shape} nb 1st level clusters = {nc1} total {nc2}")
|
||||
log("perform coarse training")
|
||||
|
||||
km = faiss.Kmeans(
|
||||
d, nc1, niter=clustering_niter,
|
||||
max_points_per_centroid=2000,
|
||||
**args
|
||||
)
|
||||
km.train(xt)
|
||||
|
||||
iteration_stats = [km.iteration_stats]
|
||||
log()
|
||||
|
||||
# coarse centroids
|
||||
centroids1 = km.centroids
|
||||
|
||||
log("assigning the training set")
|
||||
t0 = time.time()
|
||||
_, assign1 = km.assign(xt)
|
||||
bc = np.bincount(assign1, minlength=nc1)
|
||||
log(f"done in {time.time() - t0:.2f} s. Sizes of clusters {min(bc)}-{max(bc)}")
|
||||
o = assign1.argsort()
|
||||
del km
|
||||
|
||||
if not rebalance:
|
||||
# make sure the sub-clusters sum up to exactly nc2
|
||||
cc = np.arange(nc1 + 1) * nc2 // nc1
|
||||
all_nc2 = cc[1:] - cc[:-1]
|
||||
else:
|
||||
bc_sum = np.cumsum(bc)
|
||||
all_nc2 = bc_sum * nc2 // bc_sum[-1]
|
||||
all_nc2[1:] -= all_nc2[:-1]
|
||||
assert sum(all_nc2) == nc2
|
||||
log(f"nb 2nd-level centroids {min(all_nc2)}-{max(all_nc2)}")
|
||||
|
||||
# train sub-clusters
|
||||
i0 = 0
|
||||
c2 = []
|
||||
t0 = time.time()
|
||||
for c1 in range(nc1):
|
||||
nc2 = int(all_nc2[c1])
|
||||
log(f"[{time.time() - t0:.2f} s] training sub-cluster {c1}/{nc1} nc2={nc2}\r", end="", flush=True)
|
||||
i1 = i0 + bc[c1]
|
||||
subset = o[i0:i1]
|
||||
assert np.all(assign1[subset] == c1)
|
||||
km = faiss.Kmeans(d, nc2, **args)
|
||||
xtsub = xt[subset]
|
||||
km.train(xtsub)
|
||||
iteration_stats.append(km.iteration_stats)
|
||||
c2.append(km.centroids)
|
||||
del km
|
||||
i0 = i1
|
||||
log(f"done in {time.time() - t0:.2f} s")
|
||||
return np.vstack(c2), iteration_stats
|
||||
|
||||
|
||||
def train_ivf_index_with_2level(index, xt, **args):
|
||||
"""
|
||||
Applies 2-level clustering to an index_ivf embedded in an index.
|
||||
"""
|
||||
# handle PreTransforms
|
||||
index = faiss.downcast_index(index)
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
for i in range(index.chain.size()):
|
||||
vt = index.chain.at(i)
|
||||
vt.train(xt)
|
||||
xt = vt.apply(xt)
|
||||
train_ivf_index_with_2level(index.index, xt, **args)
|
||||
index.is_trained = True
|
||||
return
|
||||
assert isinstance(index, faiss.IndexIVF)
|
||||
assert index.metric_type == faiss.METRIC_L2
|
||||
# now do 2-level clustering
|
||||
nc1 = int(np.sqrt(index.nlist))
|
||||
print("REBALANCE=", args)
|
||||
|
||||
centroids, _ = two_level_clustering(xt, nc1, index.nlist, **args)
|
||||
index.quantizer.train(centroids)
|
||||
index.quantizer.add(centroids)
|
||||
# finish training
|
||||
index.train(xt)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# K-means implementation in Python
|
||||
#
|
||||
# It relies on DatasetAssign, an abstraction of the training vectors that offers
|
||||
# the minimal set of operations to perform k-means clustering.
|
||||
###############################################################################
|
||||
|
||||
|
||||
class DatasetAssign:
|
||||
"""Wrapper for a matrix that offers a function to assign the vectors
|
||||
to centroids. All other implementations offer the same interface"""
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
def count(self):
|
||||
return self.x.shape[0]
|
||||
|
||||
def dim(self):
|
||||
return self.x.shape[1]
|
||||
|
||||
def get_subset(self, indices):
|
||||
return self.x[indices]
|
||||
|
||||
def perform_search(self, centroids):
|
||||
return faiss.knn(self.x, centroids, 1)
|
||||
|
||||
def assign_to(self, centroids, weights=None):
|
||||
D, I = self.perform_search(centroids)
|
||||
|
||||
I = I.ravel()
|
||||
D = D.ravel()
|
||||
nc, d = centroids.shape
|
||||
sum_per_centroid = np.zeros((nc, d), dtype='float32')
|
||||
if weights is None:
|
||||
np.add.at(sum_per_centroid, I, self.x)
|
||||
else:
|
||||
np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x)
|
||||
|
||||
return I, D, sum_per_centroid
|
||||
|
||||
|
||||
class DatasetAssignGPU(DatasetAssign):
|
||||
""" GPU version of the previous """
|
||||
|
||||
def __init__(self, x, gpu_id, verbose=False):
|
||||
DatasetAssign.__init__(self, x)
|
||||
index = faiss.IndexFlatL2(x.shape[1])
|
||||
if gpu_id >= 0:
|
||||
self.index = faiss.index_cpu_to_gpu(
|
||||
faiss.StandardGpuResources(),
|
||||
gpu_id, index)
|
||||
else:
|
||||
# -1 -> assign to all GPUs
|
||||
self.index = faiss.index_cpu_to_all_gpus(index)
|
||||
|
||||
def perform_search(self, centroids):
|
||||
self.index.reset()
|
||||
self.index.add(centroids)
|
||||
return self.index.search(self.x, 1)
|
||||
|
||||
|
||||
def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None):
|
||||
""" assignment function for xq is sparse, xb is dense
|
||||
uses a matrix multiplication. The squared norms can be provided if
|
||||
available.
|
||||
"""
|
||||
nq = xq.shape[0]
|
||||
nb = xb.shape[0]
|
||||
if xb_norms is None:
|
||||
xb_norms = (xb ** 2).sum(1)
|
||||
if xq_norms is None:
|
||||
xq_norms = np.array(xq.power(2).sum(1))
|
||||
d2 = xb_norms - 2 * xq @ xb.T
|
||||
I = d2.argmin(axis=1)
|
||||
D = d2.ravel()[I + np.arange(nq) * nb] + xq_norms.ravel()
|
||||
return D, I
|
||||
|
||||
|
||||
def sparse_assign_to_dense_blocks(
|
||||
xq, xb, xq_norms=None, xb_norms=None, qbs=16384, bbs=16384, nt=None):
|
||||
"""
|
||||
decomposes the sparse_assign_to_dense function into blocks to avoid a
|
||||
possible memory blow up. Can be run in multithreaded mode, because scipy's
|
||||
sparse-dense matrix multiplication is single-threaded.
|
||||
"""
|
||||
nq = xq.shape[0]
|
||||
nb = xb.shape[0]
|
||||
D = np.empty(nq, dtype="float32")
|
||||
D.fill(np.inf)
|
||||
I = -np.ones(nq, dtype=int)
|
||||
|
||||
if xb_norms is None:
|
||||
xb_norms = (xb ** 2).sum(1)
|
||||
|
||||
def handle_query_block(i):
|
||||
xq_block = xq[i : i + qbs]
|
||||
Iblock = I[i : i + qbs]
|
||||
Dblock = D[i : i + qbs]
|
||||
if xq_norms is None:
|
||||
xq_norms_block = np.array(xq_block.power(2).sum(1))
|
||||
else:
|
||||
xq_norms_block = xq_norms[i : i + qbs]
|
||||
for j in range(0, nb, bbs):
|
||||
Di, Ii = sparse_assign_to_dense(
|
||||
xq_block,
|
||||
xb[j : j + bbs],
|
||||
xq_norms=xq_norms_block,
|
||||
xb_norms=xb_norms[j : j + bbs],
|
||||
)
|
||||
if j == 0:
|
||||
Iblock[:] = Ii
|
||||
Dblock[:] = Di
|
||||
else:
|
||||
mask = Di < Dblock
|
||||
Iblock[mask] = Ii[mask] + j
|
||||
Dblock[mask] = Di[mask]
|
||||
|
||||
if nt == 0 or nt == 1 or nq <= qbs:
|
||||
list(map(handle_query_block, range(0, nq, qbs)))
|
||||
else:
|
||||
pool = ThreadPool(nt)
|
||||
pool.map(handle_query_block, range(0, nq, qbs))
|
||||
|
||||
return D, I
|
||||
|
||||
|
||||
class DatasetAssignSparse(DatasetAssign):
|
||||
"""Wrapper for a matrix that offers a function to assign the vectors
|
||||
to centroids. All other implementations offer the same interface"""
|
||||
|
||||
def __init__(self, x):
|
||||
assert x.__class__ == scipy.sparse.csr_matrix
|
||||
self.x = x
|
||||
self.squared_norms = np.array(x.power(2).sum(1))
|
||||
|
||||
def get_subset(self, indices):
|
||||
return np.array(self.x[indices].todense())
|
||||
|
||||
def perform_search(self, centroids):
|
||||
return sparse_assign_to_dense_blocks(
|
||||
self.x, centroids, xq_norms=self.squared_norms)
|
||||
|
||||
def assign_to(self, centroids, weights=None):
|
||||
D, I = self.perform_search(centroids)
|
||||
|
||||
I = I.ravel()
|
||||
D = D.ravel()
|
||||
n = self.x.shape[0]
|
||||
if weights is None:
|
||||
weights = np.ones(n, dtype='float32')
|
||||
nc = len(centroids)
|
||||
|
||||
m = scipy.sparse.csc_matrix(
|
||||
(weights, I, np.arange(n + 1)),
|
||||
shape=(nc, n))
|
||||
sum_per_centroid = np.array((m * self.x).todense())
|
||||
|
||||
return I, D, sum_per_centroid
|
||||
|
||||
|
||||
def imbalance_factor(k, assign):
|
||||
assign = np.ascontiguousarray(assign, dtype='int64')
|
||||
return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign))
|
||||
|
||||
|
||||
def check_if_torch(x):
|
||||
if x.__class__ == np.ndarray:
|
||||
return False
|
||||
import torch
|
||||
if isinstance(x, torch.Tensor):
|
||||
return True
|
||||
raise NotImplementedError(f"Unknown tensor type {type(x)}")
|
||||
|
||||
|
||||
def reassign_centroids(hassign, centroids, rs=None):
|
||||
""" reassign centroids when some of them collapse """
|
||||
if rs is None:
|
||||
rs = np.random
|
||||
k, d = centroids.shape
|
||||
nsplit = 0
|
||||
is_torch = check_if_torch(centroids)
|
||||
|
||||
empty_cents = np.where(hassign == 0)[0]
|
||||
|
||||
if len(empty_cents) == 0:
|
||||
return 0
|
||||
|
||||
if is_torch:
|
||||
import torch
|
||||
fac = torch.ones_like(centroids[0])
|
||||
else:
|
||||
fac = np.ones_like(centroids[0])
|
||||
fac[::2] += 1 / 1024.
|
||||
fac[1::2] -= 1 / 1024.
|
||||
|
||||
# this is a single pass unless there are more than k/2
|
||||
# empty centroids
|
||||
while len(empty_cents) > 0:
|
||||
# choose which centroids to split (numpy)
|
||||
probas = hassign.astype('float') - 1
|
||||
probas[probas < 0] = 0
|
||||
probas /= probas.sum()
|
||||
nnz = (probas > 0).sum()
|
||||
|
||||
nreplace = min(nnz, empty_cents.size)
|
||||
cjs = rs.choice(k, size=nreplace, p=probas)
|
||||
|
||||
for ci, cj in zip(empty_cents[:nreplace], cjs):
|
||||
|
||||
c = centroids[cj]
|
||||
centroids[ci] = c * fac
|
||||
centroids[cj] = c / fac
|
||||
|
||||
hassign[ci] = hassign[cj] // 2
|
||||
hassign[cj] -= hassign[ci]
|
||||
nsplit += 1
|
||||
|
||||
empty_cents = empty_cents[nreplace:]
|
||||
|
||||
return nsplit
|
||||
|
||||
|
||||
|
||||
def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
|
||||
return_stats=False):
|
||||
"""Pure python kmeans implementation. Follows the Faiss C++ version
|
||||
quite closely, but takes a DatasetAssign instead of a training data
|
||||
matrix. Also redo is not implemented.
|
||||
|
||||
For the torch implementation, the centroids are tensors (possibly on GPU),
|
||||
but the indices remain numpy on CPU.
|
||||
"""
|
||||
n, d = data.count(), data.dim()
|
||||
log = print if verbose else print_nop
|
||||
|
||||
log(("Clustering %d points in %dD to %d clusters, " +
|
||||
"%d iterations seed %d") % (n, d, k, niter, seed))
|
||||
|
||||
rs = np.random.RandomState(seed)
|
||||
print("preproc...")
|
||||
t0 = time.time()
|
||||
# initialization
|
||||
perm = rs.choice(n, size=k, replace=False)
|
||||
centroids = data.get_subset(perm)
|
||||
is_torch = check_if_torch(centroids)
|
||||
|
||||
iteration_stats = []
|
||||
|
||||
log(" done")
|
||||
t_search_tot = 0
|
||||
obj = []
|
||||
for i in range(niter):
|
||||
t0s = time.time()
|
||||
|
||||
log('assigning', end='\r', flush=True)
|
||||
assign, D, sums = data.assign_to(centroids)
|
||||
|
||||
log('compute centroids', end='\r', flush=True)
|
||||
|
||||
t_search_tot += time.time() - t0s;
|
||||
|
||||
err = D.sum()
|
||||
if is_torch:
|
||||
err = err.item()
|
||||
obj.append(err)
|
||||
|
||||
hassign = np.bincount(assign, minlength=k)
|
||||
|
||||
fac = hassign.reshape(-1, 1).astype('float32')
|
||||
fac[fac == 0] = 1 # quiet warning
|
||||
if is_torch:
|
||||
import torch
|
||||
fac = torch.from_numpy(fac).to(sums.device)
|
||||
|
||||
centroids = sums / fac
|
||||
|
||||
nsplit = reassign_centroids(hassign, centroids, rs)
|
||||
|
||||
s = {
|
||||
"obj": err,
|
||||
"time": (time.time() - t0),
|
||||
"time_search": t_search_tot,
|
||||
"imbalance_factor": imbalance_factor(k, assign),
|
||||
"nsplit": nsplit
|
||||
}
|
||||
|
||||
log((" Iteration %d (%.2f s, search %.2f s): "
|
||||
"objective=%g imbalance=%.3f nsplit=%d") % (
|
||||
i, s["time"], s["time_search"],
|
||||
err, s["imbalance_factor"],
|
||||
nsplit)
|
||||
)
|
||||
iteration_stats.append(s)
|
||||
|
||||
if checkpoint is not None:
|
||||
log('storing centroids in', checkpoint)
|
||||
if is_torch:
|
||||
import torch
|
||||
torch.save(centroids, checkpoint)
|
||||
else:
|
||||
np.save(checkpoint, centroids)
|
||||
|
||||
if return_stats:
|
||||
return centroids, iteration_stats
|
||||
else:
|
||||
return centroids
|
||||
387
packages/leann-backend-hnsw/third_party/faiss/contrib/datasets.py
vendored
Normal file
387
packages/leann-backend-hnsw/third_party/faiss/contrib/datasets.py
vendored
Normal file
@@ -0,0 +1,387 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import faiss
|
||||
import getpass
|
||||
|
||||
|
||||
from .vecs_io import fvecs_read, ivecs_read, bvecs_mmap, fvecs_mmap
|
||||
from .exhaustive_search import knn
|
||||
|
||||
class Dataset:
|
||||
""" Generic abstract class for a test dataset """
|
||||
|
||||
def __init__(self):
|
||||
""" the constructor should set the following fields: """
|
||||
self.d = -1
|
||||
self.metric = 'L2' # or IP
|
||||
self.nq = -1
|
||||
self.nb = -1
|
||||
self.nt = -1
|
||||
|
||||
def get_queries(self):
|
||||
""" return the queries as a (nq, d) array """
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
""" return the queries as a (nt, d) array """
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_database(self):
|
||||
""" return the queries as a (nb, d) array """
|
||||
raise NotImplementedError()
|
||||
|
||||
def database_iterator(self, bs=128, split=(1, 0)):
|
||||
"""returns an iterator on database vectors.
|
||||
bs is the number of vectors per batch
|
||||
split = (nsplit, rank) means the dataset is split in nsplit
|
||||
shards and we want shard number rank
|
||||
The default implementation just iterates over the full matrix
|
||||
returned by get_dataset.
|
||||
"""
|
||||
xb = self.get_database()
|
||||
nsplit, rank = split
|
||||
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
|
||||
for j0 in range(i0, i1, bs):
|
||||
yield xb[j0: min(j0 + bs, i1)]
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
""" return the ground truth for k-nearest neighbor search """
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_groundtruth_range(self, thresh=None):
|
||||
""" return the ground truth for range search """
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self):
|
||||
return (f"dataset in dimension {self.d}, with metric {self.metric}, "
|
||||
f"size: Q {self.nq} B {self.nb} T {self.nt}")
|
||||
|
||||
def check_sizes(self):
|
||||
""" runs the previous and checks the sizes of the matrices """
|
||||
assert self.get_queries().shape == (self.nq, self.d)
|
||||
if self.nt > 0:
|
||||
xt = self.get_train(maxtrain=123)
|
||||
assert xt.shape == (123, self.d), "shape=%s" % (xt.shape, )
|
||||
assert self.get_database().shape == (self.nb, self.d)
|
||||
assert self.get_groundtruth(k=13).shape == (self.nq, 13)
|
||||
|
||||
|
||||
class SyntheticDataset(Dataset):
|
||||
"""A dataset that is not completely random but still challenging to
|
||||
index
|
||||
"""
|
||||
|
||||
def __init__(self, d, nt, nb, nq, metric='L2', seed=1338):
|
||||
Dataset.__init__(self)
|
||||
self.d, self.nt, self.nb, self.nq = d, nt, nb, nq
|
||||
d1 = 10 # intrinsic dimension (more or less)
|
||||
n = nb + nt + nq
|
||||
rs = np.random.RandomState(seed)
|
||||
x = rs.normal(size=(n, d1))
|
||||
x = np.dot(x, rs.rand(d1, d))
|
||||
# now we have a d1-dim ellipsoid in d-dimensional space
|
||||
# higher factor (>4) -> higher frequency -> less linear
|
||||
x = x * (rs.rand(d) * 4 + 0.1)
|
||||
x = np.sin(x)
|
||||
x = x.astype('float32')
|
||||
self.metric = metric
|
||||
self.xt = x[:nt]
|
||||
self.xb = x[nt:nt + nb]
|
||||
self.xq = x[nt + nb:]
|
||||
|
||||
def get_queries(self):
|
||||
return self.xq
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return self.xt[:maxtrain]
|
||||
|
||||
def get_database(self):
|
||||
return self.xb
|
||||
|
||||
def get_groundtruth(self, k=100):
|
||||
return knn(
|
||||
self.xq, self.xb, k,
|
||||
faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT
|
||||
)[1]
|
||||
|
||||
|
||||
############################################################################
|
||||
# The following datasets are a few standard open-source datasets
|
||||
# they should be stored in a directory, and we start by guessing where
|
||||
# that directory is
|
||||
############################################################################
|
||||
|
||||
username = getpass.getuser()
|
||||
|
||||
for dataset_basedir in (
|
||||
'/datasets01/simsearch/041218/',
|
||||
'/mnt/vol/gfsai-flash3-east/ai-group/datasets/simsearch/',
|
||||
f'/home/{username}/simsearch/data/'):
|
||||
if os.path.exists(dataset_basedir):
|
||||
break
|
||||
else:
|
||||
# users can link their data directory to `./data`
|
||||
dataset_basedir = 'data/'
|
||||
|
||||
|
||||
def set_dataset_basedir(path):
|
||||
global dataset_basedir
|
||||
dataset_basedir = path
|
||||
|
||||
|
||||
class DatasetSIFT1M(Dataset):
|
||||
"""
|
||||
The original dataset is available at: http://corpus-texmex.irisa.fr/
|
||||
(ANN_SIFT1M)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
Dataset.__init__(self)
|
||||
self.d, self.nt, self.nb, self.nq = 128, 100000, 1000000, 10000
|
||||
self.basedir = dataset_basedir + 'sift1M/'
|
||||
|
||||
def get_queries(self):
|
||||
return fvecs_read(self.basedir + "sift_query.fvecs")
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return fvecs_read(self.basedir + "sift_learn.fvecs")[:maxtrain]
|
||||
|
||||
def get_database(self):
|
||||
return fvecs_read(self.basedir + "sift_base.fvecs")
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
gt = ivecs_read(self.basedir + "sift_groundtruth.ivecs")
|
||||
if k is not None:
|
||||
assert k <= 100
|
||||
gt = gt[:, :k]
|
||||
return gt
|
||||
|
||||
|
||||
def sanitize(x):
|
||||
return np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
|
||||
class DatasetBigANN(Dataset):
|
||||
"""
|
||||
The original dataset is available at: http://corpus-texmex.irisa.fr/
|
||||
(ANN_SIFT1B)
|
||||
"""
|
||||
|
||||
def __init__(self, nb_M=1000):
|
||||
Dataset.__init__(self)
|
||||
assert nb_M in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000)
|
||||
self.nb_M = nb_M
|
||||
nb = nb_M * 10**6
|
||||
self.d, self.nt, self.nb, self.nq = 128, 10**8, nb, 10000
|
||||
self.basedir = dataset_basedir + 'bigann/'
|
||||
|
||||
def get_queries(self):
|
||||
return sanitize(bvecs_mmap(self.basedir + 'bigann_query.bvecs')[:])
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return sanitize(bvecs_mmap(self.basedir + 'bigann_learn.bvecs')[:maxtrain])
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
gt = ivecs_read(self.basedir + 'gnd/idx_%dM.ivecs' % self.nb_M)
|
||||
if k is not None:
|
||||
assert k <= 100
|
||||
gt = gt[:, :k]
|
||||
return gt
|
||||
|
||||
def get_database(self):
|
||||
assert self.nb_M < 100, "dataset too large, use iterator"
|
||||
return sanitize(bvecs_mmap(self.basedir + 'bigann_base.bvecs')[:self.nb])
|
||||
|
||||
def database_iterator(self, bs=128, split=(1, 0)):
|
||||
xb = bvecs_mmap(self.basedir + 'bigann_base.bvecs')
|
||||
nsplit, rank = split
|
||||
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
|
||||
for j0 in range(i0, i1, bs):
|
||||
yield sanitize(xb[j0: min(j0 + bs, i1)])
|
||||
|
||||
|
||||
class DatasetDeep1B(Dataset):
|
||||
"""
|
||||
See
|
||||
https://github.com/facebookresearch/faiss/tree/main/benchs#getting-deep1b
|
||||
on how to get the data
|
||||
"""
|
||||
|
||||
def __init__(self, nb=10**9):
|
||||
Dataset.__init__(self)
|
||||
nb_to_name = {
|
||||
10**5: '100k',
|
||||
10**6: '1M',
|
||||
10**7: '10M',
|
||||
10**8: '100M',
|
||||
10**9: '1B'
|
||||
}
|
||||
assert nb in nb_to_name
|
||||
self.d, self.nt, self.nb, self.nq = 96, 358480000, nb, 10000
|
||||
self.basedir = dataset_basedir + 'deep1b/'
|
||||
self.gt_fname = "%sdeep%s_groundtruth.ivecs" % (
|
||||
self.basedir, nb_to_name[self.nb])
|
||||
|
||||
def get_queries(self):
|
||||
return sanitize(fvecs_read(self.basedir + "deep1B_queries.fvecs"))
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return sanitize(fvecs_mmap(self.basedir + "learn.fvecs")[:maxtrain])
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
gt = ivecs_read(self.gt_fname)
|
||||
if k is not None:
|
||||
assert k <= 100
|
||||
gt = gt[:, :k]
|
||||
return gt
|
||||
|
||||
def get_database(self):
|
||||
assert self.nb <= 10**8, "dataset too large, use iterator"
|
||||
return sanitize(fvecs_mmap(self.basedir + "base.fvecs")[:self.nb])
|
||||
|
||||
def database_iterator(self, bs=128, split=(1, 0)):
|
||||
xb = fvecs_mmap(self.basedir + "base.fvecs")
|
||||
nsplit, rank = split
|
||||
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
|
||||
for j0 in range(i0, i1, bs):
|
||||
yield sanitize(xb[j0: min(j0 + bs, i1)])
|
||||
|
||||
|
||||
class DatasetGlove(Dataset):
|
||||
"""
|
||||
Data from http://ann-benchmarks.com/glove-100-angular.hdf5
|
||||
"""
|
||||
|
||||
def __init__(self, loc=None, download=False):
|
||||
import h5py
|
||||
assert not download, "not implemented"
|
||||
if not loc:
|
||||
loc = dataset_basedir + 'glove/glove-100-angular.hdf5'
|
||||
self.glove_h5py = h5py.File(loc, 'r')
|
||||
# IP and L2 are equivalent in this case, but it is traditionally seen as an IP dataset
|
||||
self.metric = 'IP'
|
||||
self.d, self.nt = 100, 0
|
||||
self.nb = self.glove_h5py['train'].shape[0]
|
||||
self.nq = self.glove_h5py['test'].shape[0]
|
||||
|
||||
def get_queries(self):
|
||||
xq = np.array(self.glove_h5py['test'])
|
||||
faiss.normalize_L2(xq)
|
||||
return xq
|
||||
|
||||
def get_database(self):
|
||||
xb = np.array(self.glove_h5py['train'])
|
||||
faiss.normalize_L2(xb)
|
||||
return xb
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
gt = self.glove_h5py['neighbors']
|
||||
if k is not None:
|
||||
assert k <= 100
|
||||
gt = gt[:, :k]
|
||||
return gt
|
||||
|
||||
|
||||
class DatasetMusic100(Dataset):
|
||||
"""
|
||||
get dataset from
|
||||
https://github.com/stanis-morozov/ip-nsw#dataset
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
Dataset.__init__(self)
|
||||
self.d, self.nt, self.nb, self.nq = 100, 0, 10**6, 10000
|
||||
self.metric = 'IP'
|
||||
self.basedir = dataset_basedir + 'music-100/'
|
||||
|
||||
def get_queries(self):
|
||||
xq = np.fromfile(self.basedir + 'query_music100.bin', dtype='float32')
|
||||
xq = xq.reshape(-1, 100)
|
||||
return xq
|
||||
|
||||
def get_database(self):
|
||||
xb = np.fromfile(self.basedir + 'database_music100.bin', dtype='float32')
|
||||
xb = xb.reshape(-1, 100)
|
||||
return xb
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
gt = np.load(self.basedir + 'gt.npy')
|
||||
if k is not None:
|
||||
assert k <= 100
|
||||
gt = gt[:, :k]
|
||||
return gt
|
||||
|
||||
class DatasetGIST1M(Dataset):
|
||||
"""
|
||||
The original dataset is available at: http://corpus-texmex.irisa.fr/
|
||||
(ANN_SIFT1M)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
Dataset.__init__(self)
|
||||
self.d, self.nt, self.nb, self.nq = 960, 100000, 1000000, 10000
|
||||
self.basedir = dataset_basedir + 'gist1M/'
|
||||
|
||||
def get_queries(self):
|
||||
return fvecs_read(self.basedir + "gist_query.fvecs")
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return fvecs_read(self.basedir + "gist_learn.fvecs")[:maxtrain]
|
||||
|
||||
def get_database(self):
|
||||
return fvecs_read(self.basedir + "gist_base.fvecs")
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
gt = ivecs_read(self.basedir + "gist_groundtruth.ivecs")
|
||||
if k is not None:
|
||||
assert k <= 100
|
||||
gt = gt[:, :k]
|
||||
return gt
|
||||
|
||||
|
||||
def dataset_from_name(dataset='deep1M', download=False):
|
||||
""" converts a string describing a dataset to a Dataset object
|
||||
Supports sift1M, bigann1M..bigann1B, deep1M..deep1B, music-100 and glove
|
||||
"""
|
||||
|
||||
if dataset == 'sift1M':
|
||||
return DatasetSIFT1M()
|
||||
|
||||
elif dataset == 'gist1M':
|
||||
return DatasetGIST1M()
|
||||
|
||||
elif dataset.startswith('bigann'):
|
||||
dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1])
|
||||
return DatasetBigANN(nb_M=dbsize)
|
||||
|
||||
elif dataset.startswith("deep"):
|
||||
|
||||
szsuf = dataset[4:]
|
||||
if szsuf[-1] == 'M':
|
||||
dbsize = 10 ** 6 * int(szsuf[:-1])
|
||||
elif szsuf == '1B':
|
||||
dbsize = 10 ** 9
|
||||
elif szsuf[-1] == 'k':
|
||||
dbsize = 1000 * int(szsuf[:-1])
|
||||
else:
|
||||
assert False, "did not recognize suffix " + szsuf
|
||||
return DatasetDeep1B(nb=dbsize)
|
||||
|
||||
elif dataset == "music-100":
|
||||
return DatasetMusic100()
|
||||
|
||||
elif dataset == "glove":
|
||||
return DatasetGlove(download=download)
|
||||
|
||||
else:
|
||||
raise RuntimeError("unknown dataset " + dataset)
|
||||
492
packages/leann-backend-hnsw/third_party/faiss/contrib/evaluation.py
vendored
Normal file
492
packages/leann-backend-hnsw/third_party/faiss/contrib/evaluation.py
vendored
Normal file
@@ -0,0 +1,492 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import time
|
||||
import faiss
|
||||
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
###############################################################
|
||||
# Simple functions to evaluate knn results
|
||||
|
||||
def knn_intersection_measure(I1, I2):
|
||||
""" computes the intersection measure of two result tables
|
||||
"""
|
||||
nq, rank = I1.shape
|
||||
assert I2.shape == (nq, rank)
|
||||
ninter = sum(
|
||||
np.intersect1d(I1[i], I2[i]).size
|
||||
for i in range(nq)
|
||||
)
|
||||
return ninter / I1.size
|
||||
|
||||
###############################################################
|
||||
# Range search results can be compared with Precision-Recall
|
||||
|
||||
def filter_range_results(lims, D, I, thresh):
|
||||
""" select a set of results """
|
||||
nq = lims.size - 1
|
||||
mask = D < thresh
|
||||
new_lims = np.zeros_like(lims)
|
||||
for i in range(nq):
|
||||
new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum()
|
||||
return new_lims, D[mask], I[mask]
|
||||
|
||||
|
||||
def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"):
|
||||
"""compute the precision and recall of range search results. The
|
||||
function does not take the distances into account. """
|
||||
|
||||
def ref_result_for(i):
|
||||
return Iref[lims_ref[i]:lims_ref[i + 1]]
|
||||
|
||||
def new_result_for(i):
|
||||
return Inew[lims_new[i]:lims_new[i + 1]]
|
||||
|
||||
nq = lims_ref.size - 1
|
||||
assert lims_new.size - 1 == nq
|
||||
|
||||
ninter = np.zeros(nq, dtype="int64")
|
||||
|
||||
def compute_PR_for(q):
|
||||
|
||||
# ground truth results for this query
|
||||
gt_ids = ref_result_for(q)
|
||||
|
||||
# results for this query
|
||||
new_ids = new_result_for(q)
|
||||
|
||||
# there are no set functions in numpy so let's do this
|
||||
inter = np.intersect1d(gt_ids, new_ids)
|
||||
|
||||
ninter[q] = len(inter)
|
||||
|
||||
# run in a thread pool, which helps in spite of the GIL
|
||||
pool = ThreadPool(20)
|
||||
pool.map(compute_PR_for, range(nq))
|
||||
|
||||
return counts_to_PR(
|
||||
lims_ref[1:] - lims_ref[:-1],
|
||||
lims_new[1:] - lims_new[:-1],
|
||||
ninter,
|
||||
mode=mode
|
||||
)
|
||||
|
||||
|
||||
def counts_to_PR(ngt, nres, ninter, mode="overall"):
|
||||
""" computes a precision-recall for a ser of queries.
|
||||
ngt = nb of GT results per query
|
||||
nres = nb of found results per query
|
||||
ninter = nb of correct results per query (smaller than nres of course)
|
||||
"""
|
||||
|
||||
if mode == "overall":
|
||||
ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum()
|
||||
|
||||
if nres > 0:
|
||||
precision = ninter / nres
|
||||
else:
|
||||
precision = 1.0
|
||||
|
||||
if ngt > 0:
|
||||
recall = ninter / ngt
|
||||
elif nres == 0:
|
||||
recall = 1.0
|
||||
else:
|
||||
recall = 0.0
|
||||
|
||||
return precision, recall
|
||||
|
||||
elif mode == "average":
|
||||
# average precision and recall over queries
|
||||
|
||||
mask = ngt == 0
|
||||
ngt[mask] = 1
|
||||
|
||||
recalls = ninter / ngt
|
||||
recalls[mask] = (nres[mask] == 0).astype(float)
|
||||
|
||||
# avoid division by 0
|
||||
mask = nres == 0
|
||||
assert np.all(ninter[mask] == 0)
|
||||
ninter[mask] = 1
|
||||
nres[mask] = 1
|
||||
|
||||
precisions = ninter / nres
|
||||
|
||||
return precisions.mean(), recalls.mean()
|
||||
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
def sort_range_res_2(lims, D, I):
|
||||
""" sort 2 arrays using the first as key """
|
||||
I2 = np.empty_like(I)
|
||||
D2 = np.empty_like(D)
|
||||
nq = len(lims) - 1
|
||||
for i in range(nq):
|
||||
l0, l1 = lims[i], lims[i + 1]
|
||||
ii = I[l0:l1]
|
||||
di = D[l0:l1]
|
||||
o = di.argsort()
|
||||
I2[l0:l1] = ii[o]
|
||||
D2[l0:l1] = di[o]
|
||||
return I2, D2
|
||||
|
||||
|
||||
def sort_range_res_1(lims, I):
|
||||
I2 = np.empty_like(I)
|
||||
nq = len(lims) - 1
|
||||
for i in range(nq):
|
||||
l0, l1 = lims[i], lims[i + 1]
|
||||
I2[l0:l1] = I[l0:l1]
|
||||
I2[l0:l1].sort()
|
||||
return I2
|
||||
|
||||
|
||||
def range_PR_multiple_thresholds(
|
||||
lims_ref, Iref,
|
||||
lims_new, Dnew, Inew,
|
||||
thresholds,
|
||||
mode="overall", do_sort="ref,new"
|
||||
):
|
||||
""" compute precision-recall values for range search results
|
||||
for several thresholds on the "new" results.
|
||||
This is to plot PR curves
|
||||
"""
|
||||
# ref should be sorted by ids
|
||||
if "ref" in do_sort:
|
||||
Iref = sort_range_res_1(lims_ref, Iref)
|
||||
|
||||
# new should be sorted by distances
|
||||
if "new" in do_sort:
|
||||
Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew)
|
||||
|
||||
def ref_result_for(i):
|
||||
return Iref[lims_ref[i]:lims_ref[i + 1]]
|
||||
|
||||
def new_result_for(i):
|
||||
l0, l1 = lims_new[i], lims_new[i + 1]
|
||||
return Inew[l0:l1], Dnew[l0:l1]
|
||||
|
||||
nq = lims_ref.size - 1
|
||||
assert lims_new.size - 1 == nq
|
||||
|
||||
nt = len(thresholds)
|
||||
counts = np.zeros((nq, nt, 3), dtype="int64")
|
||||
|
||||
def compute_PR_for(q):
|
||||
gt_ids = ref_result_for(q)
|
||||
res_ids, res_dis = new_result_for(q)
|
||||
|
||||
counts[q, :, 0] = len(gt_ids)
|
||||
|
||||
if res_dis.size == 0:
|
||||
# the rest remains at 0
|
||||
return
|
||||
|
||||
# which offsets we are interested in
|
||||
nres= np.searchsorted(res_dis, thresholds)
|
||||
counts[q, :, 1] = nres
|
||||
|
||||
if gt_ids.size == 0:
|
||||
return
|
||||
|
||||
# find number of TPs at each stage in the result list
|
||||
ii = np.searchsorted(gt_ids, res_ids)
|
||||
ii[ii == len(gt_ids)] = -1
|
||||
n_ok = np.cumsum(gt_ids[ii] == res_ids)
|
||||
|
||||
# focus on threshold points
|
||||
n_ok = np.hstack(([0], n_ok))
|
||||
counts[q, :, 2] = n_ok[nres]
|
||||
|
||||
pool = ThreadPool(20)
|
||||
pool.map(compute_PR_for, range(nq))
|
||||
# print(counts.transpose(2, 1, 0))
|
||||
|
||||
precisions = np.zeros(nt)
|
||||
recalls = np.zeros(nt)
|
||||
for t in range(nt):
|
||||
p, r = counts_to_PR(
|
||||
counts[:, t, 0], counts[:, t, 1], counts[:, t, 2],
|
||||
mode=mode
|
||||
)
|
||||
precisions[t] = p
|
||||
recalls[t] = r
|
||||
|
||||
return precisions, recalls
|
||||
|
||||
|
||||
###############################################################
|
||||
# Functions that compare search results with a reference result.
|
||||
# They are intended for use in tests
|
||||
|
||||
def _cluster_tables_with_tolerance(tab1, tab2, thr):
|
||||
""" for two tables, cluster them by merging values closer than thr.
|
||||
Returns the cluster ids for each table element """
|
||||
tab = np.hstack([tab1, tab2])
|
||||
tab.sort()
|
||||
n = len(tab)
|
||||
diffs = np.ones(n)
|
||||
diffs[1:] = tab[1:] - tab[:-1]
|
||||
unique_vals = tab[diffs > thr]
|
||||
idx1 = np.searchsorted(unique_vals, tab1, side='right') - 1
|
||||
idx2 = np.searchsorted(unique_vals, tab2, side='right') - 1
|
||||
return idx1, idx2
|
||||
|
||||
|
||||
def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5):
|
||||
""" test that knn search results are identical, with possible ties.
|
||||
Raise if not. """
|
||||
np.testing.assert_allclose(Dref, Dnew, rtol=rtol)
|
||||
# here we have to be careful because of draws
|
||||
testcase = unittest.TestCase() # because it makes nice error messages
|
||||
for i in range(len(Iref)):
|
||||
if np.all(Iref[i] == Inew[i]): # easy case
|
||||
continue
|
||||
|
||||
# otherwise collect elements per distance
|
||||
r = rtol * Dref[i].max()
|
||||
|
||||
DrefC, DnewC = _cluster_tables_with_tolerance(Dref[i], Dnew[i], r)
|
||||
|
||||
for dis in np.unique(DrefC):
|
||||
if dis == DrefC[-1]:
|
||||
continue
|
||||
mask = DrefC == dis
|
||||
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
|
||||
|
||||
|
||||
def check_ref_range_results(Lref, Dref, Iref,
|
||||
Lnew, Dnew, Inew):
|
||||
""" compare range search results wrt. a reference result,
|
||||
throw if it fails """
|
||||
np.testing.assert_array_equal(Lref, Lnew)
|
||||
nq = len(Lref) - 1
|
||||
for i in range(nq):
|
||||
l0, l1 = Lref[i], Lref[i + 1]
|
||||
Ii_ref = Iref[l0:l1]
|
||||
Ii_new = Inew[l0:l1]
|
||||
Di_ref = Dref[l0:l1]
|
||||
Di_new = Dnew[l0:l1]
|
||||
if np.all(Ii_ref == Ii_new): # easy
|
||||
pass
|
||||
else:
|
||||
def sort_by_ids(I, D):
|
||||
o = I.argsort()
|
||||
return I[o], D[o]
|
||||
# sort both
|
||||
(Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref)
|
||||
(Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new)
|
||||
np.testing.assert_array_equal(Ii_ref, Ii_new)
|
||||
np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5)
|
||||
|
||||
|
||||
###############################################################
|
||||
# OperatingPoints functions
|
||||
# this is the Python version of the AutoTune object in C++
|
||||
|
||||
class OperatingPoints:
|
||||
"""
|
||||
Manages a set of search parameters with associated performance and time.
|
||||
Keeps the Pareto optimal points.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# list of (key, perf, t)
|
||||
self.operating_points = [
|
||||
# (self.do_nothing_key(), 0.0, 0.0)
|
||||
]
|
||||
self.suboptimal_points = []
|
||||
|
||||
def compare_keys(self, k1, k2):
|
||||
""" return -1 if k1 > k2, 1 if k2 > k1, 0 otherwise """
|
||||
raise NotImplemented
|
||||
|
||||
def do_nothing_key(self):
|
||||
""" parameters to say we do noting, takes 0 time and has 0 performance"""
|
||||
raise NotImplemented
|
||||
|
||||
def is_pareto_optimal(self, perf_new, t_new):
|
||||
for _, perf, t in self.operating_points:
|
||||
if perf >= perf_new and t <= t_new:
|
||||
return False
|
||||
return True
|
||||
|
||||
def predict_bounds(self, key):
|
||||
""" predicts the bound on time and performance """
|
||||
min_time = 0.0
|
||||
max_perf = 1.0
|
||||
for key2, perf, t in self.operating_points + self.suboptimal_points:
|
||||
cmp = self.compare_keys(key, key2)
|
||||
if cmp > 0: # key2 > key
|
||||
if t > min_time:
|
||||
min_time = t
|
||||
if cmp < 0: # key2 < key
|
||||
if perf < max_perf:
|
||||
max_perf = perf
|
||||
return max_perf, min_time
|
||||
|
||||
def should_run_experiment(self, key):
|
||||
(max_perf, min_time) = self.predict_bounds(key)
|
||||
return self.is_pareto_optimal(max_perf, min_time)
|
||||
|
||||
def add_operating_point(self, key, perf, t):
|
||||
if self.is_pareto_optimal(perf, t):
|
||||
i = 0
|
||||
# maybe it shadows some other operating point completely?
|
||||
while i < len(self.operating_points):
|
||||
op_Ls, perf2, t2 = self.operating_points[i]
|
||||
if perf >= perf2 and t < t2:
|
||||
self.suboptimal_points.append(
|
||||
self.operating_points.pop(i))
|
||||
else:
|
||||
i += 1
|
||||
self.operating_points.append((key, perf, t))
|
||||
return True
|
||||
else:
|
||||
self.suboptimal_points.append((key, perf, t))
|
||||
return False
|
||||
|
||||
|
||||
class OperatingPointsWithRanges(OperatingPoints):
|
||||
"""
|
||||
Set of parameters that are each picked from a discrete range of values.
|
||||
An increase of each parameter is assumed to make the operation slower
|
||||
and more accurate.
|
||||
A key = int array of indices in the ordered set of parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
OperatingPoints.__init__(self)
|
||||
# list of (name, values)
|
||||
self.ranges = []
|
||||
|
||||
def add_range(self, name, values):
|
||||
self.ranges.append((name, values))
|
||||
|
||||
def compare_keys(self, k1, k2):
|
||||
if np.all(k1 >= k2):
|
||||
return 1
|
||||
if np.all(k2 >= k1):
|
||||
return -1
|
||||
return 0
|
||||
|
||||
def do_nothing_key(self):
|
||||
return np.zeros(len(self.ranges), dtype=int)
|
||||
|
||||
def num_experiments(self):
|
||||
return int(np.prod([len(values) for name, values in self.ranges]))
|
||||
|
||||
def sample_experiments(self, n_autotune, rs=np.random):
|
||||
""" sample a set of experiments of max size n_autotune
|
||||
(run all experiments in random order if n_autotune is 0)
|
||||
"""
|
||||
assert n_autotune == 0 or n_autotune >= 2
|
||||
totex = self.num_experiments()
|
||||
rs = np.random.RandomState(123)
|
||||
if n_autotune == 0 or totex < n_autotune:
|
||||
experiments = rs.permutation(totex - 2)
|
||||
else:
|
||||
experiments = rs.choice(
|
||||
totex - 2, size=n_autotune - 2, replace=False)
|
||||
|
||||
experiments = [0, totex - 1] + [int(cno) + 1 for cno in experiments]
|
||||
return experiments
|
||||
|
||||
def cno_to_key(self, cno):
|
||||
"""Convert a sequential experiment number to a key"""
|
||||
k = np.zeros(len(self.ranges), dtype=int)
|
||||
for i, (name, values) in enumerate(self.ranges):
|
||||
k[i] = cno % len(values)
|
||||
cno //= len(values)
|
||||
assert cno == 0
|
||||
return k
|
||||
|
||||
def get_parameters(self, k):
|
||||
"""Convert a key to a dictionary with parameter values"""
|
||||
return {
|
||||
name: values[k[i]]
|
||||
for i, (name, values) in enumerate(self.ranges)
|
||||
}
|
||||
|
||||
def restrict_range(self, name, max_val):
|
||||
""" remove too large values from a range"""
|
||||
for name2, values in self.ranges:
|
||||
if name == name2:
|
||||
val2 = [v for v in values if v < max_val]
|
||||
values[:] = val2
|
||||
return
|
||||
raise RuntimeError(f"parameter {name} not found")
|
||||
|
||||
|
||||
###############################################################
|
||||
# Timer object
|
||||
|
||||
class TimerIter:
|
||||
def __init__(self, timer):
|
||||
self.ts = []
|
||||
self.runs = timer.runs
|
||||
self.timer = timer
|
||||
if timer.nt >= 0:
|
||||
faiss.omp_set_num_threads(timer.nt)
|
||||
|
||||
def __next__(self):
|
||||
timer = self.timer
|
||||
self.runs -= 1
|
||||
self.ts.append(time.time())
|
||||
total_time = self.ts[-1] - self.ts[0] if len(self.ts) >= 2 else 0
|
||||
if self.runs == -1 or total_time > timer.max_secs:
|
||||
if timer.nt >= 0:
|
||||
faiss.omp_set_num_threads(timer.remember_nt)
|
||||
ts = np.array(self.ts)
|
||||
times = ts[1:] - ts[:-1]
|
||||
if len(times) == timer.runs:
|
||||
timer.times = times[timer.warmup :]
|
||||
else:
|
||||
# if timeout, we use all the runs
|
||||
timer.times = times[:]
|
||||
raise StopIteration
|
||||
|
||||
class RepeatTimer:
|
||||
"""
|
||||
This is yet another timer object. It is adapted to Faiss by
|
||||
taking a number of openmp threads to set on input. It should be called
|
||||
in an explicit loop as:
|
||||
|
||||
timer = RepeatTimer(warmup=1, nt=1, runs=6)
|
||||
|
||||
for _ in timer:
|
||||
# perform operation
|
||||
|
||||
print(f"time={timer.get_ms():.1f} ± {timer.get_ms_std():.1f} ms")
|
||||
|
||||
the same timer can be re-used. In that case it is reset each time it
|
||||
enters a loop. It focuses on ms-scale times because for second scale
|
||||
it's usually less relevant to repeat the operation.
|
||||
"""
|
||||
def __init__(self, warmup=0, nt=-1, runs=1, max_secs=np.inf):
|
||||
assert warmup < runs
|
||||
self.warmup = warmup
|
||||
self.nt = nt
|
||||
self.runs = runs
|
||||
self.max_secs = max_secs
|
||||
self.remember_nt = faiss.omp_get_max_threads()
|
||||
|
||||
def __iter__(self):
|
||||
return TimerIter(self)
|
||||
|
||||
def ms(self):
|
||||
return np.mean(self.times) * 1000
|
||||
|
||||
def ms_std(self):
|
||||
return np.std(self.times) * 1000 if len(self.times) > 1 else 0.0
|
||||
|
||||
def nruns(self):
|
||||
""" effective number of runs (may be lower than runs - warmup due to timeout)"""
|
||||
return len(self.times)
|
||||
367
packages/leann-backend-hnsw/third_party/faiss/contrib/exhaustive_search.py
vendored
Normal file
367
packages/leann-backend-hnsw/third_party/faiss/contrib/exhaustive_search.py
vendored
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import faiss
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2, shard=False, ngpu=-1):
|
||||
"""Computes the exact KNN search results for a dataset that possibly
|
||||
does not fit in RAM but for which we have an iterator that
|
||||
returns it block by block.
|
||||
"""
|
||||
LOG.info("knn_ground_truth queries size %s k=%d" % (xq.shape, k))
|
||||
t0 = time.time()
|
||||
nq, d = xq.shape
|
||||
keep_max = faiss.is_similarity_metric(metric_type)
|
||||
rh = faiss.ResultHeap(nq, k, keep_max=keep_max)
|
||||
|
||||
index = faiss.IndexFlat(d, metric_type)
|
||||
if ngpu == -1:
|
||||
ngpu = faiss.get_num_gpus()
|
||||
|
||||
if ngpu:
|
||||
LOG.info('running on %d GPUs' % ngpu)
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = shard
|
||||
index = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
|
||||
|
||||
# compute ground-truth by blocks, and add to heaps
|
||||
i0 = 0
|
||||
for xbi in db_iterator:
|
||||
ni = xbi.shape[0]
|
||||
index.add(xbi)
|
||||
D, I = index.search(xq, k)
|
||||
I += i0
|
||||
rh.add_result(D, I)
|
||||
index.reset()
|
||||
i0 += ni
|
||||
LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
|
||||
|
||||
rh.finalize()
|
||||
LOG.info("GT time: %.3f s (%d vectors)" % (time.time() - t0, i0))
|
||||
|
||||
return rh.D, rh.I
|
||||
|
||||
# knn function used to be here
|
||||
knn = faiss.knn
|
||||
|
||||
|
||||
|
||||
|
||||
def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
|
||||
"""GPU does not support range search, so we emulate it with
|
||||
knn search + fallback to CPU index.
|
||||
|
||||
The index_cpu can either be:
|
||||
- a CPU index that supports range search
|
||||
- a numpy table, that will be used to construct a Flat index if needed.
|
||||
- None. In that case, at most gpu_k results will be returned
|
||||
"""
|
||||
nq, d = xq.shape
|
||||
is_binary_index = isinstance(index_gpu, faiss.IndexBinary)
|
||||
keep_max = faiss.is_similarity_metric(index_gpu.metric_type)
|
||||
r2 = int(r2) if is_binary_index else float(r2)
|
||||
k = min(index_gpu.ntotal, gpu_k)
|
||||
LOG.debug(
|
||||
f"GPU search {nq} queries with {k=:} {is_binary_index=:} {keep_max=:}")
|
||||
t0 = time.time()
|
||||
D, I = index_gpu.search(xq, k)
|
||||
t1 = time.time() - t0
|
||||
if is_binary_index:
|
||||
assert d * 8 < 32768 # let's compact the distance matrix
|
||||
D = D.astype('int16')
|
||||
t2 = 0
|
||||
lim_remain = None
|
||||
if index_cpu is not None:
|
||||
if not keep_max:
|
||||
mask = D[:, k - 1] < r2
|
||||
else:
|
||||
mask = D[:, k - 1] > r2
|
||||
if mask.sum() > 0:
|
||||
LOG.debug("CPU search remain %d" % mask.sum())
|
||||
t0 = time.time()
|
||||
if isinstance(index_cpu, np.ndarray):
|
||||
# then it in fact an array that we have to make flat
|
||||
xb = index_cpu
|
||||
if is_binary_index:
|
||||
index_cpu = faiss.IndexBinaryFlat(d * 8)
|
||||
else:
|
||||
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
|
||||
index_cpu.add(xb)
|
||||
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
|
||||
if is_binary_index:
|
||||
D_remain = D_remain.astype('int16')
|
||||
t2 = time.time() - t0
|
||||
LOG.debug("combine")
|
||||
t0 = time.time()
|
||||
|
||||
CombinerRangeKNN = (
|
||||
faiss.CombinerRangeKNNint16 if is_binary_index else
|
||||
faiss.CombinerRangeKNNfloat
|
||||
)
|
||||
|
||||
combiner = CombinerRangeKNN(nq, k, r2, keep_max)
|
||||
if True:
|
||||
sp = faiss.swig_ptr
|
||||
combiner.I = sp(I)
|
||||
combiner.D = sp(D)
|
||||
# combiner.set_knn_result(sp(I), sp(D))
|
||||
if lim_remain is not None:
|
||||
combiner.mask = sp(mask)
|
||||
combiner.D_remain = sp(D_remain)
|
||||
combiner.lim_remain = sp(lim_remain.view("int64"))
|
||||
combiner.I_remain = sp(I_remain)
|
||||
# combiner.set_range_result(sp(mask), sp(lim_remain.view("int64")), sp(D_remain), sp(I_remain))
|
||||
L_res = np.empty(nq + 1, dtype='int64')
|
||||
combiner.compute_sizes(sp(L_res))
|
||||
nres = L_res[-1]
|
||||
D_res = np.empty(nres, dtype=D.dtype)
|
||||
I_res = np.empty(nres, dtype='int64')
|
||||
combiner.write_result(sp(D_res), sp(I_res))
|
||||
else:
|
||||
D_res, I_res = [], []
|
||||
nr = 0
|
||||
for i in range(nq):
|
||||
if not mask[i]:
|
||||
if index_gpu.metric_type == faiss.METRIC_L2:
|
||||
nv = (D[i, :] < r2).sum()
|
||||
else:
|
||||
nv = (D[i, :] > r2).sum()
|
||||
D_res.append(D[i, :nv])
|
||||
I_res.append(I[i, :nv])
|
||||
else:
|
||||
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
|
||||
D_res.append(D_remain[l0:l1])
|
||||
I_res.append(I_remain[l0:l1])
|
||||
nr += 1
|
||||
L_res = np.cumsum([0] + [len(di) for di in D_res])
|
||||
D_res = np.hstack(D_res)
|
||||
I_res = np.hstack(I_res)
|
||||
t3 = time.time() - t0
|
||||
LOG.debug(f"times {t1:.3f}s {t2:.3f}s {t3:.3f}s")
|
||||
return L_res, D_res, I_res
|
||||
|
||||
|
||||
def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
|
||||
shard=False, ngpu=-1):
|
||||
"""Computes the range-search search results for a dataset that possibly
|
||||
does not fit in RAM but for which we have an iterator that
|
||||
returns it block by block.
|
||||
"""
|
||||
nq, d = xq.shape
|
||||
t0 = time.time()
|
||||
xq = np.ascontiguousarray(xq, dtype='float32')
|
||||
|
||||
index = faiss.IndexFlat(d, metric_type)
|
||||
if ngpu == -1:
|
||||
ngpu = faiss.get_num_gpus()
|
||||
if ngpu:
|
||||
LOG.info('running on %d GPUs' % ngpu)
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = shard
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
|
||||
|
||||
# compute ground-truth by blocks
|
||||
i0 = 0
|
||||
D = [[] for _i in range(nq)]
|
||||
I = [[] for _i in range(nq)]
|
||||
for xbi in db_iterator:
|
||||
ni = xbi.shape[0]
|
||||
if ngpu > 0:
|
||||
index_gpu.add(xbi)
|
||||
lims_i, Di, Ii = range_search_gpu(xq, threshold, index_gpu, xbi)
|
||||
index_gpu.reset()
|
||||
else:
|
||||
index.add(xbi)
|
||||
lims_i, Di, Ii = index.range_search(xq, threshold)
|
||||
index.reset()
|
||||
Ii += i0
|
||||
for j in range(nq):
|
||||
l0, l1 = lims_i[j], lims_i[j + 1]
|
||||
if l1 > l0:
|
||||
D[j].append(Di[l0:l1])
|
||||
I[j].append(Ii[l0:l1])
|
||||
i0 += ni
|
||||
LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
|
||||
|
||||
empty_I = np.zeros(0, dtype='int64')
|
||||
empty_D = np.zeros(0, dtype='float32')
|
||||
# import pdb; pdb.set_trace()
|
||||
D = [(np.hstack(i) if i != [] else empty_D) for i in D]
|
||||
I = [(np.hstack(i) if i != [] else empty_I) for i in I]
|
||||
sizes = [len(i) for i in I]
|
||||
assert len(sizes) == nq
|
||||
lims = np.zeros(nq + 1, dtype="uint64")
|
||||
lims[1:] = np.cumsum(sizes)
|
||||
return lims, np.hstack(D), np.hstack(I)
|
||||
|
||||
|
||||
def threshold_radius_nres(nres, dis, ids, thresh, keep_max=False):
|
||||
""" select a set of results """
|
||||
if keep_max:
|
||||
mask = dis > thresh
|
||||
else:
|
||||
mask = dis < thresh
|
||||
new_nres = np.zeros_like(nres)
|
||||
o = 0
|
||||
for i, nr in enumerate(nres):
|
||||
nr = int(nr) # avoid issues with int64 + uint64
|
||||
new_nres[i] = mask[o:o + nr].sum()
|
||||
o += nr
|
||||
return new_nres, dis[mask], ids[mask]
|
||||
|
||||
|
||||
def threshold_radius(lims, dis, ids, thresh, keep_max=False):
|
||||
""" restrict range-search results to those below a given radius """
|
||||
if keep_max:
|
||||
mask = dis > thresh
|
||||
else:
|
||||
mask = dis < thresh
|
||||
new_lims = np.zeros_like(lims)
|
||||
n = len(lims) - 1
|
||||
for i in range(n):
|
||||
l0, l1 = lims[i], lims[i + 1]
|
||||
new_lims[i + 1] = new_lims[i] + mask[l0:l1].sum()
|
||||
return new_lims, dis[mask], ids[mask]
|
||||
|
||||
|
||||
def apply_maxres(res_batches, target_nres, keep_max=False):
|
||||
"""find radius that reduces number of results to target_nres, and
|
||||
applies it in-place to the result batches used in
|
||||
range_search_max_results"""
|
||||
alldis = np.hstack([dis for _, dis, _ in res_batches])
|
||||
assert len(alldis) > target_nres
|
||||
if keep_max:
|
||||
alldis.partition(len(alldis) - target_nres - 1)
|
||||
radius = alldis[-1 - target_nres]
|
||||
else:
|
||||
alldis.partition(target_nres)
|
||||
radius = alldis[target_nres]
|
||||
|
||||
if alldis.dtype == 'float32':
|
||||
radius = float(radius)
|
||||
else:
|
||||
radius = int(radius)
|
||||
LOG.debug(' setting radius to %s' % radius)
|
||||
totres = 0
|
||||
for i, (nres, dis, ids) in enumerate(res_batches):
|
||||
nres, dis, ids = threshold_radius_nres(
|
||||
nres, dis, ids, radius, keep_max=keep_max)
|
||||
totres += len(dis)
|
||||
res_batches[i] = nres, dis, ids
|
||||
LOG.debug(' updated previous results, new nb results %d' % totres)
|
||||
return radius, totres
|
||||
|
||||
|
||||
def range_search_max_results(index, query_iterator, radius,
|
||||
max_results=None, min_results=None,
|
||||
shard=False, ngpu=0, clip_to_min=False):
|
||||
"""Performs a range search with many queries (given by an iterator)
|
||||
and adjusts the threshold on-the-fly so that the total results
|
||||
table does not grow larger than max_results.
|
||||
|
||||
If ngpu != 0, the function moves the index to this many GPUs to
|
||||
speed up search.
|
||||
"""
|
||||
# TODO: all result manipulations are in python, should move to C++ if perf
|
||||
# critical
|
||||
is_binary_index = isinstance(index, faiss.IndexBinary)
|
||||
|
||||
if min_results is None:
|
||||
assert max_results is not None
|
||||
min_results = int(0.8 * max_results)
|
||||
|
||||
if max_results is None:
|
||||
assert min_results is not None
|
||||
max_results = int(min_results * 1.5)
|
||||
|
||||
if ngpu == -1:
|
||||
ngpu = faiss.get_num_gpus()
|
||||
|
||||
if ngpu:
|
||||
LOG.info('running on %d GPUs' % ngpu)
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = shard
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
|
||||
else:
|
||||
index_gpu = None
|
||||
|
||||
t_start = time.time()
|
||||
t_search = t_post_process = 0
|
||||
qtot = totres = raw_totres = 0
|
||||
res_batches = []
|
||||
|
||||
for xqi in query_iterator:
|
||||
t0 = time.time()
|
||||
LOG.debug(f"searching {len(xqi)} vectors")
|
||||
if index_gpu:
|
||||
lims_i, Di, Ii = range_search_gpu(xqi, radius, index_gpu, index)
|
||||
else:
|
||||
lims_i, Di, Ii = index.range_search(xqi, radius)
|
||||
|
||||
nres_i = lims_i[1:] - lims_i[:-1]
|
||||
raw_totres += len(Di)
|
||||
qtot += len(xqi)
|
||||
|
||||
t1 = time.time()
|
||||
if is_binary_index:
|
||||
# weird Faiss quirk that returns floats for Hamming distances
|
||||
Di = Di.astype('int16')
|
||||
|
||||
totres += len(Di)
|
||||
res_batches.append((nres_i, Di, Ii))
|
||||
|
||||
if max_results is not None and totres > max_results:
|
||||
LOG.info('too many results %d > %d, scaling back radius' %
|
||||
(totres, max_results))
|
||||
radius, totres = apply_maxres(
|
||||
res_batches, min_results,
|
||||
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
|
||||
)
|
||||
t2 = time.time()
|
||||
t_search += t1 - t0
|
||||
t_post_process += t2 - t1
|
||||
LOG.debug(' [%.3f s] %d queries done, %d results' % (
|
||||
time.time() - t_start, qtot, totres))
|
||||
|
||||
LOG.info(
|
||||
'search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
|
||||
t_search, t_post_process, totres, radius)
|
||||
)
|
||||
|
||||
if clip_to_min and totres > min_results:
|
||||
radius, totres = apply_maxres(
|
||||
res_batches, min_results,
|
||||
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
|
||||
)
|
||||
|
||||
nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
|
||||
dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches])
|
||||
ids = np.hstack([ids_i for nres_i, dis_i, ids_i in res_batches])
|
||||
|
||||
lims = np.zeros(len(nres) + 1, dtype='uint64')
|
||||
lims[1:] = np.cumsum(nres)
|
||||
|
||||
return radius, lims, dis, ids
|
||||
|
||||
|
||||
def exponential_query_iterator(xq, start_bs=32, max_bs=20000):
|
||||
""" produces batches of progressively increasing sizes. This is useful to
|
||||
adjust the search radius progressively without overflowing with
|
||||
intermediate results """
|
||||
nq = len(xq)
|
||||
bs = start_bs
|
||||
i = 0
|
||||
while i < nq:
|
||||
xqi = xq[i:i + bs]
|
||||
yield xqi
|
||||
if bs < max_bs:
|
||||
bs *= 2
|
||||
i += len(xqi)
|
||||
149
packages/leann-backend-hnsw/third_party/faiss/contrib/factory_tools.py
vendored
Normal file
149
packages/leann-backend-hnsw/third_party/faiss/contrib/factory_tools.py
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import faiss
|
||||
import re
|
||||
|
||||
|
||||
def get_code_size(d, indexkey):
|
||||
""" size of one vector in an index in dimension d
|
||||
constructed with factory string indexkey"""
|
||||
|
||||
if indexkey == "Flat":
|
||||
return d * 4
|
||||
|
||||
if indexkey.endswith(",RFlat"):
|
||||
return d * 4 + get_code_size(d, indexkey[:-len(",RFlat")])
|
||||
|
||||
mo = re.match("IVF\\d+(_HNSW32)?,(.*)$", indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(2))
|
||||
|
||||
mo = re.match("IVF\\d+\\(.*\\)?,(.*)$", indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(1))
|
||||
|
||||
mo = re.match("IMI\\d+x2,(.*)$", indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(1))
|
||||
|
||||
mo = re.match("(.*),Refine\\((.*)\\)$", indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(1)) + get_code_size(d, mo.group(2))
|
||||
|
||||
mo = re.match('PQ(\\d+)x(\\d+)(fs|fsr)?$', indexkey)
|
||||
if mo:
|
||||
return (int(mo.group(1)) * int(mo.group(2)) + 7) // 8
|
||||
|
||||
mo = re.match('PQ(\\d+)\\+(\\d+)$', indexkey)
|
||||
if mo:
|
||||
return (int(mo.group(1)) + int(mo.group(2)))
|
||||
|
||||
mo = re.match('PQ(\\d+)$', indexkey)
|
||||
if mo:
|
||||
return int(mo.group(1))
|
||||
|
||||
if indexkey == "HNSW32" or indexkey == "HNSW32,Flat":
|
||||
return d * 4 + 64 * 4 # roughly
|
||||
|
||||
if indexkey == 'SQ8':
|
||||
return d
|
||||
elif indexkey == 'SQ4':
|
||||
return (d + 1) // 2
|
||||
elif indexkey == 'SQ6':
|
||||
return (d * 6 + 7) // 8
|
||||
elif indexkey == 'SQfp16':
|
||||
return d * 2
|
||||
elif indexkey == 'SQbf16':
|
||||
return d * 2
|
||||
|
||||
mo = re.match('PCAR?(\\d+),(.*)$', indexkey)
|
||||
if mo:
|
||||
return get_code_size(int(mo.group(1)), mo.group(2))
|
||||
mo = re.match('OPQ\\d+_(\\d+),(.*)$', indexkey)
|
||||
if mo:
|
||||
return get_code_size(int(mo.group(1)), mo.group(2))
|
||||
mo = re.match('OPQ\\d+,(.*)$', indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(1))
|
||||
mo = re.match('RR(\\d+),(.*)$', indexkey)
|
||||
if mo:
|
||||
return get_code_size(int(mo.group(1)), mo.group(2))
|
||||
raise RuntimeError("cannot parse " + indexkey)
|
||||
|
||||
|
||||
def get_hnsw_M(index):
|
||||
return index.hnsw.cum_nneighbor_per_level.at(1) // 2
|
||||
|
||||
|
||||
def reverse_index_factory(index):
|
||||
"""
|
||||
attempts to get the factory string the index was built with
|
||||
"""
|
||||
index = faiss.downcast_index(index)
|
||||
if isinstance(index, faiss.IndexFlat):
|
||||
return "Flat"
|
||||
elif isinstance(index, faiss.IndexIVF):
|
||||
quantizer = faiss.downcast_index(index.quantizer)
|
||||
|
||||
if isinstance(quantizer, faiss.IndexFlat):
|
||||
prefix = f"IVF{index.nlist}"
|
||||
elif isinstance(quantizer, faiss.MultiIndexQuantizer):
|
||||
prefix = f"IMI{quantizer.pq.M}x{quantizer.pq.nbits}"
|
||||
elif isinstance(quantizer, faiss.IndexHNSW):
|
||||
prefix = f"IVF{index.nlist}_HNSW{get_hnsw_M(quantizer)}"
|
||||
else:
|
||||
prefix = f"IVF{index.nlist}({reverse_index_factory(quantizer)})"
|
||||
|
||||
if isinstance(index, faiss.IndexIVFFlat):
|
||||
return prefix + ",Flat"
|
||||
if isinstance(index, faiss.IndexIVFScalarQuantizer):
|
||||
return prefix + ",SQ8"
|
||||
if isinstance(index, faiss.IndexIVFPQ):
|
||||
return prefix + f",PQ{index.pq.M}x{index.pq.nbits}"
|
||||
if isinstance(index, faiss.IndexIVFPQFastScan):
|
||||
return prefix + f",PQ{index.pq.M}x{index.pq.nbits}fs"
|
||||
|
||||
elif isinstance(index, faiss.IndexPreTransform):
|
||||
if index.chain.size() != 1:
|
||||
raise NotImplementedError()
|
||||
vt = faiss.downcast_VectorTransform(index.chain.at(0))
|
||||
if isinstance(vt, faiss.OPQMatrix):
|
||||
prefix = f"OPQ{vt.M}_{vt.d_out}"
|
||||
elif isinstance(vt, faiss.ITQTransform):
|
||||
prefix = f"ITQ{vt.itq.d_out}"
|
||||
elif isinstance(vt, faiss.PCAMatrix):
|
||||
assert vt.eigen_power == 0
|
||||
prefix = "PCA" + ("R" if vt.random_rotation else "") + str(vt.d_out)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return f"{prefix},{reverse_index_factory(index.index)}"
|
||||
|
||||
elif isinstance(index, faiss.IndexHNSW):
|
||||
return f"HNSW{get_hnsw_M(index)}"
|
||||
|
||||
elif isinstance(index, faiss.IndexRefine):
|
||||
return f"{reverse_index_factory(index.base_index)},Refine({reverse_index_factory(index.refine_index)})"
|
||||
|
||||
elif isinstance(index, faiss.IndexPQFastScan):
|
||||
return f"PQ{index.pq.M}x{index.pq.nbits}fs"
|
||||
|
||||
elif isinstance(index, faiss.IndexPQ):
|
||||
return f"PQ{index.pq.M}x{index.pq.nbits}"
|
||||
|
||||
elif isinstance(index, faiss.IndexLSH):
|
||||
return "LSH" + ("r" if index.rotate_data else "") + ("t" if index.train_thresholds else "")
|
||||
|
||||
elif isinstance(index, faiss.IndexScalarQuantizer):
|
||||
sqtypes = {
|
||||
faiss.ScalarQuantizer.QT_8bit: "8",
|
||||
faiss.ScalarQuantizer.QT_4bit: "4",
|
||||
faiss.ScalarQuantizer.QT_6bit: "6",
|
||||
faiss.ScalarQuantizer.QT_fp16: "fp16",
|
||||
faiss.ScalarQuantizer.QT_bf16: "bf16",
|
||||
}
|
||||
return f"SQ{sqtypes[index.sq.qtype]}"
|
||||
|
||||
raise NotImplementedError()
|
||||
117
packages/leann-backend-hnsw/third_party/faiss/contrib/inspect_tools.py
vendored
Normal file
117
packages/leann-backend-hnsw/third_party/faiss/contrib/inspect_tools.py
vendored
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import faiss
|
||||
|
||||
|
||||
def get_invlist(invlists, l):
|
||||
""" returns the inverted lists content as a pair of (list_ids, list_codes).
|
||||
The codes are reshaped to a proper size
|
||||
"""
|
||||
invlists = faiss.downcast_InvertedLists(invlists)
|
||||
ls = invlists.list_size(l)
|
||||
list_ids = np.zeros(ls, dtype='int64')
|
||||
ids = codes = None
|
||||
try:
|
||||
ids = invlists.get_ids(l)
|
||||
if ls > 0:
|
||||
faiss.memcpy(faiss.swig_ptr(list_ids), ids, list_ids.nbytes)
|
||||
codes = invlists.get_codes(l)
|
||||
if invlists.code_size != faiss.InvertedLists.INVALID_CODE_SIZE:
|
||||
list_codes = np.zeros((ls, invlists.code_size), dtype='uint8')
|
||||
else:
|
||||
# it's a BlockInvertedLists
|
||||
npb = invlists.n_per_block
|
||||
bs = invlists.block_size
|
||||
ls_round = (ls + npb - 1) // npb
|
||||
list_codes = np.zeros((ls_round, bs // npb, npb), dtype='uint8')
|
||||
if ls > 0:
|
||||
faiss.memcpy(faiss.swig_ptr(list_codes), codes, list_codes.nbytes)
|
||||
finally:
|
||||
if ids is not None:
|
||||
invlists.release_ids(l, ids)
|
||||
if codes is not None:
|
||||
invlists.release_codes(l, codes)
|
||||
return list_ids, list_codes
|
||||
|
||||
|
||||
def get_invlist_sizes(invlists):
|
||||
""" return the array of sizes of the inverted lists """
|
||||
return np.array([
|
||||
invlists.list_size(i)
|
||||
for i in range(invlists.nlist)
|
||||
], dtype='int64')
|
||||
|
||||
|
||||
def print_object_fields(obj):
|
||||
""" list values all fields of an object known to SWIG """
|
||||
|
||||
for name in obj.__class__.__swig_getmethods__:
|
||||
print(f"{name} = {getattr(obj, name)}")
|
||||
|
||||
|
||||
def get_pq_centroids(pq):
|
||||
""" return the PQ centroids as an array """
|
||||
cen = faiss.vector_to_array(pq.centroids)
|
||||
return cen.reshape(pq.M, pq.ksub, pq.dsub)
|
||||
|
||||
|
||||
def get_LinearTransform_matrix(pca):
|
||||
""" extract matrix + bias from the PCA object
|
||||
works for any linear transform (OPQ, random rotation, etc.)
|
||||
"""
|
||||
b = faiss.vector_to_array(pca.b)
|
||||
A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in)
|
||||
return A, b
|
||||
|
||||
|
||||
def make_LinearTransform_matrix(A, b=None):
|
||||
""" make a linear transform from a matrix and a bias term (optional)"""
|
||||
d_out, d_in = A.shape
|
||||
if b is not None:
|
||||
assert b.shape == (d_out, )
|
||||
lt = faiss.LinearTransform(d_in, d_out, b is not None)
|
||||
faiss.copy_array_to_vector(A.ravel(), lt.A)
|
||||
if b is not None:
|
||||
faiss.copy_array_to_vector(b, lt.b)
|
||||
lt.is_trained = True
|
||||
lt.set_is_orthonormal()
|
||||
return lt
|
||||
|
||||
|
||||
def get_additive_quantizer_codebooks(aq):
|
||||
""" return to codebooks of an additive quantizer """
|
||||
codebooks = faiss.vector_to_array(aq.codebooks).reshape(-1, aq.d)
|
||||
co = faiss.vector_to_array(aq.codebook_offsets)
|
||||
return [
|
||||
codebooks[co[i]:co[i + 1]]
|
||||
for i in range(aq.M)
|
||||
]
|
||||
|
||||
|
||||
def get_flat_data(index):
|
||||
""" copy and return the data matrix in an IndexFlat """
|
||||
xb = faiss.vector_to_array(index.codes).view("float32")
|
||||
return xb.reshape(index.ntotal, index.d)
|
||||
|
||||
|
||||
def get_flat_codes(index_flat):
|
||||
""" get the codes from an indexFlatCodes as an array """
|
||||
return faiss.vector_to_array(index_flat.codes).reshape(
|
||||
index_flat.ntotal, index_flat.code_size)
|
||||
|
||||
|
||||
def get_NSG_neighbors(nsg):
|
||||
""" get the neighbor list for the vectors stored in the NSG structure, as
|
||||
a N-by-K matrix of indices """
|
||||
graph = nsg.get_final_graph()
|
||||
neighbors = np.zeros((graph.N, graph.K), dtype='int32')
|
||||
faiss.memcpy(
|
||||
faiss.swig_ptr(neighbors),
|
||||
graph.data,
|
||||
neighbors.nbytes
|
||||
)
|
||||
return neighbors
|
||||
148
packages/leann-backend-hnsw/third_party/faiss/contrib/ivf_tools.py
vendored
Normal file
148
packages/leann-backend-hnsw/third_party/faiss/contrib/ivf_tools.py
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import faiss
|
||||
|
||||
from faiss.contrib.inspect_tools import get_invlist_sizes
|
||||
|
||||
|
||||
def add_preassigned(index_ivf, x, a, ids=None):
|
||||
"""
|
||||
Add elements to an IVF index, where the assignment is already computed
|
||||
"""
|
||||
n, d = x.shape
|
||||
assert a.shape == (n, )
|
||||
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
||||
d *= 8
|
||||
assert d == index_ivf.d
|
||||
if ids is not None:
|
||||
assert ids.shape == (n, )
|
||||
ids = faiss.swig_ptr(ids)
|
||||
index_ivf.add_core(
|
||||
n, faiss.swig_ptr(x), ids, faiss.swig_ptr(a)
|
||||
)
|
||||
|
||||
|
||||
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
|
||||
"""
|
||||
Perform a search in the IVF index, with predefined lists to search into.
|
||||
Supports indexes with pretransforms (as opposed to the
|
||||
IndexIVF.search_preassigned, that cannot be applied with pretransform).
|
||||
"""
|
||||
if isinstance(index_ivf, faiss.IndexPreTransform):
|
||||
assert index_ivf.chain.size() == 1, "chain must have only one component"
|
||||
transform = faiss.downcast_VectorTransform(index_ivf.chain.at(0))
|
||||
xq = transform.apply(xq)
|
||||
index_ivf = faiss.downcast_index(index_ivf.index)
|
||||
n, d = xq.shape
|
||||
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
||||
d *= 8
|
||||
dis_type = "int32"
|
||||
else:
|
||||
dis_type = "float32"
|
||||
|
||||
assert d == index_ivf.d
|
||||
assert list_nos.shape == (n, index_ivf.nprobe)
|
||||
|
||||
# the coarse distances are used in IVFPQ with L2 distance and
|
||||
# by_residual=True otherwise we provide dummy coarse_dis
|
||||
if coarse_dis is None:
|
||||
coarse_dis = np.zeros((n, index_ivf.nprobe), dtype=dis_type)
|
||||
else:
|
||||
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
||||
|
||||
return index_ivf.search_preassigned(xq, k, list_nos, coarse_dis)
|
||||
|
||||
|
||||
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
|
||||
"""
|
||||
Perform a range search in the IVF index, with predefined lists to
|
||||
search into
|
||||
"""
|
||||
n, d = x.shape
|
||||
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
||||
d *= 8
|
||||
dis_type = "int32"
|
||||
else:
|
||||
dis_type = "float32"
|
||||
|
||||
# the coarse distances are used in IVFPQ with L2 distance and
|
||||
# by_residual=True otherwise we provide dummy coarse_dis
|
||||
if coarse_dis is None:
|
||||
coarse_dis = np.empty((n, index_ivf.nprobe), dtype=dis_type)
|
||||
else:
|
||||
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
||||
|
||||
assert d == index_ivf.d
|
||||
assert list_nos.shape == (n, index_ivf.nprobe)
|
||||
|
||||
res = faiss.RangeSearchResult(n)
|
||||
sp = faiss.swig_ptr
|
||||
|
||||
index_ivf.range_search_preassigned_c(
|
||||
n, sp(x), radius,
|
||||
sp(list_nos), sp(coarse_dis),
|
||||
res
|
||||
)
|
||||
# get pointers and copy them
|
||||
lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
|
||||
num_results = int(lims[-1])
|
||||
dist = faiss.rev_swig_ptr(res.distances, num_results).copy()
|
||||
indices = faiss.rev_swig_ptr(res.labels, num_results).copy()
|
||||
return lims, dist, indices
|
||||
|
||||
|
||||
def replace_ivf_quantizer(index_ivf, new_quantizer):
|
||||
""" replace the IVF quantizer with a flat quantizer and return the
|
||||
old quantizer"""
|
||||
if new_quantizer.ntotal == 0:
|
||||
centroids = index_ivf.quantizer.reconstruct_n()
|
||||
new_quantizer.train(centroids)
|
||||
new_quantizer.add(centroids)
|
||||
else:
|
||||
assert new_quantizer.ntotal == index_ivf.nlist
|
||||
|
||||
# cleanly dealloc old quantizer
|
||||
old_own = index_ivf.own_fields
|
||||
index_ivf.own_fields = False
|
||||
old_quantizer = faiss.downcast_index(index_ivf.quantizer)
|
||||
old_quantizer.this.own(old_own)
|
||||
index_ivf.quantizer = new_quantizer
|
||||
|
||||
if hasattr(index_ivf, "referenced_objects"):
|
||||
index_ivf.referenced_objects.append(new_quantizer)
|
||||
else:
|
||||
index_ivf.referenced_objects = [new_quantizer]
|
||||
return old_quantizer
|
||||
|
||||
|
||||
def permute_invlists(index_ivf, perm):
|
||||
""" Apply some permutation to the inverted lists, and modify the quantizer
|
||||
entries accordingly.
|
||||
Perm is an array of size nlist, where old_index = perm[new_index]
|
||||
"""
|
||||
nlist, = perm.shape
|
||||
assert index_ivf.nlist == nlist
|
||||
quantizer = faiss.downcast_index(index_ivf.quantizer)
|
||||
assert quantizer.ntotal == index_ivf.nlist
|
||||
perm = np.ascontiguousarray(perm, dtype='int64')
|
||||
|
||||
# just make sure it's a permutation...
|
||||
bc = np.bincount(perm, minlength=nlist)
|
||||
assert np.all(bc == np.ones(nlist, dtype=int))
|
||||
|
||||
# handle quantizer
|
||||
quantizer.permute_entries(perm)
|
||||
|
||||
# handle inverted lists
|
||||
invlists = faiss.downcast_InvertedLists(index_ivf.invlists)
|
||||
invlists.permute_invlists(faiss.swig_ptr(perm))
|
||||
|
||||
|
||||
def sort_invlists_by_size(index_ivf):
|
||||
invlist_sizes = get_invlist_sizes(index_ivf.invlists)
|
||||
perm = np.argsort(invlist_sizes)
|
||||
permute_invlists(index_ivf, perm)
|
||||
59
packages/leann-backend-hnsw/third_party/faiss/contrib/ondisk.py
vendored
Normal file
59
packages/leann-backend-hnsw/third_party/faiss/contrib/ondisk.py
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
import faiss
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merge_ondisk(
|
||||
trained_index: faiss.Index, shard_fnames: List[str], ivfdata_fname: str, shift_ids=False
|
||||
) -> None:
|
||||
"""Add the contents of the indexes stored in shard_fnames into the index
|
||||
trained_index. The on-disk data is stored in ivfdata_fname"""
|
||||
assert not isinstance(
|
||||
trained_index, faiss.IndexIVFPQR
|
||||
), "IndexIVFPQR is not supported as an on disk index."
|
||||
# merge the images into an on-disk index
|
||||
# first load the inverted lists
|
||||
ivfs = []
|
||||
for fname in shard_fnames:
|
||||
# the IO_FLAG_MMAP is to avoid actually loading the data thus
|
||||
# the total size of the inverted lists can exceed the
|
||||
# available RAM
|
||||
LOG.info("read " + fname)
|
||||
index = faiss.read_index(fname, faiss.IO_FLAG_MMAP)
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
ivfs.append(index_ivf.invlists)
|
||||
|
||||
# avoid that the invlists get deallocated with the index
|
||||
index_ivf.own_invlists = False
|
||||
|
||||
# construct the output index
|
||||
index = trained_index
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
|
||||
assert index.ntotal == 0, "works only on empty index"
|
||||
|
||||
# prepare the output inverted lists. They will be written
|
||||
# to merged_index.ivfdata
|
||||
invlists = faiss.OnDiskInvertedLists(
|
||||
index_ivf.nlist, index_ivf.code_size, ivfdata_fname
|
||||
)
|
||||
|
||||
# merge all the inverted lists
|
||||
ivf_vector = faiss.InvertedListsPtrVector()
|
||||
for ivf in ivfs:
|
||||
ivf_vector.push_back(ivf)
|
||||
|
||||
LOG.info("merge %d inverted lists " % ivf_vector.size())
|
||||
ntotal = invlists.merge_from_multiple(ivf_vector.data(), ivf_vector.size(), shift_ids)
|
||||
|
||||
# now replace the inverted lists in the output index
|
||||
index.ntotal = index_ivf.ntotal = ntotal
|
||||
index_ivf.replace_invlists(invlists, True)
|
||||
invlists.this.disown()
|
||||
258
packages/leann-backend-hnsw/third_party/faiss/contrib/rpc.py
vendored
Executable file
258
packages/leann-backend-hnsw/third_party/faiss/contrib/rpc.py
vendored
Executable file
@@ -0,0 +1,258 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Simplistic RPC implementation.
|
||||
Exposes all functions of a Server object.
|
||||
|
||||
This code is for demonstration purposes only, and does not include certain
|
||||
security protections. It is not meant to be run on an untrusted network or
|
||||
in a production environment.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import _thread
|
||||
import traceback
|
||||
import socket
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# default
|
||||
PORT = 12032
|
||||
|
||||
safe_modules = {
|
||||
'numpy',
|
||||
'numpy.core.multiarray',
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
|
||||
def find_class(self, module, name):
|
||||
# Only allow safe modules.
|
||||
if module in safe_modules:
|
||||
return getattr(importlib.import_module(module), name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
class FileSock:
|
||||
" wraps a socket so that it is usable by pickle/cPickle "
|
||||
|
||||
def __init__(self,sock):
|
||||
self.sock = sock
|
||||
self.nr=0
|
||||
|
||||
def write(self, buf):
|
||||
# print("sending %d bytes"%len(buf))
|
||||
#self.sock.sendall(buf)
|
||||
# print("...done")
|
||||
bs = 512 * 1024
|
||||
ns = 0
|
||||
while ns < len(buf):
|
||||
sent = self.sock.send(buf[ns:ns + bs])
|
||||
ns += sent
|
||||
|
||||
def read(self,bs=512*1024):
|
||||
#if self.nr==10000: pdb.set_trace()
|
||||
self.nr+=1
|
||||
# print("read bs=%d"%bs)
|
||||
b = []
|
||||
nb = 0
|
||||
while len(b)<bs:
|
||||
# print(' loop')
|
||||
rb = self.sock.recv(bs - nb)
|
||||
if not rb: break
|
||||
b.append(rb)
|
||||
nb += len(rb)
|
||||
return b''.join(b)
|
||||
|
||||
def readline(self):
|
||||
# print("readline!")
|
||||
"""may be optimized..."""
|
||||
s=bytes()
|
||||
while True:
|
||||
c=self.read(1)
|
||||
s+=c
|
||||
if len(c)==0 or chr(c[0])=='\n':
|
||||
return s
|
||||
|
||||
class ClientExit(Exception):
|
||||
pass
|
||||
|
||||
class ServerException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Server:
|
||||
"""
|
||||
server protocol. Methods from classes that subclass Server can be called
|
||||
transparently from a client
|
||||
"""
|
||||
|
||||
def __init__(self, s, logf=sys.stderr, log_prefix=''):
|
||||
self.logf = logf
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
# connection
|
||||
|
||||
self.conn = s
|
||||
self.fs = FileSock(s)
|
||||
|
||||
|
||||
def log(self, s):
|
||||
self.logf.write("Sever log %s: %s\n" % (self.log_prefix, s))
|
||||
|
||||
def one_function(self):
|
||||
"""
|
||||
Executes a single function with associated I/O.
|
||||
Protocol:
|
||||
- the arguments and results are serialized with the pickle protocol
|
||||
- client sends : (fname,args)
|
||||
fname = method name to call
|
||||
args = tuple of arguments
|
||||
- server sends result: (rid,st,ret)
|
||||
rid = request id
|
||||
st = None, or exception if there was during execution
|
||||
ret = return value or None if st!=None
|
||||
"""
|
||||
|
||||
try:
|
||||
(fname, args) = RestrictedUnpickler(self.fs).load()
|
||||
except EOFError:
|
||||
raise ClientExit("read args")
|
||||
self.log("executing method %s"%(fname))
|
||||
st = None
|
||||
ret = None
|
||||
try:
|
||||
f=getattr(self,fname)
|
||||
except AttributeError:
|
||||
st = AttributeError("unknown method "+fname)
|
||||
self.log("unknown method")
|
||||
|
||||
try:
|
||||
ret = f(*args)
|
||||
except Exception as e:
|
||||
# due to a bug (in mod_python?), ServerException cannot be
|
||||
# unpickled, so send the string and make the exception on the client side
|
||||
|
||||
#st=ServerException(
|
||||
# "".join(traceback.format_tb(sys.exc_info()[2]))+
|
||||
# str(e))
|
||||
st="".join(traceback.format_tb(sys.exc_info()[2]))+str(e)
|
||||
self.log("exception in method")
|
||||
traceback.print_exc(50,self.logf)
|
||||
self.logf.flush()
|
||||
|
||||
LOG.info("return")
|
||||
try:
|
||||
pickle.dump((st ,ret), self.fs, protocol=4)
|
||||
except EOFError:
|
||||
raise ClientExit("function return")
|
||||
|
||||
def exec_loop(self):
|
||||
""" main execution loop. Loops and handles exit states"""
|
||||
|
||||
self.log("in exec_loop")
|
||||
try:
|
||||
while True:
|
||||
self.one_function()
|
||||
except ClientExit as e:
|
||||
self.log("ClientExit %s"%e)
|
||||
except socket.error as e:
|
||||
self.log("socket error %s"%e)
|
||||
traceback.print_exc(50,self.logf)
|
||||
except EOFError:
|
||||
self.log("EOF during communication")
|
||||
traceback.print_exc(50,self.logf)
|
||||
except BaseException:
|
||||
# unexpected
|
||||
traceback.print_exc(50,sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
LOG.info("exit sever")
|
||||
|
||||
def exec_loop_cleanup(self):
|
||||
pass
|
||||
|
||||
###################################################################
|
||||
# spying stuff
|
||||
|
||||
def get_ps_stats(self):
|
||||
ret=''
|
||||
f=os.popen("echo ============ `hostname` uptime:; uptime;"+
|
||||
"echo ============ self:; "+
|
||||
"ps -p %d -o pid,vsize,rss,%%cpu,nlwp,psr; "%os.getpid()+
|
||||
"echo ============ run queue:;"+
|
||||
"ps ar -o user,pid,%cpu,%mem,ni,nlwp,psr,vsz,rss,cputime,command")
|
||||
for l in f:
|
||||
ret+=l
|
||||
return ret
|
||||
|
||||
class Client:
|
||||
"""
|
||||
Methods of the server object can be called transparently. Exceptions are
|
||||
re-raised.
|
||||
"""
|
||||
def __init__(self, HOST, port=PORT, v6=False):
|
||||
socktype = socket.AF_INET6 if v6 else socket.AF_INET
|
||||
|
||||
sock = socket.socket(socktype, socket.SOCK_STREAM)
|
||||
LOG.info("connecting to %s:%d, socket type: %s", HOST, port, socktype)
|
||||
sock.connect((HOST, port))
|
||||
self.sock = sock
|
||||
self.fs = FileSock(sock)
|
||||
|
||||
def generic_fun(self, fname, args):
|
||||
# int "gen fun",fname
|
||||
pickle.dump((fname, args), self.fs, protocol=4)
|
||||
return self.get_result()
|
||||
|
||||
def get_result(self):
|
||||
(st, ret) = RestrictedUnpickler(self.fs).load()
|
||||
if st!=None:
|
||||
raise ServerException(st)
|
||||
else:
|
||||
return ret
|
||||
|
||||
def __getattr__(self,name):
|
||||
return lambda *x: self.generic_fun(name,x)
|
||||
|
||||
|
||||
def run_server(new_handler, port=PORT, report_to_file=None, v6=False):
|
||||
|
||||
HOST = '' # Symbolic name meaning the local host
|
||||
socktype = socket.AF_INET6 if v6 else socket.AF_INET
|
||||
s = socket.socket(socktype, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
LOG.info("bind %s:%d", HOST, port)
|
||||
s.bind((HOST, port))
|
||||
s.listen(5)
|
||||
|
||||
LOG.info("accepting connections")
|
||||
if report_to_file is not None:
|
||||
LOG.info('storing host+port in %s', report_to_file)
|
||||
open(report_to_file, 'w').write('%s:%d ' % (socket.gethostname(), port))
|
||||
|
||||
while True:
|
||||
try:
|
||||
conn, addr = s.accept()
|
||||
except socket.error as e:
|
||||
if e[1]=='Interrupted system call': continue
|
||||
raise
|
||||
|
||||
LOG.info('Connected to %s', addr)
|
||||
|
||||
ibs = new_handler(conn)
|
||||
|
||||
tid = _thread.start_new_thread(ibs.exec_loop,())
|
||||
|
||||
LOG.debug("Thread ID: %d", tid)
|
||||
6
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/README.md
vendored
Normal file
6
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/README.md
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
# The Torch contrib
|
||||
|
||||
This contrib directory contains a few Pytorch routines that
|
||||
are useful for similarity search. They do not necessarily depend on Faiss.
|
||||
|
||||
The code is designed to work with CPU and GPU tensors.
|
||||
0
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/__init__.py
vendored
Normal file
0
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/__init__.py
vendored
Normal file
60
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/clustering.py
vendored
Normal file
60
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/clustering.py
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This contrib module contains Pytorch code for k-means clustering
|
||||
"""
|
||||
import faiss
|
||||
import faiss.contrib.torch_utils
|
||||
import torch
|
||||
|
||||
# the kmeans can produce both torch and numpy centroids
|
||||
from faiss.contrib.clustering import kmeans
|
||||
|
||||
|
||||
class DatasetAssign:
|
||||
"""Wrapper for a tensor that offers a function to assign the vectors
|
||||
to centroids. All other implementations offer the same interface"""
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def count(self):
|
||||
return self.x.shape[0]
|
||||
|
||||
def dim(self):
|
||||
return self.x.shape[1]
|
||||
|
||||
def get_subset(self, indices):
|
||||
return self.x[indices]
|
||||
|
||||
def perform_search(self, centroids):
|
||||
return faiss.knn(self.x, centroids, 1)
|
||||
|
||||
def assign_to(self, centroids, weights=None):
|
||||
D, I = self.perform_search(centroids)
|
||||
|
||||
I = I.ravel()
|
||||
D = D.ravel()
|
||||
nc, d = centroids.shape
|
||||
|
||||
sum_per_centroid = torch.zeros_like(centroids)
|
||||
if weights is None:
|
||||
sum_per_centroid.index_add_(0, I, self.x)
|
||||
else:
|
||||
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])
|
||||
|
||||
# the indices are still in numpy.
|
||||
return I.cpu().numpy(), D, sum_per_centroid
|
||||
|
||||
|
||||
class DatasetAssignGPU(DatasetAssign):
|
||||
|
||||
def __init__(self, res, x):
|
||||
DatasetAssign.__init__(self, x)
|
||||
self.res = res
|
||||
|
||||
def perform_search(self, centroids):
|
||||
return faiss.knn_gpu(self.res, self.x, centroids, 1)
|
||||
96
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/quantization.py
vendored
Normal file
96
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/quantization.py
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This contrib module contains Pytorch code for quantization.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import faiss
|
||||
import math
|
||||
from faiss.contrib.torch import clustering
|
||||
# the kmeans can produce both torch and numpy centroids
|
||||
|
||||
|
||||
class Quantizer:
|
||||
|
||||
def __init__(self, d, code_size):
|
||||
"""
|
||||
d: dimension of vectors
|
||||
code_size: nb of bytes of the code (per vector)
|
||||
"""
|
||||
self.d = d
|
||||
self.code_size = code_size
|
||||
|
||||
def train(self, x):
|
||||
"""
|
||||
takes a n-by-d array and peforms training
|
||||
"""
|
||||
pass
|
||||
|
||||
def encode(self, x):
|
||||
"""
|
||||
takes a n-by-d float array, encodes to an n-by-code_size uint8 array
|
||||
"""
|
||||
pass
|
||||
|
||||
def decode(self, codes):
|
||||
"""
|
||||
takes a n-by-code_size uint8 array, returns a n-by-d array
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class VectorQuantizer(Quantizer):
|
||||
|
||||
def __init__(self, d, k):
|
||||
|
||||
code_size = int(math.ceil(torch.log2(k) / 8))
|
||||
Quantizer.__init__(d, code_size)
|
||||
self.k = k
|
||||
|
||||
def train(self, x):
|
||||
pass
|
||||
|
||||
|
||||
class ProductQuantizer(Quantizer):
|
||||
def __init__(self, d, M, nbits):
|
||||
""" M: number of subvectors, d%M == 0
|
||||
nbits: number of bits that each vector is encoded into
|
||||
"""
|
||||
assert d % M == 0
|
||||
assert nbits == 8 # todo: implement other nbits values
|
||||
code_size = int(math.ceil(M * nbits / 8))
|
||||
Quantizer.__init__(self, d, code_size)
|
||||
self.M = M
|
||||
self.nbits = nbits
|
||||
self.code_size = code_size
|
||||
|
||||
def train(self, x):
|
||||
nc = 2 ** self.nbits
|
||||
sd = self.d // self.M
|
||||
dev = x.device
|
||||
dtype = x.dtype
|
||||
self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype)
|
||||
for m in range(self.M):
|
||||
xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M]
|
||||
data = clustering.DatasetAssign(xsub.contiguous())
|
||||
self.codebook[m] = clustering.kmeans(2 ** self.nbits, data)
|
||||
|
||||
def encode(self, x):
|
||||
codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
|
||||
for m in range(self.M):
|
||||
xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M]
|
||||
_, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1)
|
||||
codes[:, m] = I.ravel()
|
||||
return codes
|
||||
|
||||
def decode(self, codes):
|
||||
idxs = [codes[:, m].long() for m in range(self.M)]
|
||||
vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)]
|
||||
stacked_vectors = torch.stack(vectors, dim=1)
|
||||
cbd = self.codebook.shape[-1]
|
||||
x_rec = stacked_vectors.reshape(-1, cbd * self.M)
|
||||
return x_rec
|
||||
764
packages/leann-backend-hnsw/third_party/faiss/contrib/torch_utils.py
vendored
Normal file
764
packages/leann-backend-hnsw/third_party/faiss/contrib/torch_utils.py
vendored
Normal file
@@ -0,0 +1,764 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
|
||||
This is a set of function wrappers that override the default numpy versions.
|
||||
|
||||
Interoperability functions for pytorch and Faiss: Importing this will allow
|
||||
pytorch Tensors (CPU or GPU) to be used as arguments to Faiss indexes and
|
||||
other functions. Torch GPU tensors can only be used with Faiss GPU indexes.
|
||||
If this is imported with a package that supports Faiss GPU, the necessary
|
||||
stream synchronization with the current pytorch stream will be automatically
|
||||
performed.
|
||||
|
||||
Numpy ndarrays can continue to be used in the Faiss python interface after
|
||||
importing this file. All arguments must be uniformly either numpy ndarrays
|
||||
or Torch tensors; no mixing is allowed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import faiss
|
||||
import torch
|
||||
import contextlib
|
||||
import inspect
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
##################################################################
|
||||
# Equivalent of swig_ptr for Torch tensors
|
||||
##################################################################
|
||||
|
||||
def swig_ptr_from_UInt8Tensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.uint8
|
||||
return faiss.cast_integer_to_uint8_ptr(
|
||||
x.untyped_storage().data_ptr() + x.storage_offset())
|
||||
|
||||
|
||||
def swig_ptr_from_HalfTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.float16
|
||||
# no canonical half type in C/C++
|
||||
return faiss.cast_integer_to_void_ptr(
|
||||
x.untyped_storage().data_ptr() + x.storage_offset() * 2)
|
||||
|
||||
|
||||
def swig_ptr_from_FloatTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.float32
|
||||
return faiss.cast_integer_to_float_ptr(
|
||||
x.untyped_storage().data_ptr() + x.storage_offset() * 4)
|
||||
|
||||
def swig_ptr_from_BFloat16Tensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.bfloat16
|
||||
return faiss.cast_integer_to_void_ptr(
|
||||
x.untyped_storage().data_ptr() + x.storage_offset() * 2)
|
||||
|
||||
|
||||
def swig_ptr_from_IntTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.int32, 'dtype=%s' % x.dtype
|
||||
return faiss.cast_integer_to_int_ptr(
|
||||
x.untyped_storage().data_ptr() + x.storage_offset() * 4)
|
||||
|
||||
|
||||
def swig_ptr_from_IndicesTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
|
||||
return faiss.cast_integer_to_idx_t_ptr(
|
||||
x.untyped_storage().data_ptr() + x.storage_offset() * 8)
|
||||
|
||||
##################################################################
|
||||
# utilities
|
||||
##################################################################
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_stream(res, pytorch_stream=None):
|
||||
""" Creates a scoping object to make Faiss GPU use the same stream
|
||||
as pytorch, based on torch.cuda.current_stream().
|
||||
Or, a specific pytorch stream can be passed in as a second
|
||||
argument, in which case we will use that stream.
|
||||
"""
|
||||
|
||||
if pytorch_stream is None:
|
||||
pytorch_stream = torch.cuda.current_stream()
|
||||
|
||||
# This is the cudaStream_t that we wish to use
|
||||
cuda_stream_s = faiss.cast_integer_to_cudastream_t(pytorch_stream.cuda_stream)
|
||||
|
||||
# So we can revert GpuResources stream state upon exit
|
||||
prior_dev = torch.cuda.current_device()
|
||||
prior_stream = res.getDefaultStream(torch.cuda.current_device())
|
||||
|
||||
res.setDefaultStream(torch.cuda.current_device(), cuda_stream_s)
|
||||
|
||||
# Do the user work
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
res.setDefaultStream(prior_dev, prior_stream)
|
||||
|
||||
def torch_replace_method(the_class, name, replacement,
|
||||
ignore_missing=False, ignore_no_base=False):
|
||||
try:
|
||||
orig_method = getattr(the_class, name)
|
||||
except AttributeError:
|
||||
if ignore_missing:
|
||||
return
|
||||
raise
|
||||
if orig_method.__name__ == 'torch_replacement_' + name:
|
||||
# replacement was done in parent class
|
||||
return
|
||||
|
||||
# We should already have the numpy replacement methods patched
|
||||
assert ignore_no_base or (orig_method.__name__ == 'replacement_' + name)
|
||||
setattr(the_class, name + '_numpy', orig_method)
|
||||
setattr(the_class, name, replacement)
|
||||
|
||||
##################################################################
|
||||
# Setup wrappers
|
||||
##################################################################
|
||||
|
||||
def handle_torch_Index(the_class):
|
||||
def torch_replacement_add(self, x):
|
||||
if type(x) is np.ndarray:
|
||||
# forward to faiss __init__.py base method
|
||||
return self.add_numpy(x)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.add_c(n, x_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.add_c(n, x_ptr)
|
||||
|
||||
def torch_replacement_add_with_ids(self, x, ids):
|
||||
if type(x) is np.ndarray:
|
||||
# forward to faiss __init__.py base method
|
||||
return self.add_with_ids_numpy(x, ids)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
assert type(ids) is torch.Tensor
|
||||
assert ids.shape == (n, ), 'not same number of vectors as ids'
|
||||
ids_ptr = swig_ptr_from_IndicesTensor(ids)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.add_with_ids_c(n, x_ptr, ids_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.add_with_ids_c(n, x_ptr, ids_ptr)
|
||||
|
||||
def torch_replacement_assign(self, x, k, labels=None):
|
||||
if type(x) is np.ndarray:
|
||||
# forward to faiss __init__.py base method
|
||||
return self.assign_numpy(x, k, labels)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if labels is None:
|
||||
labels = torch.empty(n, k, device=x.device, dtype=torch.int64)
|
||||
else:
|
||||
assert type(labels) is torch.Tensor
|
||||
assert labels.shape == (n, k)
|
||||
L_ptr = swig_ptr_from_IndicesTensor(labels)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.assign_c(n, x_ptr, L_ptr, k)
|
||||
else:
|
||||
# CPU torch
|
||||
self.assign_c(n, x_ptr, L_ptr, k)
|
||||
|
||||
return labels
|
||||
|
||||
def torch_replacement_train(self, x):
|
||||
if type(x) is np.ndarray:
|
||||
# forward to faiss __init__.py base method
|
||||
return self.train_numpy(x)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.train_c(n, x_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.train_c(n, x_ptr)
|
||||
|
||||
def search_methods_common(x, k, D, I):
|
||||
n, d = x.shape
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if D is None:
|
||||
D = torch.empty(n, k, device=x.device, dtype=torch.float32)
|
||||
else:
|
||||
assert type(D) is torch.Tensor
|
||||
assert D.shape == (n, k)
|
||||
D_ptr = swig_ptr_from_FloatTensor(D)
|
||||
|
||||
if I is None:
|
||||
I = torch.empty(n, k, device=x.device, dtype=torch.int64)
|
||||
else:
|
||||
assert type(I) is torch.Tensor
|
||||
assert I.shape == (n, k)
|
||||
I_ptr = swig_ptr_from_IndicesTensor(I)
|
||||
|
||||
return x_ptr, D_ptr, I_ptr, D, I
|
||||
|
||||
def torch_replacement_search(self, x, k, D=None, I=None):
|
||||
if type(x) is np.ndarray:
|
||||
# forward to faiss __init__.py base method
|
||||
return self.search_numpy(x, k, D=D, I=I)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
|
||||
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.search_c(n, x_ptr, k, D_ptr, I_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.search_c(n, x_ptr, k, D_ptr, I_ptr)
|
||||
|
||||
return D, I
|
||||
|
||||
def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None):
|
||||
if type(x) is np.ndarray:
|
||||
# Forward to faiss __init__.py base method
|
||||
return self.search_and_reconstruct_numpy(x, k, D=D, I=I, R=R)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
|
||||
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
|
||||
|
||||
if R is None:
|
||||
R = torch.empty(n, k, d, device=x.device, dtype=torch.float32)
|
||||
else:
|
||||
assert type(R) is torch.Tensor
|
||||
assert R.shape == (n, k, d)
|
||||
R_ptr = swig_ptr_from_FloatTensor(R)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr)
|
||||
|
||||
return D, I, R
|
||||
|
||||
def torch_replacement_search_preassigned(self, x, k, Iq, Dq, *, D=None, I=None):
|
||||
if type(x) is np.ndarray:
|
||||
# forward to faiss __init__.py base method
|
||||
return self.search_preassigned_numpy(x, k, Iq, Dq, D=D, I=I)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
|
||||
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
|
||||
|
||||
assert Iq.shape == (n, self.nprobe)
|
||||
Iq = Iq.contiguous()
|
||||
Iq_ptr = swig_ptr_from_IndicesTensor(Iq)
|
||||
|
||||
if Dq is not None:
|
||||
Dq = Dq.contiguous()
|
||||
assert Dq.shape == Iq.shape
|
||||
Dq_ptr = swig_ptr_from_FloatTensor(Dq)
|
||||
else:
|
||||
Dq_ptr = None
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
|
||||
else:
|
||||
# CPU torch
|
||||
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
|
||||
|
||||
return D, I
|
||||
|
||||
def torch_replacement_remove_ids(self, x):
|
||||
# Not yet implemented
|
||||
assert type(x) is not torch.Tensor, 'remove_ids not yet implemented for torch'
|
||||
return self.remove_ids_numpy(x)
|
||||
|
||||
def torch_replacement_reconstruct(self, key, x=None):
|
||||
# No tensor inputs are required, but with importing this module, we
|
||||
# assume that the default should be torch tensors. If we are passed a
|
||||
# numpy array, however, assume that the user is overriding this default
|
||||
if (x is not None) and (type(x) is np.ndarray):
|
||||
# Forward to faiss __init__.py base method
|
||||
return self.reconstruct_numpy(key, x)
|
||||
|
||||
# If the index is a CPU index, the default device is CPU, otherwise we
|
||||
# produce a GPU tensor
|
||||
device = torch.device('cpu')
|
||||
if hasattr(self, 'getDevice'):
|
||||
# same device as the index
|
||||
device = torch.device('cuda', self.getDevice())
|
||||
|
||||
if x is None:
|
||||
x = torch.empty(self.d, device=device, dtype=torch.float32)
|
||||
else:
|
||||
assert type(x) is torch.Tensor
|
||||
assert x.shape == (self.d, )
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.reconstruct_c(key, x_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.reconstruct_c(key, x_ptr)
|
||||
|
||||
return x
|
||||
|
||||
def torch_replacement_reconstruct_n(self, n0=0, ni=-1, x=None):
|
||||
if ni == -1:
|
||||
ni = self.ntotal
|
||||
|
||||
# No tensor inputs are required, but with importing this module, we
|
||||
# assume that the default should be torch tensors. If we are passed a
|
||||
# numpy array, however, assume that the user is overriding this default
|
||||
if (x is not None) and (type(x) is np.ndarray):
|
||||
# Forward to faiss __init__.py base method
|
||||
return self.reconstruct_n_numpy(n0, ni, x)
|
||||
|
||||
# If the index is a CPU index, the default device is CPU, otherwise we
|
||||
# produce a GPU tensor
|
||||
device = torch.device('cpu')
|
||||
if hasattr(self, 'getDevice'):
|
||||
# same device as the index
|
||||
device = torch.device('cuda', self.getDevice())
|
||||
|
||||
if x is None:
|
||||
x = torch.empty(ni, self.d, device=device, dtype=torch.float32)
|
||||
else:
|
||||
assert type(x) is torch.Tensor
|
||||
assert x.shape == (ni, self.d)
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.reconstruct_n_c(n0, ni, x_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.reconstruct_n_c(n0, ni, x_ptr)
|
||||
|
||||
return x
|
||||
|
||||
def torch_replacement_update_vectors(self, keys, x):
|
||||
if type(keys) is np.ndarray:
|
||||
# Forward to faiss __init__.py base method
|
||||
return self.update_vectors_numpy(keys, x)
|
||||
|
||||
assert type(keys) is torch.Tensor
|
||||
(n, ) = keys.shape
|
||||
keys_ptr = swig_ptr_from_IndicesTensor(keys)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
assert x.shape == (n, self.d)
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.update_vectors_c(n, keys_ptr, x_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.update_vectors_c(n, keys_ptr, x_ptr)
|
||||
|
||||
# Until the GPU version is implemented, we do not support pre-allocated
|
||||
# output buffers
|
||||
def torch_replacement_range_search(self, x, thresh):
|
||||
if type(x) is np.ndarray:
|
||||
# Forward to faiss __init__.py base method
|
||||
return self.range_search_numpy(x, thresh)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
assert not x.is_cuda, 'Range search using GPU tensor not yet implemented'
|
||||
assert not hasattr(self, 'getDevice'), 'Range search on GPU index not yet implemented'
|
||||
|
||||
res = faiss.RangeSearchResult(n)
|
||||
self.range_search_c(n, x_ptr, thresh, res)
|
||||
|
||||
# get pointers and copy them
|
||||
# FIXME: no rev_swig_ptr equivalent for torch.Tensor, just convert
|
||||
# np to torch
|
||||
# NOTE: torch does not support np.uint64, just np.int64
|
||||
lims = torch.from_numpy(faiss.rev_swig_ptr(res.lims, n + 1).copy().astype('int64'))
|
||||
nd = int(lims[-1])
|
||||
D = torch.from_numpy(faiss.rev_swig_ptr(res.distances, nd).copy())
|
||||
I = torch.from_numpy(faiss.rev_swig_ptr(res.labels, nd).copy())
|
||||
|
||||
return lims, D, I
|
||||
|
||||
def torch_replacement_sa_encode(self, x, codes=None):
|
||||
if type(x) is np.ndarray:
|
||||
# Forward to faiss __init__.py base method
|
||||
return self.sa_encode_numpy(x, codes)
|
||||
|
||||
assert type(x) is torch.Tensor
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if codes is None:
|
||||
codes = torch.empty(n, self.sa_code_size(), dtype=torch.uint8)
|
||||
else:
|
||||
assert codes.shape == (n, self.sa_code_size())
|
||||
codes_ptr = swig_ptr_from_UInt8Tensor(codes)
|
||||
|
||||
if x.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.sa_encode_c(n, x_ptr, codes_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.sa_encode_c(n, x_ptr, codes_ptr)
|
||||
|
||||
return codes
|
||||
|
||||
def torch_replacement_sa_decode(self, codes, x=None):
|
||||
if type(codes) is np.ndarray:
|
||||
# Forward to faiss __init__.py base method
|
||||
return self.sa_decode_numpy(codes, x)
|
||||
|
||||
assert type(codes) is torch.Tensor
|
||||
n, cs = codes.shape
|
||||
assert cs == self.sa_code_size()
|
||||
codes_ptr = swig_ptr_from_UInt8Tensor(codes)
|
||||
|
||||
if x is None:
|
||||
x = torch.empty(n, self.d, dtype=torch.float32)
|
||||
else:
|
||||
assert type(x) is torch.Tensor
|
||||
assert x.shape == (n, self.d)
|
||||
x_ptr = swig_ptr_from_FloatTensor(x)
|
||||
|
||||
if codes.is_cuda:
|
||||
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
|
||||
|
||||
# On the GPU, use proper stream ordering
|
||||
with using_stream(self.getResources()):
|
||||
self.sa_decode_c(n, codes_ptr, x_ptr)
|
||||
else:
|
||||
# CPU torch
|
||||
self.sa_decode_c(n, codes_ptr, x_ptr)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
torch_replace_method(the_class, 'add', torch_replacement_add)
|
||||
torch_replace_method(the_class, 'add_with_ids', torch_replacement_add_with_ids)
|
||||
torch_replace_method(the_class, 'assign', torch_replacement_assign)
|
||||
torch_replace_method(the_class, 'train', torch_replacement_train)
|
||||
torch_replace_method(the_class, 'search', torch_replacement_search)
|
||||
torch_replace_method(the_class, 'remove_ids', torch_replacement_remove_ids)
|
||||
torch_replace_method(the_class, 'reconstruct', torch_replacement_reconstruct)
|
||||
torch_replace_method(the_class, 'reconstruct_n', torch_replacement_reconstruct_n)
|
||||
torch_replace_method(the_class, 'range_search', torch_replacement_range_search)
|
||||
torch_replace_method(the_class, 'update_vectors', torch_replacement_update_vectors,
|
||||
ignore_missing=True)
|
||||
torch_replace_method(the_class, 'search_and_reconstruct',
|
||||
torch_replacement_search_and_reconstruct, ignore_missing=True)
|
||||
torch_replace_method(the_class, 'search_preassigned',
|
||||
torch_replacement_search_preassigned, ignore_missing=True)
|
||||
torch_replace_method(the_class, 'sa_encode', torch_replacement_sa_encode)
|
||||
torch_replace_method(the_class, 'sa_decode', torch_replacement_sa_decode)
|
||||
|
||||
faiss_module = sys.modules['faiss']
|
||||
|
||||
# Re-patch anything that inherits from faiss.Index to add the torch bindings
|
||||
for symbol in dir(faiss_module):
|
||||
obj = getattr(faiss_module, symbol)
|
||||
if inspect.isclass(obj):
|
||||
the_class = obj
|
||||
if issubclass(the_class, faiss.Index):
|
||||
handle_torch_Index(the_class)
|
||||
|
||||
|
||||
# allows torch tensor usage with knn
|
||||
def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
|
||||
if type(xb) is np.ndarray:
|
||||
# Forward to faiss __init__.py base method
|
||||
return faiss.knn_numpy(xq, xb, k, metric=metric, metric_arg=metric_arg)
|
||||
|
||||
nb, d = xb.size()
|
||||
assert xb.is_contiguous()
|
||||
assert xb.dtype == torch.float32
|
||||
assert not xb.is_cuda, "use knn_gpu for GPU tensors"
|
||||
|
||||
nq, d2 = xq.size()
|
||||
assert d2 == d
|
||||
assert xq.is_contiguous()
|
||||
assert xq.dtype == torch.float32
|
||||
assert not xq.is_cuda, "use knn_gpu for GPU tensors"
|
||||
|
||||
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
|
||||
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
|
||||
I_ptr = swig_ptr_from_IndicesTensor(I)
|
||||
D_ptr = swig_ptr_from_FloatTensor(D)
|
||||
xb_ptr = swig_ptr_from_FloatTensor(xb)
|
||||
xq_ptr = swig_ptr_from_FloatTensor(xq)
|
||||
|
||||
if metric == faiss.METRIC_L2:
|
||||
faiss.knn_L2sqr(
|
||||
xq_ptr, xb_ptr,
|
||||
d, nq, nb, k, D_ptr, I_ptr
|
||||
)
|
||||
elif metric == faiss.METRIC_INNER_PRODUCT:
|
||||
faiss.knn_inner_product(
|
||||
xq_ptr, xb_ptr,
|
||||
d, nq, nb, k, D_ptr, I_ptr
|
||||
)
|
||||
else:
|
||||
faiss.knn_extra_metrics(
|
||||
xq_ptr, xb_ptr,
|
||||
d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr
|
||||
)
|
||||
|
||||
return D, I
|
||||
|
||||
|
||||
torch_replace_method(faiss_module, 'knn', torch_replacement_knn, True, True)
|
||||
|
||||
|
||||
# allows torch tensor usage with bfKnn
|
||||
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_cuvs=False):
|
||||
if type(xb) is np.ndarray:
|
||||
# Forward to faiss __init__.py base method
|
||||
return faiss.knn_gpu_numpy(res, xq, xb, k, D, I, metric, device)
|
||||
|
||||
nb, d = xb.size()
|
||||
if xb.is_contiguous():
|
||||
xb_row_major = True
|
||||
elif xb.t().is_contiguous():
|
||||
xb = xb.t()
|
||||
xb_row_major = False
|
||||
else:
|
||||
raise TypeError('matrix should be row or column-major')
|
||||
|
||||
if xb.dtype == torch.float32:
|
||||
xb_type = faiss.DistanceDataType_F32
|
||||
xb_ptr = swig_ptr_from_FloatTensor(xb)
|
||||
elif xb.dtype == torch.float16:
|
||||
xb_type = faiss.DistanceDataType_F16
|
||||
xb_ptr = swig_ptr_from_HalfTensor(xb)
|
||||
elif xb.dtype == torch.bfloat16:
|
||||
xb_type = faiss.DistanceDataType_BF16
|
||||
xb_ptr = swig_ptr_from_BFloat16Tensor(xb)
|
||||
else:
|
||||
raise TypeError('xq must be float32, float16 or bfloat16')
|
||||
|
||||
nq, d2 = xq.size()
|
||||
assert d2 == d
|
||||
if xq.is_contiguous():
|
||||
xq_row_major = True
|
||||
elif xq.t().is_contiguous():
|
||||
xq = xq.t()
|
||||
xq_row_major = False
|
||||
else:
|
||||
raise TypeError('matrix should be row or column-major')
|
||||
|
||||
if xq.dtype == torch.float32:
|
||||
xq_type = faiss.DistanceDataType_F32
|
||||
xq_ptr = swig_ptr_from_FloatTensor(xq)
|
||||
elif xq.dtype == torch.float16:
|
||||
xq_type = faiss.DistanceDataType_F16
|
||||
xq_ptr = swig_ptr_from_HalfTensor(xq)
|
||||
elif xq.dtype == torch.bfloat16:
|
||||
xq_type = faiss.DistanceDataType_BF16
|
||||
xq_ptr = swig_ptr_from_BFloat16Tensor(xq)
|
||||
else:
|
||||
raise TypeError('xq must be float32, float16 or bfloat16')
|
||||
|
||||
if D is None:
|
||||
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
|
||||
else:
|
||||
assert D.shape == (nq, k)
|
||||
# interface takes void*, we need to check this
|
||||
assert (D.dtype == torch.float32)
|
||||
|
||||
if I is None:
|
||||
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
|
||||
else:
|
||||
assert I.shape == (nq, k)
|
||||
|
||||
if I.dtype == torch.int64:
|
||||
I_type = faiss.IndicesDataType_I64
|
||||
I_ptr = swig_ptr_from_IndicesTensor(I)
|
||||
elif I.dtype == I.dtype == torch.int32:
|
||||
I_type = faiss.IndicesDataType_I32
|
||||
I_ptr = swig_ptr_from_IntTensor(I)
|
||||
else:
|
||||
raise TypeError('I must be i64 or i32')
|
||||
|
||||
D_ptr = swig_ptr_from_FloatTensor(D)
|
||||
|
||||
args = faiss.GpuDistanceParams()
|
||||
args.metric = metric
|
||||
args.k = k
|
||||
args.dims = d
|
||||
args.vectors = xb_ptr
|
||||
args.vectorsRowMajor = xb_row_major
|
||||
args.vectorType = xb_type
|
||||
args.numVectors = nb
|
||||
args.queries = xq_ptr
|
||||
args.queriesRowMajor = xq_row_major
|
||||
args.queryType = xq_type
|
||||
args.numQueries = nq
|
||||
args.outDistances = D_ptr
|
||||
args.outIndices = I_ptr
|
||||
args.outIndicesType = I_type
|
||||
args.device = device
|
||||
args.use_cuvs = use_cuvs
|
||||
|
||||
with using_stream(res):
|
||||
faiss.bfKnn(res, args)
|
||||
|
||||
return D, I
|
||||
|
||||
torch_replace_method(faiss_module, 'knn_gpu', torch_replacement_knn_gpu, True, True)
|
||||
|
||||
# allows torch tensor usage with bfKnn for all pairwise distances
|
||||
def torch_replacement_pairwise_distance_gpu(res, xq, xb, D=None, metric=faiss.METRIC_L2, device=-1):
|
||||
if type(xb) is np.ndarray:
|
||||
# Forward to faiss __init__.py base method
|
||||
return faiss.pairwise_distance_gpu_numpy(res, xq, xb, D, metric)
|
||||
|
||||
nb, d = xb.size()
|
||||
if xb.is_contiguous():
|
||||
xb_row_major = True
|
||||
elif xb.t().is_contiguous():
|
||||
xb = xb.t()
|
||||
xb_row_major = False
|
||||
else:
|
||||
raise TypeError('xb matrix should be row or column-major')
|
||||
|
||||
if xb.dtype == torch.float32:
|
||||
xb_type = faiss.DistanceDataType_F32
|
||||
xb_ptr = swig_ptr_from_FloatTensor(xb)
|
||||
elif xb.dtype == torch.float16:
|
||||
xb_type = faiss.DistanceDataType_F16
|
||||
xb_ptr = swig_ptr_from_HalfTensor(xb)
|
||||
else:
|
||||
raise TypeError('xb must be float32 or float16')
|
||||
|
||||
nq, d2 = xq.size()
|
||||
assert d2 == d
|
||||
if xq.is_contiguous():
|
||||
xq_row_major = True
|
||||
elif xq.t().is_contiguous():
|
||||
xq = xq.t()
|
||||
xq_row_major = False
|
||||
else:
|
||||
raise TypeError('xq matrix should be row or column-major')
|
||||
|
||||
if xq.dtype == torch.float32:
|
||||
xq_type = faiss.DistanceDataType_F32
|
||||
xq_ptr = swig_ptr_from_FloatTensor(xq)
|
||||
elif xq.dtype == torch.float16:
|
||||
xq_type = faiss.DistanceDataType_F16
|
||||
xq_ptr = swig_ptr_from_HalfTensor(xq)
|
||||
else:
|
||||
raise TypeError('xq must be float32 or float16')
|
||||
|
||||
if D is None:
|
||||
D = torch.empty(nq, nb, device=xb.device, dtype=torch.float32)
|
||||
else:
|
||||
assert D.shape == (nq, nb)
|
||||
# interface takes void*, we need to check this
|
||||
assert (D.dtype == torch.float32)
|
||||
|
||||
D_ptr = swig_ptr_from_FloatTensor(D)
|
||||
|
||||
args = faiss.GpuDistanceParams()
|
||||
args.metric = metric
|
||||
args.k = -1 # selects all pairwise distance
|
||||
args.dims = d
|
||||
args.vectors = xb_ptr
|
||||
args.vectorsRowMajor = xb_row_major
|
||||
args.vectorType = xb_type
|
||||
args.numVectors = nb
|
||||
args.queries = xq_ptr
|
||||
args.queriesRowMajor = xq_row_major
|
||||
args.queryType = xq_type
|
||||
args.numQueries = nq
|
||||
args.outDistances = D_ptr
|
||||
args.device = device
|
||||
|
||||
with using_stream(res):
|
||||
faiss.bfKnn(res, args)
|
||||
|
||||
return D
|
||||
|
||||
torch_replace_method(faiss_module, 'pairwise_distance_gpu', torch_replacement_pairwise_distance_gpu, True, True)
|
||||
60
packages/leann-backend-hnsw/third_party/faiss/contrib/vecs_io.py
vendored
Normal file
60
packages/leann-backend-hnsw/third_party/faiss/contrib/vecs_io.py
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
I/O functions in fvecs, bvecs, ivecs formats
|
||||
definition of the formats here: http://corpus-texmex.irisa.fr/
|
||||
"""
|
||||
|
||||
|
||||
def ivecs_read(fname):
|
||||
a = np.fromfile(fname, dtype='int32')
|
||||
if sys.byteorder == 'big':
|
||||
a.byteswap(inplace=True)
|
||||
d = a[0]
|
||||
return a.reshape(-1, d + 1)[:, 1:].copy()
|
||||
|
||||
|
||||
def fvecs_read(fname):
|
||||
return ivecs_read(fname).view('float32')
|
||||
|
||||
|
||||
def ivecs_mmap(fname):
|
||||
assert sys.byteorder != 'big'
|
||||
a = np.memmap(fname, dtype='int32', mode='r')
|
||||
d = a[0]
|
||||
return a.reshape(-1, d + 1)[:, 1:]
|
||||
|
||||
|
||||
def fvecs_mmap(fname):
|
||||
return ivecs_mmap(fname).view('float32')
|
||||
|
||||
|
||||
def bvecs_mmap(fname):
|
||||
x = np.memmap(fname, dtype='uint8', mode='r')
|
||||
if sys.byteorder == 'big':
|
||||
da = x[:4][::-1].copy()
|
||||
d = da.view('int32')[0]
|
||||
else:
|
||||
d = x[:4].view('int32')[0]
|
||||
return x.reshape(-1, d + 4)[:, 4:]
|
||||
|
||||
|
||||
def ivecs_write(fname, m):
|
||||
n, d = m.shape
|
||||
m1 = np.empty((n, d + 1), dtype='int32')
|
||||
m1[:, 0] = d
|
||||
m1[:, 1:] = m
|
||||
if sys.byteorder == 'big':
|
||||
m1.byteswap(inplace=True)
|
||||
m1.tofile(fname)
|
||||
|
||||
|
||||
def fvecs_write(fname, m):
|
||||
m = m.astype('float32')
|
||||
ivecs_write(fname, m.view('int32'))
|
||||
Reference in New Issue
Block a user