Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,76 @@
# The contrib modules
The contrib directory contains helper modules for Faiss for various tasks.
## Code structure
The contrib directory gets compiled in the module faiss.contrib.
Note that although some of the modules may depend on additional modules (eg. GPU Faiss, pytorch, hdf5), they are not necessarily compiled in to avoid adding dependencies. It is the user's responsibility to provide them.
In contrib, we are progressively dropping python2 support.
## List of contrib modules
### rpc.py
A very simple Remote Procedure Call library, where function parameters and results are pickled, for use with client_server.py
### client_server.py
The server handles requests to a Faiss index. The client calls the remote index.
This is mainly to shard datasets over several machines, see [Distributed index](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#distributed-index)
### ondisk.py
Encloses the main logic to merge indexes into an on-disk index.
See [On-disk storage](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#on-disk-storage)
### exhaustive_search.py
Computes the ground-truth search results for a dataset that possibly does not fit in RAM. Uses GPU if available.
Tested in `tests/test_contrib.TestComputeGT`
### torch_utils.py
Interoperability functions for pytorch and Faiss: Importing this will allow pytorch Tensors (CPU or GPU) to be used as arguments to Faiss indexes and other functions. Torch GPU tensors can only be used with Faiss GPU indexes. If this is imported with a package that supports Faiss GPU, the necessary stream synchronization with the current pytorch stream will be automatically performed.
Numpy ndarrays can continue to be used in the Faiss python interface after importing this file. All arguments must be uniformly either numpy ndarrays or Torch tensors; no mixing is allowed.
Tested in `tests/test_contrib_torch.py` (CPU) and `gpu/test/test_contrib_torch_gpu.py` (GPU).
### inspect_tools.py
Functions to inspect C++ objects wrapped by SWIG. Most often this just means reading
fields and converting them to the proper python array.
### ivf_tools.py
A few functions to override the coarse quantizer in IVF, providing additional flexibility for assignment.
### datasets.py
(may require h5py)
Definition of how to access data for some standard datasets.
### factory_tools.py
Functions related to factory strings.
### evaluation.py
A few non-trivial evaluation functions for search results
### clustering.py
Contains:
- a Python implementation of kmeans, that can be used for special datatypes (eg. sparse matrices).
- a 2-level clustering routine and a function that can apply it to train an IndexIVF
### big_batch_search.py
Search IVF indexes with one centroid after another. Useful for large
databases that do not fit in RAM *and* a large number of queries.

View File

View File

@@ -0,0 +1,515 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
import pickle
import os
import logging
from multiprocessing.pool import ThreadPool
import threading
import _thread
from queue import Queue
import traceback
import datetime
import numpy as np
import faiss
from faiss.contrib.inspect_tools import get_invlist
class BigBatchSearcher:
"""
Object that manages all the data related to the computation
except the actual within-bucket matching and the organization of the
computation (parallel or not)
"""
def __init__(
self,
index, xq, k,
verbose=0,
use_float16=False):
# verbosity
self.verbose = verbose
self.tictoc = []
self.xq = xq
self.index = index
self.use_float16 = use_float16
keep_max = faiss.is_similarity_metric(index.metric_type)
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
self.t_accu = [0] * 6
self.t_display = self.t0 = time.time()
def start_t_accu(self):
self.t_accu_t0 = time.time()
def stop_t_accu(self, n):
self.t_accu[n] += time.time() - self.t_accu_t0
def tic(self, name):
self.tictoc = (name, time.time())
if self.verbose > 0:
print(name, end="\r", flush=True)
def toc(self):
name, t0 = self.tictoc
dt = time.time() - t0
if self.verbose > 0:
print(f"{name}: {dt:.3f} s")
return dt
def report(self, l):
if self.verbose == 1 or (
self.verbose == 2 and (
l > 1000 and time.time() < self.t_display + 1.0
)
):
return
t = time.time() - self.t0
print(
f"[{t:.1f} s] list {l}/{self.index.nlist} "
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
f"wait in {self.t_accu[4]:.3f} "
f"wait out {self.t_accu[5]:.3f} "
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
f"mem {faiss.get_mem_usage_kb()}",
end="\r" if self.verbose <= 2 else "\n",
flush=True,
)
self.t_display = time.time()
def coarse_quantization(self):
self.tic("coarse quantization")
bs = 65536
nq = len(self.xq)
q_assign = np.empty((nq, self.index.nprobe), dtype='int32')
for i0 in range(0, nq, bs):
i1 = min(nq, i0 + bs)
q_dis_i, q_assign_i = self.index.quantizer.search(
self.xq[i0:i1], self.index.nprobe)
# q_dis[i0:i1] = q_dis_i
q_assign[i0:i1] = q_assign_i
self.toc()
self.q_assign = q_assign
def reorder_assign(self):
self.tic("bucket sort")
q_assign = self.q_assign
q_assign += 1 # move -1 -> 0
self.bucket_lims = faiss.matrix_bucket_sort_inplace(
self.q_assign, nbucket=self.index.nlist + 1, nt=16)
self.query_ids = self.q_assign.ravel()
if self.verbose > 0:
print(' number of -1s:', self.bucket_lims[1])
self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s
del self.q_assign # inplace so let's forget about the old version...
self.toc()
def prepare_bucket(self, l):
""" prepare the queries and database items for bucket l"""
t0 = time.time()
index = self.index
# prepare queries
i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1]
q_subset = self.query_ids[i0:i1]
xq_l = self.xq[q_subset]
if self.by_residual:
xq_l = xq_l - index.quantizer.reconstruct(l)
t1 = time.time()
# prepare database side
list_ids, xb_l = get_invlist(index.invlists, l)
if self.decode_func is None:
xb_l = xb_l.ravel()
else:
xb_l = self.decode_func(xb_l)
if self.use_float16:
xb_l = xb_l.astype('float16')
xq_l = xq_l.astype('float16')
t2 = time.time()
self.t_accu[0] += t1 - t0
self.t_accu[1] += t2 - t1
return q_subset, xq_l, list_ids, xb_l
def add_results_to_heap(self, q_subset, D, list_ids, I):
"""add the bucket results to the heap structure"""
if D is None:
return
t0 = time.time()
if I is None:
I = list_ids
else:
I = list_ids[I]
self.rh.add_result_subset(q_subset, D, I)
self.t_accu[3] += time.time() - t0
def sizes_in_checkpoint(self):
return (self.xq.shape, self.index.nprobe, self.index.nlist)
def write_checkpoint(self, fname, completed):
# write to temp file then move to final file
tmpname = fname + ".tmp"
with open(tmpname, "wb") as f:
pickle.dump(
{
"sizes": self.sizes_in_checkpoint(),
"completed": completed,
"rh": (self.rh.D, self.rh.I),
}, f, -1)
os.replace(tmpname, fname)
def read_checkpoint(self, fname):
with open(fname, "rb") as f:
ckp = pickle.load(f)
assert ckp["sizes"] == self.sizes_in_checkpoint()
self.rh.D[:] = ckp["rh"][0]
self.rh.I[:] = ckp["rh"][1]
return ckp["completed"]
class BlockComputer:
""" computation within one bucket """
def __init__(
self,
index,
method="knn_function",
pairwise_distances=faiss.pairwise_distances,
knn=faiss.knn):
self.index = index
if index.__class__ == faiss.IndexIVFFlat:
index_help = faiss.IndexFlat(index.d, index.metric_type)
decode_func = lambda x: x.view("float32")
by_residual = False
elif index.__class__ == faiss.IndexIVFPQ:
index_help = faiss.IndexPQ(
index.d, index.pq.M, index.pq.nbits, index.metric_type)
index_help.pq = index.pq
decode_func = index_help.pq.decode
index_help.is_trained = True
by_residual = index.by_residual
elif index.__class__ == faiss.IndexIVFScalarQuantizer:
index_help = faiss.IndexScalarQuantizer(
index.d, index.sq.qtype, index.metric_type)
index_help.sq = index.sq
decode_func = index_help.sq.decode
index_help.is_trained = True
by_residual = index.by_residual
else:
raise RuntimeError(f"index type {index.__class__} not supported")
self.index_help = index_help
self.decode_func = None if method == "index" else decode_func
self.by_residual = by_residual
self.method = method
self.pairwise_distances = pairwise_distances
self.knn = knn
def block_search(self, xq_l, xb_l, list_ids, k, **extra_args):
metric_type = self.index.metric_type
if xq_l.size == 0 or xb_l.size == 0:
D = I = None
elif self.method == "index":
faiss.copy_array_to_vector(xb_l, self.index_help.codes)
self.index_help.ntotal = len(list_ids)
D, I = self.index_help.search(xq_l, k)
elif self.method == "pairwise_distances":
# TODO implement blockwise to avoid mem blowup
D = self.pairwise_distances(xq_l, xb_l, metric=metric_type)
I = None
elif self.method == "knn_function":
D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args)
return D, I
def big_batch_search(
index, xq, k,
method="knn_function",
pairwise_distances=faiss.pairwise_distances,
knn=faiss.knn,
verbose=0,
threaded=0,
use_float16=False,
prefetch_threads=1,
computation_threads=1,
q_assign=None,
checkpoint=None,
checkpoint_freq=7200,
start_list=0,
end_list=None,
crash_at=-1
):
"""
Search queries xq in the IVF index, with a search function that collects
batches of query vectors per inverted list. This can be faster than the
regular search indexes.
Supports IVFFlat, IVFPQ and IVFScalarQuantizer.
Supports three computation methods:
method = "index":
build a flat index and populate it separately for each index
method = "pairwise_distances":
decompress codes and compute all pairwise distances for the queries
and index and add result to heap
method = "knn_function":
decompress codes and compute knn results for the queries
threaded=0: sequential execution
threaded=1: prefetch next bucket while computing the current one
threaded=2: prefetch prefetch_threads buckets at a time.
compute_threads>1: the knn function will get an additional thread_no that
tells which worker should handle this.
In threaded mode, the computation is tiled with the bucket perparation and
the writeback of results (useful to maximize GPU utilization).
use_float16: convert all matrices to float16 (faster for GPU gemm)
q_assign: override coarse assignment, should be a matrix of size nq * nprobe
checkpointing (only for threaded > 1):
checkpoint: file where the checkpoints are stored
checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded
start_list, end_list: process only a subset of invlists
"""
nprobe = index.nprobe
assert method in ("index", "pairwise_distances", "knn_function")
mem_queries = xq.nbytes
mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize
mem_res = len(xq) * k * (
np.dtype('int64').itemsize
+ np.dtype('float32').itemsize
)
mem_tot = mem_queries + mem_assign + mem_res
if verbose > 0:
logging.info(
f"memory: queries {mem_queries} assign {mem_assign} "
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
)
bbs = BigBatchSearcher(
index, xq, k,
verbose=verbose,
use_float16=use_float16
)
comp = BlockComputer(
index,
method=method,
pairwise_distances=pairwise_distances,
knn=knn
)
bbs.decode_func = comp.decode_func
bbs.by_residual = comp.by_residual
if q_assign is None:
bbs.coarse_quantization()
else:
bbs.q_assign = q_assign
bbs.reorder_assign()
if end_list is None:
end_list = index.nlist
completed = set()
if checkpoint is not None:
assert (start_list, end_list) == (0, index.nlist)
if os.path.exists(checkpoint):
logging.info(f"recovering checkpoint: {checkpoint}")
completed = bbs.read_checkpoint(checkpoint)
logging.info(f" already completed: {len(completed)}")
else:
logging.info("no checkpoint: starting from scratch")
if threaded == 0:
# simple sequential version
for l in range(start_list, end_list):
bbs.report(l)
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
t0i = time.time()
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
bbs.t_accu[2] += time.time() - t0i
bbs.add_results_to_heap(q_subset, D, list_ids, I)
elif threaded == 1:
# parallel version with granularity 1
def add_results_and_prefetch(to_add, l):
""" perform the addition for the previous bucket and
prefetch the next (if applicable) """
if to_add is not None:
bbs.add_results_to_heap(*to_add)
if l < index.nlist:
return bbs.prepare_bucket(l)
prefetched_bucket = bbs.prepare_bucket(start_list)
to_add = None
pool = ThreadPool(1)
for l in range(start_list, end_list):
bbs.report(l)
prefetched_bucket_a = pool.apply_async(
add_results_and_prefetch, (to_add, l + 1))
q_subset, xq_l, list_ids, xb_l = prefetched_bucket
bbs.start_t_accu()
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
bbs.stop_t_accu(2)
to_add = q_subset, D, list_ids, I
bbs.start_t_accu()
prefetched_bucket = prefetched_bucket_a.get()
bbs.stop_t_accu(4)
bbs.add_results_to_heap(*to_add)
pool.close()
else:
def task_manager_thread(
task,
pool_size,
start_task,
end_task,
completed,
output_queue,
input_queue,
):
try:
with ThreadPool(pool_size) as pool:
res = [pool.apply_async(
task,
args=(i, output_queue, input_queue))
for i in range(start_task, end_task)
if i not in completed]
for r in res:
r.get()
pool.close()
pool.join()
output_queue.put(None)
except:
traceback.print_exc()
_thread.interrupt_main()
raise
def task_manager(*args):
task_manager = threading.Thread(
target=task_manager_thread,
args=args,
)
task_manager.daemon = True
task_manager.start()
return task_manager
def prepare_task(task_id, output_queue, input_queue=None):
try:
logging.info(f"Prepare start: {task_id}")
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
logging.info(f"Prepare end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise
def compute_task(task_id, output_queue, input_queue):
try:
logging.info(f"Compute start: {task_id}")
t_wait_out = 0
while True:
t0 = time.time()
logging.info(f'Compute input: task {task_id}')
input_value = input_queue.get()
t_wait_in = time.time() - t0
if input_value is None:
# signal for other compute tasks
input_queue.put(None)
break
centroid, q_subset, xq_l, list_ids, xb_l = input_value
logging.info(f'Compute work: task {task_id}, centroid {centroid}')
t0 = time.time()
if computation_threads > 1:
D, I = comp.block_search(
xq_l, xb_l, list_ids, k, thread_id=task_id
)
else:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
t_compute = time.time() - t0
logging.info(f'Compute output: task {task_id}, centroid {centroid}')
t0 = time.time()
output_queue.put(
(centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I)
)
t_wait_out = time.time() - t0
logging.info(f"Compute end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise
prepare_to_compute_queue = Queue(2)
compute_to_main_queue = Queue(2)
compute_task_manager = task_manager(
compute_task,
computation_threads,
0,
computation_threads,
set(),
compute_to_main_queue,
prepare_to_compute_queue,
)
prepare_task_manager = task_manager(
prepare_task,
prefetch_threads,
start_list,
end_list,
completed,
prepare_to_compute_queue,
None,
)
t_checkpoint = time.time()
while True:
logging.info("Waiting for result")
value = compute_to_main_queue.get()
if not value:
break
centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value
# to test checkpointing
if centroid == crash_at:
1 / 0
bbs.t_accu[2] += t_compute
bbs.t_accu[4] += t_wait_in
bbs.t_accu[5] += t_wait_out
logging.info(f"Adding to heap start: centroid {centroid}")
bbs.add_results_to_heap(q_subset, D, list_ids, I)
logging.info(f"Adding to heap end: centroid {centroid}")
completed.add(centroid)
bbs.report(centroid)
if checkpoint is not None:
if time.time() - t_checkpoint > checkpoint_freq:
logging.info("writing checkpoint")
bbs.write_checkpoint(checkpoint, completed)
t_checkpoint = time.time()
prepare_task_manager.join()
compute_task_manager.join()
bbs.tic("finalize heap")
bbs.rh.finalize()
bbs.toc()
return bbs.rh.D, bbs.rh.I

View File

@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from multiprocessing.pool import ThreadPool
import faiss
from typing import List, Tuple
from . import rpc
############################################################
# Server implementation
############################################################
class SearchServer(rpc.Server):
""" Assign version that can be exposed via RPC """
def __init__(self, s: int, index: faiss.Index):
rpc.Server.__init__(self, s)
self.index = index
self.index_ivf = faiss.extract_index_ivf(index)
def set_nprobe(self, nprobe: int) -> int:
""" set nprobe field """
self.index_ivf.nprobe = nprobe
def get_ntotal(self) -> int:
return self.index.ntotal
def __getattr__(self, f):
# all other functions get forwarded to the index
return getattr(self.index, f)
def run_index_server(index: faiss.Index, port: int, v6: bool = False):
""" serve requests for that index forerver """
rpc.run_server(
lambda s: SearchServer(s, index),
port, v6=v6)
############################################################
# Client implementation
############################################################
class ClientIndex:
"""manages a set of distance sub-indexes. The sub_indexes search a
subset of the inverted lists. Searches are merged afterwards
"""
def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False):
""" connect to a series of (host, port) pairs """
self.sub_indexes = []
for machine, port in machine_ports:
self.sub_indexes.append(rpc.Client(machine, port, v6))
self.ni = len(self.sub_indexes)
# pool of threads. Each thread manages one sub-index.
self.pool = ThreadPool(self.ni)
# test connection...
self.ntotal = self.get_ntotal()
self.verbose = False
def set_nprobe(self, nprobe: int) -> None:
self.pool.map(
lambda idx: idx.set_nprobe(nprobe),
self.sub_indexes
)
def set_omp_num_threads(self, nt: int) -> None:
self.pool.map(
lambda idx: idx.set_omp_num_threads(nt),
self.sub_indexes
)
def get_ntotal(self) -> None:
return sum(self.pool.map(
lambda idx: idx.get_ntotal(),
self.sub_indexes
))
def search(self, x, k: int):
rh = faiss.ResultHeap(x.shape[0], k)
for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes):
rh.add_result(Di, Ii)
rh.finalize()
return rh.D, rh.I

View File

@@ -0,0 +1,428 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This contrib module contains a few routines useful to do clustering variants.
"""
import numpy as np
import faiss
import time
from multiprocessing.pool import ThreadPool
try:
import scipy.sparse
except ImportError:
print("scipy not accessible, Python k-means will not work")
def print_nop(*arg, **kwargs):
pass
def two_level_clustering(xt, nc1, nc2, rebalance=True, clustering_niter=25, **args):
"""
perform 2-level clustering on a training set xt
nc1 and nc2 are the number of clusters at each level, the final number of
clusters is nc2. Additional arguments are passed to the Kmeans object.
Rebalance allocates the number of sub-clusters depending on the number of
first-level assignment.
"""
d = xt.shape[1]
verbose = args.get("verbose", False)
log = print if verbose else print_nop
log(f"2-level clustering of {xt.shape} nb 1st level clusters = {nc1} total {nc2}")
log("perform coarse training")
km = faiss.Kmeans(
d, nc1, niter=clustering_niter,
max_points_per_centroid=2000,
**args
)
km.train(xt)
iteration_stats = [km.iteration_stats]
log()
# coarse centroids
centroids1 = km.centroids
log("assigning the training set")
t0 = time.time()
_, assign1 = km.assign(xt)
bc = np.bincount(assign1, minlength=nc1)
log(f"done in {time.time() - t0:.2f} s. Sizes of clusters {min(bc)}-{max(bc)}")
o = assign1.argsort()
del km
if not rebalance:
# make sure the sub-clusters sum up to exactly nc2
cc = np.arange(nc1 + 1) * nc2 // nc1
all_nc2 = cc[1:] - cc[:-1]
else:
bc_sum = np.cumsum(bc)
all_nc2 = bc_sum * nc2 // bc_sum[-1]
all_nc2[1:] -= all_nc2[:-1]
assert sum(all_nc2) == nc2
log(f"nb 2nd-level centroids {min(all_nc2)}-{max(all_nc2)}")
# train sub-clusters
i0 = 0
c2 = []
t0 = time.time()
for c1 in range(nc1):
nc2 = int(all_nc2[c1])
log(f"[{time.time() - t0:.2f} s] training sub-cluster {c1}/{nc1} nc2={nc2}\r", end="", flush=True)
i1 = i0 + bc[c1]
subset = o[i0:i1]
assert np.all(assign1[subset] == c1)
km = faiss.Kmeans(d, nc2, **args)
xtsub = xt[subset]
km.train(xtsub)
iteration_stats.append(km.iteration_stats)
c2.append(km.centroids)
del km
i0 = i1
log(f"done in {time.time() - t0:.2f} s")
return np.vstack(c2), iteration_stats
def train_ivf_index_with_2level(index, xt, **args):
"""
Applies 2-level clustering to an index_ivf embedded in an index.
"""
# handle PreTransforms
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
for i in range(index.chain.size()):
vt = index.chain.at(i)
vt.train(xt)
xt = vt.apply(xt)
train_ivf_index_with_2level(index.index, xt, **args)
index.is_trained = True
return
assert isinstance(index, faiss.IndexIVF)
assert index.metric_type == faiss.METRIC_L2
# now do 2-level clustering
nc1 = int(np.sqrt(index.nlist))
print("REBALANCE=", args)
centroids, _ = two_level_clustering(xt, nc1, index.nlist, **args)
index.quantizer.train(centroids)
index.quantizer.add(centroids)
# finish training
index.train(xt)
###############################################################################
# K-means implementation in Python
#
# It relies on DatasetAssign, an abstraction of the training vectors that offers
# the minimal set of operations to perform k-means clustering.
###############################################################################
class DatasetAssign:
"""Wrapper for a matrix that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""
def __init__(self, x):
self.x = np.ascontiguousarray(x, dtype='float32')
def count(self):
return self.x.shape[0]
def dim(self):
return self.x.shape[1]
def get_subset(self, indices):
return self.x[indices]
def perform_search(self, centroids):
return faiss.knn(self.x, centroids, 1)
def assign_to(self, centroids, weights=None):
D, I = self.perform_search(centroids)
I = I.ravel()
D = D.ravel()
nc, d = centroids.shape
sum_per_centroid = np.zeros((nc, d), dtype='float32')
if weights is None:
np.add.at(sum_per_centroid, I, self.x)
else:
np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x)
return I, D, sum_per_centroid
class DatasetAssignGPU(DatasetAssign):
""" GPU version of the previous """
def __init__(self, x, gpu_id, verbose=False):
DatasetAssign.__init__(self, x)
index = faiss.IndexFlatL2(x.shape[1])
if gpu_id >= 0:
self.index = faiss.index_cpu_to_gpu(
faiss.StandardGpuResources(),
gpu_id, index)
else:
# -1 -> assign to all GPUs
self.index = faiss.index_cpu_to_all_gpus(index)
def perform_search(self, centroids):
self.index.reset()
self.index.add(centroids)
return self.index.search(self.x, 1)
def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None):
""" assignment function for xq is sparse, xb is dense
uses a matrix multiplication. The squared norms can be provided if
available.
"""
nq = xq.shape[0]
nb = xb.shape[0]
if xb_norms is None:
xb_norms = (xb ** 2).sum(1)
if xq_norms is None:
xq_norms = np.array(xq.power(2).sum(1))
d2 = xb_norms - 2 * xq @ xb.T
I = d2.argmin(axis=1)
D = d2.ravel()[I + np.arange(nq) * nb] + xq_norms.ravel()
return D, I
def sparse_assign_to_dense_blocks(
xq, xb, xq_norms=None, xb_norms=None, qbs=16384, bbs=16384, nt=None):
"""
decomposes the sparse_assign_to_dense function into blocks to avoid a
possible memory blow up. Can be run in multithreaded mode, because scipy's
sparse-dense matrix multiplication is single-threaded.
"""
nq = xq.shape[0]
nb = xb.shape[0]
D = np.empty(nq, dtype="float32")
D.fill(np.inf)
I = -np.ones(nq, dtype=int)
if xb_norms is None:
xb_norms = (xb ** 2).sum(1)
def handle_query_block(i):
xq_block = xq[i : i + qbs]
Iblock = I[i : i + qbs]
Dblock = D[i : i + qbs]
if xq_norms is None:
xq_norms_block = np.array(xq_block.power(2).sum(1))
else:
xq_norms_block = xq_norms[i : i + qbs]
for j in range(0, nb, bbs):
Di, Ii = sparse_assign_to_dense(
xq_block,
xb[j : j + bbs],
xq_norms=xq_norms_block,
xb_norms=xb_norms[j : j + bbs],
)
if j == 0:
Iblock[:] = Ii
Dblock[:] = Di
else:
mask = Di < Dblock
Iblock[mask] = Ii[mask] + j
Dblock[mask] = Di[mask]
if nt == 0 or nt == 1 or nq <= qbs:
list(map(handle_query_block, range(0, nq, qbs)))
else:
pool = ThreadPool(nt)
pool.map(handle_query_block, range(0, nq, qbs))
return D, I
class DatasetAssignSparse(DatasetAssign):
"""Wrapper for a matrix that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""
def __init__(self, x):
assert x.__class__ == scipy.sparse.csr_matrix
self.x = x
self.squared_norms = np.array(x.power(2).sum(1))
def get_subset(self, indices):
return np.array(self.x[indices].todense())
def perform_search(self, centroids):
return sparse_assign_to_dense_blocks(
self.x, centroids, xq_norms=self.squared_norms)
def assign_to(self, centroids, weights=None):
D, I = self.perform_search(centroids)
I = I.ravel()
D = D.ravel()
n = self.x.shape[0]
if weights is None:
weights = np.ones(n, dtype='float32')
nc = len(centroids)
m = scipy.sparse.csc_matrix(
(weights, I, np.arange(n + 1)),
shape=(nc, n))
sum_per_centroid = np.array((m * self.x).todense())
return I, D, sum_per_centroid
def imbalance_factor(k, assign):
assign = np.ascontiguousarray(assign, dtype='int64')
return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign))
def check_if_torch(x):
if x.__class__ == np.ndarray:
return False
import torch
if isinstance(x, torch.Tensor):
return True
raise NotImplementedError(f"Unknown tensor type {type(x)}")
def reassign_centroids(hassign, centroids, rs=None):
""" reassign centroids when some of them collapse """
if rs is None:
rs = np.random
k, d = centroids.shape
nsplit = 0
is_torch = check_if_torch(centroids)
empty_cents = np.where(hassign == 0)[0]
if len(empty_cents) == 0:
return 0
if is_torch:
import torch
fac = torch.ones_like(centroids[0])
else:
fac = np.ones_like(centroids[0])
fac[::2] += 1 / 1024.
fac[1::2] -= 1 / 1024.
# this is a single pass unless there are more than k/2
# empty centroids
while len(empty_cents) > 0:
# choose which centroids to split (numpy)
probas = hassign.astype('float') - 1
probas[probas < 0] = 0
probas /= probas.sum()
nnz = (probas > 0).sum()
nreplace = min(nnz, empty_cents.size)
cjs = rs.choice(k, size=nreplace, p=probas)
for ci, cj in zip(empty_cents[:nreplace], cjs):
c = centroids[cj]
centroids[ci] = c * fac
centroids[cj] = c / fac
hassign[ci] = hassign[cj] // 2
hassign[cj] -= hassign[ci]
nsplit += 1
empty_cents = empty_cents[nreplace:]
return nsplit
def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
return_stats=False):
"""Pure python kmeans implementation. Follows the Faiss C++ version
quite closely, but takes a DatasetAssign instead of a training data
matrix. Also redo is not implemented.
For the torch implementation, the centroids are tensors (possibly on GPU),
but the indices remain numpy on CPU.
"""
n, d = data.count(), data.dim()
log = print if verbose else print_nop
log(("Clustering %d points in %dD to %d clusters, " +
"%d iterations seed %d") % (n, d, k, niter, seed))
rs = np.random.RandomState(seed)
print("preproc...")
t0 = time.time()
# initialization
perm = rs.choice(n, size=k, replace=False)
centroids = data.get_subset(perm)
is_torch = check_if_torch(centroids)
iteration_stats = []
log(" done")
t_search_tot = 0
obj = []
for i in range(niter):
t0s = time.time()
log('assigning', end='\r', flush=True)
assign, D, sums = data.assign_to(centroids)
log('compute centroids', end='\r', flush=True)
t_search_tot += time.time() - t0s;
err = D.sum()
if is_torch:
err = err.item()
obj.append(err)
hassign = np.bincount(assign, minlength=k)
fac = hassign.reshape(-1, 1).astype('float32')
fac[fac == 0] = 1 # quiet warning
if is_torch:
import torch
fac = torch.from_numpy(fac).to(sums.device)
centroids = sums / fac
nsplit = reassign_centroids(hassign, centroids, rs)
s = {
"obj": err,
"time": (time.time() - t0),
"time_search": t_search_tot,
"imbalance_factor": imbalance_factor(k, assign),
"nsplit": nsplit
}
log((" Iteration %d (%.2f s, search %.2f s): "
"objective=%g imbalance=%.3f nsplit=%d") % (
i, s["time"], s["time_search"],
err, s["imbalance_factor"],
nsplit)
)
iteration_stats.append(s)
if checkpoint is not None:
log('storing centroids in', checkpoint)
if is_torch:
import torch
torch.save(centroids, checkpoint)
else:
np.save(checkpoint, centroids)
if return_stats:
return centroids, iteration_stats
else:
return centroids

View File

@@ -0,0 +1,387 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import faiss
import getpass
from .vecs_io import fvecs_read, ivecs_read, bvecs_mmap, fvecs_mmap
from .exhaustive_search import knn
class Dataset:
""" Generic abstract class for a test dataset """
def __init__(self):
""" the constructor should set the following fields: """
self.d = -1
self.metric = 'L2' # or IP
self.nq = -1
self.nb = -1
self.nt = -1
def get_queries(self):
""" return the queries as a (nq, d) array """
raise NotImplementedError()
def get_train(self, maxtrain=None):
""" return the queries as a (nt, d) array """
raise NotImplementedError()
def get_database(self):
""" return the queries as a (nb, d) array """
raise NotImplementedError()
def database_iterator(self, bs=128, split=(1, 0)):
"""returns an iterator on database vectors.
bs is the number of vectors per batch
split = (nsplit, rank) means the dataset is split in nsplit
shards and we want shard number rank
The default implementation just iterates over the full matrix
returned by get_dataset.
"""
xb = self.get_database()
nsplit, rank = split
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
for j0 in range(i0, i1, bs):
yield xb[j0: min(j0 + bs, i1)]
def get_groundtruth(self, k=None):
""" return the ground truth for k-nearest neighbor search """
raise NotImplementedError()
def get_groundtruth_range(self, thresh=None):
""" return the ground truth for range search """
raise NotImplementedError()
def __str__(self):
return (f"dataset in dimension {self.d}, with metric {self.metric}, "
f"size: Q {self.nq} B {self.nb} T {self.nt}")
def check_sizes(self):
""" runs the previous and checks the sizes of the matrices """
assert self.get_queries().shape == (self.nq, self.d)
if self.nt > 0:
xt = self.get_train(maxtrain=123)
assert xt.shape == (123, self.d), "shape=%s" % (xt.shape, )
assert self.get_database().shape == (self.nb, self.d)
assert self.get_groundtruth(k=13).shape == (self.nq, 13)
class SyntheticDataset(Dataset):
"""A dataset that is not completely random but still challenging to
index
"""
def __init__(self, d, nt, nb, nq, metric='L2', seed=1338):
Dataset.__init__(self)
self.d, self.nt, self.nb, self.nq = d, nt, nb, nq
d1 = 10 # intrinsic dimension (more or less)
n = nb + nt + nq
rs = np.random.RandomState(seed)
x = rs.normal(size=(n, d1))
x = np.dot(x, rs.rand(d1, d))
# now we have a d1-dim ellipsoid in d-dimensional space
# higher factor (>4) -> higher frequency -> less linear
x = x * (rs.rand(d) * 4 + 0.1)
x = np.sin(x)
x = x.astype('float32')
self.metric = metric
self.xt = x[:nt]
self.xb = x[nt:nt + nb]
self.xq = x[nt + nb:]
def get_queries(self):
return self.xq
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return self.xt[:maxtrain]
def get_database(self):
return self.xb
def get_groundtruth(self, k=100):
return knn(
self.xq, self.xb, k,
faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT
)[1]
############################################################################
# The following datasets are a few standard open-source datasets
# they should be stored in a directory, and we start by guessing where
# that directory is
############################################################################
username = getpass.getuser()
for dataset_basedir in (
'/datasets01/simsearch/041218/',
'/mnt/vol/gfsai-flash3-east/ai-group/datasets/simsearch/',
f'/home/{username}/simsearch/data/'):
if os.path.exists(dataset_basedir):
break
else:
# users can link their data directory to `./data`
dataset_basedir = 'data/'
def set_dataset_basedir(path):
global dataset_basedir
dataset_basedir = path
class DatasetSIFT1M(Dataset):
"""
The original dataset is available at: http://corpus-texmex.irisa.fr/
(ANN_SIFT1M)
"""
def __init__(self):
Dataset.__init__(self)
self.d, self.nt, self.nb, self.nq = 128, 100000, 1000000, 10000
self.basedir = dataset_basedir + 'sift1M/'
def get_queries(self):
return fvecs_read(self.basedir + "sift_query.fvecs")
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return fvecs_read(self.basedir + "sift_learn.fvecs")[:maxtrain]
def get_database(self):
return fvecs_read(self.basedir + "sift_base.fvecs")
def get_groundtruth(self, k=None):
gt = ivecs_read(self.basedir + "sift_groundtruth.ivecs")
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
def sanitize(x):
return np.ascontiguousarray(x, dtype='float32')
class DatasetBigANN(Dataset):
"""
The original dataset is available at: http://corpus-texmex.irisa.fr/
(ANN_SIFT1B)
"""
def __init__(self, nb_M=1000):
Dataset.__init__(self)
assert nb_M in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000)
self.nb_M = nb_M
nb = nb_M * 10**6
self.d, self.nt, self.nb, self.nq = 128, 10**8, nb, 10000
self.basedir = dataset_basedir + 'bigann/'
def get_queries(self):
return sanitize(bvecs_mmap(self.basedir + 'bigann_query.bvecs')[:])
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return sanitize(bvecs_mmap(self.basedir + 'bigann_learn.bvecs')[:maxtrain])
def get_groundtruth(self, k=None):
gt = ivecs_read(self.basedir + 'gnd/idx_%dM.ivecs' % self.nb_M)
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
def get_database(self):
assert self.nb_M < 100, "dataset too large, use iterator"
return sanitize(bvecs_mmap(self.basedir + 'bigann_base.bvecs')[:self.nb])
def database_iterator(self, bs=128, split=(1, 0)):
xb = bvecs_mmap(self.basedir + 'bigann_base.bvecs')
nsplit, rank = split
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
for j0 in range(i0, i1, bs):
yield sanitize(xb[j0: min(j0 + bs, i1)])
class DatasetDeep1B(Dataset):
"""
See
https://github.com/facebookresearch/faiss/tree/main/benchs#getting-deep1b
on how to get the data
"""
def __init__(self, nb=10**9):
Dataset.__init__(self)
nb_to_name = {
10**5: '100k',
10**6: '1M',
10**7: '10M',
10**8: '100M',
10**9: '1B'
}
assert nb in nb_to_name
self.d, self.nt, self.nb, self.nq = 96, 358480000, nb, 10000
self.basedir = dataset_basedir + 'deep1b/'
self.gt_fname = "%sdeep%s_groundtruth.ivecs" % (
self.basedir, nb_to_name[self.nb])
def get_queries(self):
return sanitize(fvecs_read(self.basedir + "deep1B_queries.fvecs"))
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return sanitize(fvecs_mmap(self.basedir + "learn.fvecs")[:maxtrain])
def get_groundtruth(self, k=None):
gt = ivecs_read(self.gt_fname)
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
def get_database(self):
assert self.nb <= 10**8, "dataset too large, use iterator"
return sanitize(fvecs_mmap(self.basedir + "base.fvecs")[:self.nb])
def database_iterator(self, bs=128, split=(1, 0)):
xb = fvecs_mmap(self.basedir + "base.fvecs")
nsplit, rank = split
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
for j0 in range(i0, i1, bs):
yield sanitize(xb[j0: min(j0 + bs, i1)])
class DatasetGlove(Dataset):
"""
Data from http://ann-benchmarks.com/glove-100-angular.hdf5
"""
def __init__(self, loc=None, download=False):
import h5py
assert not download, "not implemented"
if not loc:
loc = dataset_basedir + 'glove/glove-100-angular.hdf5'
self.glove_h5py = h5py.File(loc, 'r')
# IP and L2 are equivalent in this case, but it is traditionally seen as an IP dataset
self.metric = 'IP'
self.d, self.nt = 100, 0
self.nb = self.glove_h5py['train'].shape[0]
self.nq = self.glove_h5py['test'].shape[0]
def get_queries(self):
xq = np.array(self.glove_h5py['test'])
faiss.normalize_L2(xq)
return xq
def get_database(self):
xb = np.array(self.glove_h5py['train'])
faiss.normalize_L2(xb)
return xb
def get_groundtruth(self, k=None):
gt = self.glove_h5py['neighbors']
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
class DatasetMusic100(Dataset):
"""
get dataset from
https://github.com/stanis-morozov/ip-nsw#dataset
"""
def __init__(self):
Dataset.__init__(self)
self.d, self.nt, self.nb, self.nq = 100, 0, 10**6, 10000
self.metric = 'IP'
self.basedir = dataset_basedir + 'music-100/'
def get_queries(self):
xq = np.fromfile(self.basedir + 'query_music100.bin', dtype='float32')
xq = xq.reshape(-1, 100)
return xq
def get_database(self):
xb = np.fromfile(self.basedir + 'database_music100.bin', dtype='float32')
xb = xb.reshape(-1, 100)
return xb
def get_groundtruth(self, k=None):
gt = np.load(self.basedir + 'gt.npy')
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
class DatasetGIST1M(Dataset):
"""
The original dataset is available at: http://corpus-texmex.irisa.fr/
(ANN_SIFT1M)
"""
def __init__(self):
Dataset.__init__(self)
self.d, self.nt, self.nb, self.nq = 960, 100000, 1000000, 10000
self.basedir = dataset_basedir + 'gist1M/'
def get_queries(self):
return fvecs_read(self.basedir + "gist_query.fvecs")
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return fvecs_read(self.basedir + "gist_learn.fvecs")[:maxtrain]
def get_database(self):
return fvecs_read(self.basedir + "gist_base.fvecs")
def get_groundtruth(self, k=None):
gt = ivecs_read(self.basedir + "gist_groundtruth.ivecs")
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
def dataset_from_name(dataset='deep1M', download=False):
""" converts a string describing a dataset to a Dataset object
Supports sift1M, bigann1M..bigann1B, deep1M..deep1B, music-100 and glove
"""
if dataset == 'sift1M':
return DatasetSIFT1M()
elif dataset == 'gist1M':
return DatasetGIST1M()
elif dataset.startswith('bigann'):
dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1])
return DatasetBigANN(nb_M=dbsize)
elif dataset.startswith("deep"):
szsuf = dataset[4:]
if szsuf[-1] == 'M':
dbsize = 10 ** 6 * int(szsuf[:-1])
elif szsuf == '1B':
dbsize = 10 ** 9
elif szsuf[-1] == 'k':
dbsize = 1000 * int(szsuf[:-1])
else:
assert False, "did not recognize suffix " + szsuf
return DatasetDeep1B(nb=dbsize)
elif dataset == "music-100":
return DatasetMusic100()
elif dataset == "glove":
return DatasetGlove(download=download)
else:
raise RuntimeError("unknown dataset " + dataset)

View File

@@ -0,0 +1,492 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import unittest
import time
import faiss
from multiprocessing.pool import ThreadPool
###############################################################
# Simple functions to evaluate knn results
def knn_intersection_measure(I1, I2):
""" computes the intersection measure of two result tables
"""
nq, rank = I1.shape
assert I2.shape == (nq, rank)
ninter = sum(
np.intersect1d(I1[i], I2[i]).size
for i in range(nq)
)
return ninter / I1.size
###############################################################
# Range search results can be compared with Precision-Recall
def filter_range_results(lims, D, I, thresh):
""" select a set of results """
nq = lims.size - 1
mask = D < thresh
new_lims = np.zeros_like(lims)
for i in range(nq):
new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum()
return new_lims, D[mask], I[mask]
def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"):
"""compute the precision and recall of range search results. The
function does not take the distances into account. """
def ref_result_for(i):
return Iref[lims_ref[i]:lims_ref[i + 1]]
def new_result_for(i):
return Inew[lims_new[i]:lims_new[i + 1]]
nq = lims_ref.size - 1
assert lims_new.size - 1 == nq
ninter = np.zeros(nq, dtype="int64")
def compute_PR_for(q):
# ground truth results for this query
gt_ids = ref_result_for(q)
# results for this query
new_ids = new_result_for(q)
# there are no set functions in numpy so let's do this
inter = np.intersect1d(gt_ids, new_ids)
ninter[q] = len(inter)
# run in a thread pool, which helps in spite of the GIL
pool = ThreadPool(20)
pool.map(compute_PR_for, range(nq))
return counts_to_PR(
lims_ref[1:] - lims_ref[:-1],
lims_new[1:] - lims_new[:-1],
ninter,
mode=mode
)
def counts_to_PR(ngt, nres, ninter, mode="overall"):
""" computes a precision-recall for a ser of queries.
ngt = nb of GT results per query
nres = nb of found results per query
ninter = nb of correct results per query (smaller than nres of course)
"""
if mode == "overall":
ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum()
if nres > 0:
precision = ninter / nres
else:
precision = 1.0
if ngt > 0:
recall = ninter / ngt
elif nres == 0:
recall = 1.0
else:
recall = 0.0
return precision, recall
elif mode == "average":
# average precision and recall over queries
mask = ngt == 0
ngt[mask] = 1
recalls = ninter / ngt
recalls[mask] = (nres[mask] == 0).astype(float)
# avoid division by 0
mask = nres == 0
assert np.all(ninter[mask] == 0)
ninter[mask] = 1
nres[mask] = 1
precisions = ninter / nres
return precisions.mean(), recalls.mean()
else:
raise AssertionError()
def sort_range_res_2(lims, D, I):
""" sort 2 arrays using the first as key """
I2 = np.empty_like(I)
D2 = np.empty_like(D)
nq = len(lims) - 1
for i in range(nq):
l0, l1 = lims[i], lims[i + 1]
ii = I[l0:l1]
di = D[l0:l1]
o = di.argsort()
I2[l0:l1] = ii[o]
D2[l0:l1] = di[o]
return I2, D2
def sort_range_res_1(lims, I):
I2 = np.empty_like(I)
nq = len(lims) - 1
for i in range(nq):
l0, l1 = lims[i], lims[i + 1]
I2[l0:l1] = I[l0:l1]
I2[l0:l1].sort()
return I2
def range_PR_multiple_thresholds(
lims_ref, Iref,
lims_new, Dnew, Inew,
thresholds,
mode="overall", do_sort="ref,new"
):
""" compute precision-recall values for range search results
for several thresholds on the "new" results.
This is to plot PR curves
"""
# ref should be sorted by ids
if "ref" in do_sort:
Iref = sort_range_res_1(lims_ref, Iref)
# new should be sorted by distances
if "new" in do_sort:
Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew)
def ref_result_for(i):
return Iref[lims_ref[i]:lims_ref[i + 1]]
def new_result_for(i):
l0, l1 = lims_new[i], lims_new[i + 1]
return Inew[l0:l1], Dnew[l0:l1]
nq = lims_ref.size - 1
assert lims_new.size - 1 == nq
nt = len(thresholds)
counts = np.zeros((nq, nt, 3), dtype="int64")
def compute_PR_for(q):
gt_ids = ref_result_for(q)
res_ids, res_dis = new_result_for(q)
counts[q, :, 0] = len(gt_ids)
if res_dis.size == 0:
# the rest remains at 0
return
# which offsets we are interested in
nres= np.searchsorted(res_dis, thresholds)
counts[q, :, 1] = nres
if gt_ids.size == 0:
return
# find number of TPs at each stage in the result list
ii = np.searchsorted(gt_ids, res_ids)
ii[ii == len(gt_ids)] = -1
n_ok = np.cumsum(gt_ids[ii] == res_ids)
# focus on threshold points
n_ok = np.hstack(([0], n_ok))
counts[q, :, 2] = n_ok[nres]
pool = ThreadPool(20)
pool.map(compute_PR_for, range(nq))
# print(counts.transpose(2, 1, 0))
precisions = np.zeros(nt)
recalls = np.zeros(nt)
for t in range(nt):
p, r = counts_to_PR(
counts[:, t, 0], counts[:, t, 1], counts[:, t, 2],
mode=mode
)
precisions[t] = p
recalls[t] = r
return precisions, recalls
###############################################################
# Functions that compare search results with a reference result.
# They are intended for use in tests
def _cluster_tables_with_tolerance(tab1, tab2, thr):
""" for two tables, cluster them by merging values closer than thr.
Returns the cluster ids for each table element """
tab = np.hstack([tab1, tab2])
tab.sort()
n = len(tab)
diffs = np.ones(n)
diffs[1:] = tab[1:] - tab[:-1]
unique_vals = tab[diffs > thr]
idx1 = np.searchsorted(unique_vals, tab1, side='right') - 1
idx2 = np.searchsorted(unique_vals, tab2, side='right') - 1
return idx1, idx2
def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5):
""" test that knn search results are identical, with possible ties.
Raise if not. """
np.testing.assert_allclose(Dref, Dnew, rtol=rtol)
# here we have to be careful because of draws
testcase = unittest.TestCase() # because it makes nice error messages
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# otherwise collect elements per distance
r = rtol * Dref[i].max()
DrefC, DnewC = _cluster_tables_with_tolerance(Dref[i], Dnew[i], r)
for dis in np.unique(DrefC):
if dis == DrefC[-1]:
continue
mask = DrefC == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
def check_ref_range_results(Lref, Dref, Iref,
Lnew, Dnew, Inew):
""" compare range search results wrt. a reference result,
throw if it fails """
np.testing.assert_array_equal(Lref, Lnew)
nq = len(Lref) - 1
for i in range(nq):
l0, l1 = Lref[i], Lref[i + 1]
Ii_ref = Iref[l0:l1]
Ii_new = Inew[l0:l1]
Di_ref = Dref[l0:l1]
Di_new = Dnew[l0:l1]
if np.all(Ii_ref == Ii_new): # easy
pass
else:
def sort_by_ids(I, D):
o = I.argsort()
return I[o], D[o]
# sort both
(Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref)
(Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new)
np.testing.assert_array_equal(Ii_ref, Ii_new)
np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5)
###############################################################
# OperatingPoints functions
# this is the Python version of the AutoTune object in C++
class OperatingPoints:
"""
Manages a set of search parameters with associated performance and time.
Keeps the Pareto optimal points.
"""
def __init__(self):
# list of (key, perf, t)
self.operating_points = [
# (self.do_nothing_key(), 0.0, 0.0)
]
self.suboptimal_points = []
def compare_keys(self, k1, k2):
""" return -1 if k1 > k2, 1 if k2 > k1, 0 otherwise """
raise NotImplemented
def do_nothing_key(self):
""" parameters to say we do noting, takes 0 time and has 0 performance"""
raise NotImplemented
def is_pareto_optimal(self, perf_new, t_new):
for _, perf, t in self.operating_points:
if perf >= perf_new and t <= t_new:
return False
return True
def predict_bounds(self, key):
""" predicts the bound on time and performance """
min_time = 0.0
max_perf = 1.0
for key2, perf, t in self.operating_points + self.suboptimal_points:
cmp = self.compare_keys(key, key2)
if cmp > 0: # key2 > key
if t > min_time:
min_time = t
if cmp < 0: # key2 < key
if perf < max_perf:
max_perf = perf
return max_perf, min_time
def should_run_experiment(self, key):
(max_perf, min_time) = self.predict_bounds(key)
return self.is_pareto_optimal(max_perf, min_time)
def add_operating_point(self, key, perf, t):
if self.is_pareto_optimal(perf, t):
i = 0
# maybe it shadows some other operating point completely?
while i < len(self.operating_points):
op_Ls, perf2, t2 = self.operating_points[i]
if perf >= perf2 and t < t2:
self.suboptimal_points.append(
self.operating_points.pop(i))
else:
i += 1
self.operating_points.append((key, perf, t))
return True
else:
self.suboptimal_points.append((key, perf, t))
return False
class OperatingPointsWithRanges(OperatingPoints):
"""
Set of parameters that are each picked from a discrete range of values.
An increase of each parameter is assumed to make the operation slower
and more accurate.
A key = int array of indices in the ordered set of parameters.
"""
def __init__(self):
OperatingPoints.__init__(self)
# list of (name, values)
self.ranges = []
def add_range(self, name, values):
self.ranges.append((name, values))
def compare_keys(self, k1, k2):
if np.all(k1 >= k2):
return 1
if np.all(k2 >= k1):
return -1
return 0
def do_nothing_key(self):
return np.zeros(len(self.ranges), dtype=int)
def num_experiments(self):
return int(np.prod([len(values) for name, values in self.ranges]))
def sample_experiments(self, n_autotune, rs=np.random):
""" sample a set of experiments of max size n_autotune
(run all experiments in random order if n_autotune is 0)
"""
assert n_autotune == 0 or n_autotune >= 2
totex = self.num_experiments()
rs = np.random.RandomState(123)
if n_autotune == 0 or totex < n_autotune:
experiments = rs.permutation(totex - 2)
else:
experiments = rs.choice(
totex - 2, size=n_autotune - 2, replace=False)
experiments = [0, totex - 1] + [int(cno) + 1 for cno in experiments]
return experiments
def cno_to_key(self, cno):
"""Convert a sequential experiment number to a key"""
k = np.zeros(len(self.ranges), dtype=int)
for i, (name, values) in enumerate(self.ranges):
k[i] = cno % len(values)
cno //= len(values)
assert cno == 0
return k
def get_parameters(self, k):
"""Convert a key to a dictionary with parameter values"""
return {
name: values[k[i]]
for i, (name, values) in enumerate(self.ranges)
}
def restrict_range(self, name, max_val):
""" remove too large values from a range"""
for name2, values in self.ranges:
if name == name2:
val2 = [v for v in values if v < max_val]
values[:] = val2
return
raise RuntimeError(f"parameter {name} not found")
###############################################################
# Timer object
class TimerIter:
def __init__(self, timer):
self.ts = []
self.runs = timer.runs
self.timer = timer
if timer.nt >= 0:
faiss.omp_set_num_threads(timer.nt)
def __next__(self):
timer = self.timer
self.runs -= 1
self.ts.append(time.time())
total_time = self.ts[-1] - self.ts[0] if len(self.ts) >= 2 else 0
if self.runs == -1 or total_time > timer.max_secs:
if timer.nt >= 0:
faiss.omp_set_num_threads(timer.remember_nt)
ts = np.array(self.ts)
times = ts[1:] - ts[:-1]
if len(times) == timer.runs:
timer.times = times[timer.warmup :]
else:
# if timeout, we use all the runs
timer.times = times[:]
raise StopIteration
class RepeatTimer:
"""
This is yet another timer object. It is adapted to Faiss by
taking a number of openmp threads to set on input. It should be called
in an explicit loop as:
timer = RepeatTimer(warmup=1, nt=1, runs=6)
for _ in timer:
# perform operation
print(f"time={timer.get_ms():.1f} ± {timer.get_ms_std():.1f} ms")
the same timer can be re-used. In that case it is reset each time it
enters a loop. It focuses on ms-scale times because for second scale
it's usually less relevant to repeat the operation.
"""
def __init__(self, warmup=0, nt=-1, runs=1, max_secs=np.inf):
assert warmup < runs
self.warmup = warmup
self.nt = nt
self.runs = runs
self.max_secs = max_secs
self.remember_nt = faiss.omp_get_max_threads()
def __iter__(self):
return TimerIter(self)
def ms(self):
return np.mean(self.times) * 1000
def ms_std(self):
return np.std(self.times) * 1000 if len(self.times) > 1 else 0.0
def nruns(self):
""" effective number of runs (may be lower than runs - warmup due to timeout)"""
return len(self.times)

View File

@@ -0,0 +1,367 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faiss
import time
import numpy as np
import logging
LOG = logging.getLogger(__name__)
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2, shard=False, ngpu=-1):
"""Computes the exact KNN search results for a dataset that possibly
does not fit in RAM but for which we have an iterator that
returns it block by block.
"""
LOG.info("knn_ground_truth queries size %s k=%d" % (xq.shape, k))
t0 = time.time()
nq, d = xq.shape
keep_max = faiss.is_similarity_metric(metric_type)
rh = faiss.ResultHeap(nq, k, keep_max=keep_max)
index = faiss.IndexFlat(d, metric_type)
if ngpu == -1:
ngpu = faiss.get_num_gpus()
if ngpu:
LOG.info('running on %d GPUs' % ngpu)
co = faiss.GpuMultipleClonerOptions()
co.shard = shard
index = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
# compute ground-truth by blocks, and add to heaps
i0 = 0
for xbi in db_iterator:
ni = xbi.shape[0]
index.add(xbi)
D, I = index.search(xq, k)
I += i0
rh.add_result(D, I)
index.reset()
i0 += ni
LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
rh.finalize()
LOG.info("GT time: %.3f s (%d vectors)" % (time.time() - t0, i0))
return rh.D, rh.I
# knn function used to be here
knn = faiss.knn
def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
"""GPU does not support range search, so we emulate it with
knn search + fallback to CPU index.
The index_cpu can either be:
- a CPU index that supports range search
- a numpy table, that will be used to construct a Flat index if needed.
- None. In that case, at most gpu_k results will be returned
"""
nq, d = xq.shape
is_binary_index = isinstance(index_gpu, faiss.IndexBinary)
keep_max = faiss.is_similarity_metric(index_gpu.metric_type)
r2 = int(r2) if is_binary_index else float(r2)
k = min(index_gpu.ntotal, gpu_k)
LOG.debug(
f"GPU search {nq} queries with {k=:} {is_binary_index=:} {keep_max=:}")
t0 = time.time()
D, I = index_gpu.search(xq, k)
t1 = time.time() - t0
if is_binary_index:
assert d * 8 < 32768 # let's compact the distance matrix
D = D.astype('int16')
t2 = 0
lim_remain = None
if index_cpu is not None:
if not keep_max:
mask = D[:, k - 1] < r2
else:
mask = D[:, k - 1] > r2
if mask.sum() > 0:
LOG.debug("CPU search remain %d" % mask.sum())
t0 = time.time()
if isinstance(index_cpu, np.ndarray):
# then it in fact an array that we have to make flat
xb = index_cpu
if is_binary_index:
index_cpu = faiss.IndexBinaryFlat(d * 8)
else:
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
index_cpu.add(xb)
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
if is_binary_index:
D_remain = D_remain.astype('int16')
t2 = time.time() - t0
LOG.debug("combine")
t0 = time.time()
CombinerRangeKNN = (
faiss.CombinerRangeKNNint16 if is_binary_index else
faiss.CombinerRangeKNNfloat
)
combiner = CombinerRangeKNN(nq, k, r2, keep_max)
if True:
sp = faiss.swig_ptr
combiner.I = sp(I)
combiner.D = sp(D)
# combiner.set_knn_result(sp(I), sp(D))
if lim_remain is not None:
combiner.mask = sp(mask)
combiner.D_remain = sp(D_remain)
combiner.lim_remain = sp(lim_remain.view("int64"))
combiner.I_remain = sp(I_remain)
# combiner.set_range_result(sp(mask), sp(lim_remain.view("int64")), sp(D_remain), sp(I_remain))
L_res = np.empty(nq + 1, dtype='int64')
combiner.compute_sizes(sp(L_res))
nres = L_res[-1]
D_res = np.empty(nres, dtype=D.dtype)
I_res = np.empty(nres, dtype='int64')
combiner.write_result(sp(D_res), sp(I_res))
else:
D_res, I_res = [], []
nr = 0
for i in range(nq):
if not mask[i]:
if index_gpu.metric_type == faiss.METRIC_L2:
nv = (D[i, :] < r2).sum()
else:
nv = (D[i, :] > r2).sum()
D_res.append(D[i, :nv])
I_res.append(I[i, :nv])
else:
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
D_res.append(D_remain[l0:l1])
I_res.append(I_remain[l0:l1])
nr += 1
L_res = np.cumsum([0] + [len(di) for di in D_res])
D_res = np.hstack(D_res)
I_res = np.hstack(I_res)
t3 = time.time() - t0
LOG.debug(f"times {t1:.3f}s {t2:.3f}s {t3:.3f}s")
return L_res, D_res, I_res
def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
shard=False, ngpu=-1):
"""Computes the range-search search results for a dataset that possibly
does not fit in RAM but for which we have an iterator that
returns it block by block.
"""
nq, d = xq.shape
t0 = time.time()
xq = np.ascontiguousarray(xq, dtype='float32')
index = faiss.IndexFlat(d, metric_type)
if ngpu == -1:
ngpu = faiss.get_num_gpus()
if ngpu:
LOG.info('running on %d GPUs' % ngpu)
co = faiss.GpuMultipleClonerOptions()
co.shard = shard
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
# compute ground-truth by blocks
i0 = 0
D = [[] for _i in range(nq)]
I = [[] for _i in range(nq)]
for xbi in db_iterator:
ni = xbi.shape[0]
if ngpu > 0:
index_gpu.add(xbi)
lims_i, Di, Ii = range_search_gpu(xq, threshold, index_gpu, xbi)
index_gpu.reset()
else:
index.add(xbi)
lims_i, Di, Ii = index.range_search(xq, threshold)
index.reset()
Ii += i0
for j in range(nq):
l0, l1 = lims_i[j], lims_i[j + 1]
if l1 > l0:
D[j].append(Di[l0:l1])
I[j].append(Ii[l0:l1])
i0 += ni
LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
empty_I = np.zeros(0, dtype='int64')
empty_D = np.zeros(0, dtype='float32')
# import pdb; pdb.set_trace()
D = [(np.hstack(i) if i != [] else empty_D) for i in D]
I = [(np.hstack(i) if i != [] else empty_I) for i in I]
sizes = [len(i) for i in I]
assert len(sizes) == nq
lims = np.zeros(nq + 1, dtype="uint64")
lims[1:] = np.cumsum(sizes)
return lims, np.hstack(D), np.hstack(I)
def threshold_radius_nres(nres, dis, ids, thresh, keep_max=False):
""" select a set of results """
if keep_max:
mask = dis > thresh
else:
mask = dis < thresh
new_nres = np.zeros_like(nres)
o = 0
for i, nr in enumerate(nres):
nr = int(nr) # avoid issues with int64 + uint64
new_nres[i] = mask[o:o + nr].sum()
o += nr
return new_nres, dis[mask], ids[mask]
def threshold_radius(lims, dis, ids, thresh, keep_max=False):
""" restrict range-search results to those below a given radius """
if keep_max:
mask = dis > thresh
else:
mask = dis < thresh
new_lims = np.zeros_like(lims)
n = len(lims) - 1
for i in range(n):
l0, l1 = lims[i], lims[i + 1]
new_lims[i + 1] = new_lims[i] + mask[l0:l1].sum()
return new_lims, dis[mask], ids[mask]
def apply_maxres(res_batches, target_nres, keep_max=False):
"""find radius that reduces number of results to target_nres, and
applies it in-place to the result batches used in
range_search_max_results"""
alldis = np.hstack([dis for _, dis, _ in res_batches])
assert len(alldis) > target_nres
if keep_max:
alldis.partition(len(alldis) - target_nres - 1)
radius = alldis[-1 - target_nres]
else:
alldis.partition(target_nres)
radius = alldis[target_nres]
if alldis.dtype == 'float32':
radius = float(radius)
else:
radius = int(radius)
LOG.debug(' setting radius to %s' % radius)
totres = 0
for i, (nres, dis, ids) in enumerate(res_batches):
nres, dis, ids = threshold_radius_nres(
nres, dis, ids, radius, keep_max=keep_max)
totres += len(dis)
res_batches[i] = nres, dis, ids
LOG.debug(' updated previous results, new nb results %d' % totres)
return radius, totres
def range_search_max_results(index, query_iterator, radius,
max_results=None, min_results=None,
shard=False, ngpu=0, clip_to_min=False):
"""Performs a range search with many queries (given by an iterator)
and adjusts the threshold on-the-fly so that the total results
table does not grow larger than max_results.
If ngpu != 0, the function moves the index to this many GPUs to
speed up search.
"""
# TODO: all result manipulations are in python, should move to C++ if perf
# critical
is_binary_index = isinstance(index, faiss.IndexBinary)
if min_results is None:
assert max_results is not None
min_results = int(0.8 * max_results)
if max_results is None:
assert min_results is not None
max_results = int(min_results * 1.5)
if ngpu == -1:
ngpu = faiss.get_num_gpus()
if ngpu:
LOG.info('running on %d GPUs' % ngpu)
co = faiss.GpuMultipleClonerOptions()
co.shard = shard
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
else:
index_gpu = None
t_start = time.time()
t_search = t_post_process = 0
qtot = totres = raw_totres = 0
res_batches = []
for xqi in query_iterator:
t0 = time.time()
LOG.debug(f"searching {len(xqi)} vectors")
if index_gpu:
lims_i, Di, Ii = range_search_gpu(xqi, radius, index_gpu, index)
else:
lims_i, Di, Ii = index.range_search(xqi, radius)
nres_i = lims_i[1:] - lims_i[:-1]
raw_totres += len(Di)
qtot += len(xqi)
t1 = time.time()
if is_binary_index:
# weird Faiss quirk that returns floats for Hamming distances
Di = Di.astype('int16')
totres += len(Di)
res_batches.append((nres_i, Di, Ii))
if max_results is not None and totres > max_results:
LOG.info('too many results %d > %d, scaling back radius' %
(totres, max_results))
radius, totres = apply_maxres(
res_batches, min_results,
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
)
t2 = time.time()
t_search += t1 - t0
t_post_process += t2 - t1
LOG.debug(' [%.3f s] %d queries done, %d results' % (
time.time() - t_start, qtot, totres))
LOG.info(
'search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
t_search, t_post_process, totres, radius)
)
if clip_to_min and totres > min_results:
radius, totres = apply_maxres(
res_batches, min_results,
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
)
nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches])
ids = np.hstack([ids_i for nres_i, dis_i, ids_i in res_batches])
lims = np.zeros(len(nres) + 1, dtype='uint64')
lims[1:] = np.cumsum(nres)
return radius, lims, dis, ids
def exponential_query_iterator(xq, start_bs=32, max_bs=20000):
""" produces batches of progressively increasing sizes. This is useful to
adjust the search radius progressively without overflowing with
intermediate results """
nq = len(xq)
bs = start_bs
i = 0
while i < nq:
xqi = xq[i:i + bs]
yield xqi
if bs < max_bs:
bs *= 2
i += len(xqi)

View File

@@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faiss
import re
def get_code_size(d, indexkey):
""" size of one vector in an index in dimension d
constructed with factory string indexkey"""
if indexkey == "Flat":
return d * 4
if indexkey.endswith(",RFlat"):
return d * 4 + get_code_size(d, indexkey[:-len(",RFlat")])
mo = re.match("IVF\\d+(_HNSW32)?,(.*)$", indexkey)
if mo:
return get_code_size(d, mo.group(2))
mo = re.match("IVF\\d+\\(.*\\)?,(.*)$", indexkey)
if mo:
return get_code_size(d, mo.group(1))
mo = re.match("IMI\\d+x2,(.*)$", indexkey)
if mo:
return get_code_size(d, mo.group(1))
mo = re.match("(.*),Refine\\((.*)\\)$", indexkey)
if mo:
return get_code_size(d, mo.group(1)) + get_code_size(d, mo.group(2))
mo = re.match('PQ(\\d+)x(\\d+)(fs|fsr)?$', indexkey)
if mo:
return (int(mo.group(1)) * int(mo.group(2)) + 7) // 8
mo = re.match('PQ(\\d+)\\+(\\d+)$', indexkey)
if mo:
return (int(mo.group(1)) + int(mo.group(2)))
mo = re.match('PQ(\\d+)$', indexkey)
if mo:
return int(mo.group(1))
if indexkey == "HNSW32" or indexkey == "HNSW32,Flat":
return d * 4 + 64 * 4 # roughly
if indexkey == 'SQ8':
return d
elif indexkey == 'SQ4':
return (d + 1) // 2
elif indexkey == 'SQ6':
return (d * 6 + 7) // 8
elif indexkey == 'SQfp16':
return d * 2
elif indexkey == 'SQbf16':
return d * 2
mo = re.match('PCAR?(\\d+),(.*)$', indexkey)
if mo:
return get_code_size(int(mo.group(1)), mo.group(2))
mo = re.match('OPQ\\d+_(\\d+),(.*)$', indexkey)
if mo:
return get_code_size(int(mo.group(1)), mo.group(2))
mo = re.match('OPQ\\d+,(.*)$', indexkey)
if mo:
return get_code_size(d, mo.group(1))
mo = re.match('RR(\\d+),(.*)$', indexkey)
if mo:
return get_code_size(int(mo.group(1)), mo.group(2))
raise RuntimeError("cannot parse " + indexkey)
def get_hnsw_M(index):
return index.hnsw.cum_nneighbor_per_level.at(1) // 2
def reverse_index_factory(index):
"""
attempts to get the factory string the index was built with
"""
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexFlat):
return "Flat"
elif isinstance(index, faiss.IndexIVF):
quantizer = faiss.downcast_index(index.quantizer)
if isinstance(quantizer, faiss.IndexFlat):
prefix = f"IVF{index.nlist}"
elif isinstance(quantizer, faiss.MultiIndexQuantizer):
prefix = f"IMI{quantizer.pq.M}x{quantizer.pq.nbits}"
elif isinstance(quantizer, faiss.IndexHNSW):
prefix = f"IVF{index.nlist}_HNSW{get_hnsw_M(quantizer)}"
else:
prefix = f"IVF{index.nlist}({reverse_index_factory(quantizer)})"
if isinstance(index, faiss.IndexIVFFlat):
return prefix + ",Flat"
if isinstance(index, faiss.IndexIVFScalarQuantizer):
return prefix + ",SQ8"
if isinstance(index, faiss.IndexIVFPQ):
return prefix + f",PQ{index.pq.M}x{index.pq.nbits}"
if isinstance(index, faiss.IndexIVFPQFastScan):
return prefix + f",PQ{index.pq.M}x{index.pq.nbits}fs"
elif isinstance(index, faiss.IndexPreTransform):
if index.chain.size() != 1:
raise NotImplementedError()
vt = faiss.downcast_VectorTransform(index.chain.at(0))
if isinstance(vt, faiss.OPQMatrix):
prefix = f"OPQ{vt.M}_{vt.d_out}"
elif isinstance(vt, faiss.ITQTransform):
prefix = f"ITQ{vt.itq.d_out}"
elif isinstance(vt, faiss.PCAMatrix):
assert vt.eigen_power == 0
prefix = "PCA" + ("R" if vt.random_rotation else "") + str(vt.d_out)
else:
raise NotImplementedError()
return f"{prefix},{reverse_index_factory(index.index)}"
elif isinstance(index, faiss.IndexHNSW):
return f"HNSW{get_hnsw_M(index)}"
elif isinstance(index, faiss.IndexRefine):
return f"{reverse_index_factory(index.base_index)},Refine({reverse_index_factory(index.refine_index)})"
elif isinstance(index, faiss.IndexPQFastScan):
return f"PQ{index.pq.M}x{index.pq.nbits}fs"
elif isinstance(index, faiss.IndexPQ):
return f"PQ{index.pq.M}x{index.pq.nbits}"
elif isinstance(index, faiss.IndexLSH):
return "LSH" + ("r" if index.rotate_data else "") + ("t" if index.train_thresholds else "")
elif isinstance(index, faiss.IndexScalarQuantizer):
sqtypes = {
faiss.ScalarQuantizer.QT_8bit: "8",
faiss.ScalarQuantizer.QT_4bit: "4",
faiss.ScalarQuantizer.QT_6bit: "6",
faiss.ScalarQuantizer.QT_fp16: "fp16",
faiss.ScalarQuantizer.QT_bf16: "bf16",
}
return f"SQ{sqtypes[index.sq.qtype]}"
raise NotImplementedError()

View File

@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import faiss
def get_invlist(invlists, l):
""" returns the inverted lists content as a pair of (list_ids, list_codes).
The codes are reshaped to a proper size
"""
invlists = faiss.downcast_InvertedLists(invlists)
ls = invlists.list_size(l)
list_ids = np.zeros(ls, dtype='int64')
ids = codes = None
try:
ids = invlists.get_ids(l)
if ls > 0:
faiss.memcpy(faiss.swig_ptr(list_ids), ids, list_ids.nbytes)
codes = invlists.get_codes(l)
if invlists.code_size != faiss.InvertedLists.INVALID_CODE_SIZE:
list_codes = np.zeros((ls, invlists.code_size), dtype='uint8')
else:
# it's a BlockInvertedLists
npb = invlists.n_per_block
bs = invlists.block_size
ls_round = (ls + npb - 1) // npb
list_codes = np.zeros((ls_round, bs // npb, npb), dtype='uint8')
if ls > 0:
faiss.memcpy(faiss.swig_ptr(list_codes), codes, list_codes.nbytes)
finally:
if ids is not None:
invlists.release_ids(l, ids)
if codes is not None:
invlists.release_codes(l, codes)
return list_ids, list_codes
def get_invlist_sizes(invlists):
""" return the array of sizes of the inverted lists """
return np.array([
invlists.list_size(i)
for i in range(invlists.nlist)
], dtype='int64')
def print_object_fields(obj):
""" list values all fields of an object known to SWIG """
for name in obj.__class__.__swig_getmethods__:
print(f"{name} = {getattr(obj, name)}")
def get_pq_centroids(pq):
""" return the PQ centroids as an array """
cen = faiss.vector_to_array(pq.centroids)
return cen.reshape(pq.M, pq.ksub, pq.dsub)
def get_LinearTransform_matrix(pca):
""" extract matrix + bias from the PCA object
works for any linear transform (OPQ, random rotation, etc.)
"""
b = faiss.vector_to_array(pca.b)
A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in)
return A, b
def make_LinearTransform_matrix(A, b=None):
""" make a linear transform from a matrix and a bias term (optional)"""
d_out, d_in = A.shape
if b is not None:
assert b.shape == (d_out, )
lt = faiss.LinearTransform(d_in, d_out, b is not None)
faiss.copy_array_to_vector(A.ravel(), lt.A)
if b is not None:
faiss.copy_array_to_vector(b, lt.b)
lt.is_trained = True
lt.set_is_orthonormal()
return lt
def get_additive_quantizer_codebooks(aq):
""" return to codebooks of an additive quantizer """
codebooks = faiss.vector_to_array(aq.codebooks).reshape(-1, aq.d)
co = faiss.vector_to_array(aq.codebook_offsets)
return [
codebooks[co[i]:co[i + 1]]
for i in range(aq.M)
]
def get_flat_data(index):
""" copy and return the data matrix in an IndexFlat """
xb = faiss.vector_to_array(index.codes).view("float32")
return xb.reshape(index.ntotal, index.d)
def get_flat_codes(index_flat):
""" get the codes from an indexFlatCodes as an array """
return faiss.vector_to_array(index_flat.codes).reshape(
index_flat.ntotal, index_flat.code_size)
def get_NSG_neighbors(nsg):
""" get the neighbor list for the vectors stored in the NSG structure, as
a N-by-K matrix of indices """
graph = nsg.get_final_graph()
neighbors = np.zeros((graph.N, graph.K), dtype='int32')
faiss.memcpy(
faiss.swig_ptr(neighbors),
graph.data,
neighbors.nbytes
)
return neighbors

View File

@@ -0,0 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import faiss
from faiss.contrib.inspect_tools import get_invlist_sizes
def add_preassigned(index_ivf, x, a, ids=None):
"""
Add elements to an IVF index, where the assignment is already computed
"""
n, d = x.shape
assert a.shape == (n, )
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
assert d == index_ivf.d
if ids is not None:
assert ids.shape == (n, )
ids = faiss.swig_ptr(ids)
index_ivf.add_core(
n, faiss.swig_ptr(x), ids, faiss.swig_ptr(a)
)
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
"""
Perform a search in the IVF index, with predefined lists to search into.
Supports indexes with pretransforms (as opposed to the
IndexIVF.search_preassigned, that cannot be applied with pretransform).
"""
if isinstance(index_ivf, faiss.IndexPreTransform):
assert index_ivf.chain.size() == 1, "chain must have only one component"
transform = faiss.downcast_VectorTransform(index_ivf.chain.at(0))
xq = transform.apply(xq)
index_ivf = faiss.downcast_index(index_ivf.index)
n, d = xq.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
dis_type = "int32"
else:
dis_type = "float32"
assert d == index_ivf.d
assert list_nos.shape == (n, index_ivf.nprobe)
# the coarse distances are used in IVFPQ with L2 distance and
# by_residual=True otherwise we provide dummy coarse_dis
if coarse_dis is None:
coarse_dis = np.zeros((n, index_ivf.nprobe), dtype=dis_type)
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
return index_ivf.search_preassigned(xq, k, list_nos, coarse_dis)
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
"""
Perform a range search in the IVF index, with predefined lists to
search into
"""
n, d = x.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
dis_type = "int32"
else:
dis_type = "float32"
# the coarse distances are used in IVFPQ with L2 distance and
# by_residual=True otherwise we provide dummy coarse_dis
if coarse_dis is None:
coarse_dis = np.empty((n, index_ivf.nprobe), dtype=dis_type)
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
assert d == index_ivf.d
assert list_nos.shape == (n, index_ivf.nprobe)
res = faiss.RangeSearchResult(n)
sp = faiss.swig_ptr
index_ivf.range_search_preassigned_c(
n, sp(x), radius,
sp(list_nos), sp(coarse_dis),
res
)
# get pointers and copy them
lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
num_results = int(lims[-1])
dist = faiss.rev_swig_ptr(res.distances, num_results).copy()
indices = faiss.rev_swig_ptr(res.labels, num_results).copy()
return lims, dist, indices
def replace_ivf_quantizer(index_ivf, new_quantizer):
""" replace the IVF quantizer with a flat quantizer and return the
old quantizer"""
if new_quantizer.ntotal == 0:
centroids = index_ivf.quantizer.reconstruct_n()
new_quantizer.train(centroids)
new_quantizer.add(centroids)
else:
assert new_quantizer.ntotal == index_ivf.nlist
# cleanly dealloc old quantizer
old_own = index_ivf.own_fields
index_ivf.own_fields = False
old_quantizer = faiss.downcast_index(index_ivf.quantizer)
old_quantizer.this.own(old_own)
index_ivf.quantizer = new_quantizer
if hasattr(index_ivf, "referenced_objects"):
index_ivf.referenced_objects.append(new_quantizer)
else:
index_ivf.referenced_objects = [new_quantizer]
return old_quantizer
def permute_invlists(index_ivf, perm):
""" Apply some permutation to the inverted lists, and modify the quantizer
entries accordingly.
Perm is an array of size nlist, where old_index = perm[new_index]
"""
nlist, = perm.shape
assert index_ivf.nlist == nlist
quantizer = faiss.downcast_index(index_ivf.quantizer)
assert quantizer.ntotal == index_ivf.nlist
perm = np.ascontiguousarray(perm, dtype='int64')
# just make sure it's a permutation...
bc = np.bincount(perm, minlength=nlist)
assert np.all(bc == np.ones(nlist, dtype=int))
# handle quantizer
quantizer.permute_entries(perm)
# handle inverted lists
invlists = faiss.downcast_InvertedLists(index_ivf.invlists)
invlists.permute_invlists(faiss.swig_ptr(perm))
def sort_invlists_by_size(index_ivf):
invlist_sizes = get_invlist_sizes(index_ivf.invlists)
perm = np.argsort(invlist_sizes)
permute_invlists(index_ivf, perm)

View File

@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import faiss
import logging
LOG = logging.getLogger(__name__)
def merge_ondisk(
trained_index: faiss.Index, shard_fnames: List[str], ivfdata_fname: str, shift_ids=False
) -> None:
"""Add the contents of the indexes stored in shard_fnames into the index
trained_index. The on-disk data is stored in ivfdata_fname"""
assert not isinstance(
trained_index, faiss.IndexIVFPQR
), "IndexIVFPQR is not supported as an on disk index."
# merge the images into an on-disk index
# first load the inverted lists
ivfs = []
for fname in shard_fnames:
# the IO_FLAG_MMAP is to avoid actually loading the data thus
# the total size of the inverted lists can exceed the
# available RAM
LOG.info("read " + fname)
index = faiss.read_index(fname, faiss.IO_FLAG_MMAP)
index_ivf = faiss.extract_index_ivf(index)
ivfs.append(index_ivf.invlists)
# avoid that the invlists get deallocated with the index
index_ivf.own_invlists = False
# construct the output index
index = trained_index
index_ivf = faiss.extract_index_ivf(index)
assert index.ntotal == 0, "works only on empty index"
# prepare the output inverted lists. They will be written
# to merged_index.ivfdata
invlists = faiss.OnDiskInvertedLists(
index_ivf.nlist, index_ivf.code_size, ivfdata_fname
)
# merge all the inverted lists
ivf_vector = faiss.InvertedListsPtrVector()
for ivf in ivfs:
ivf_vector.push_back(ivf)
LOG.info("merge %d inverted lists " % ivf_vector.size())
ntotal = invlists.merge_from_multiple(ivf_vector.data(), ivf_vector.size(), shift_ids)
# now replace the inverted lists in the output index
index.ntotal = index_ivf.ntotal = ntotal
index_ivf.replace_invlists(invlists, True)
invlists.this.disown()

View File

@@ -0,0 +1,258 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Simplistic RPC implementation.
Exposes all functions of a Server object.
This code is for demonstration purposes only, and does not include certain
security protections. It is not meant to be run on an untrusted network or
in a production environment.
"""
import importlib
import os
import pickle
import sys
import _thread
import traceback
import socket
import logging
LOG = logging.getLogger(__name__)
# default
PORT = 12032
safe_modules = {
'numpy',
'numpy.core.multiarray',
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow safe modules.
if module in safe_modules:
return getattr(importlib.import_module(module), name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
class FileSock:
" wraps a socket so that it is usable by pickle/cPickle "
def __init__(self,sock):
self.sock = sock
self.nr=0
def write(self, buf):
# print("sending %d bytes"%len(buf))
#self.sock.sendall(buf)
# print("...done")
bs = 512 * 1024
ns = 0
while ns < len(buf):
sent = self.sock.send(buf[ns:ns + bs])
ns += sent
def read(self,bs=512*1024):
#if self.nr==10000: pdb.set_trace()
self.nr+=1
# print("read bs=%d"%bs)
b = []
nb = 0
while len(b)<bs:
# print(' loop')
rb = self.sock.recv(bs - nb)
if not rb: break
b.append(rb)
nb += len(rb)
return b''.join(b)
def readline(self):
# print("readline!")
"""may be optimized..."""
s=bytes()
while True:
c=self.read(1)
s+=c
if len(c)==0 or chr(c[0])=='\n':
return s
class ClientExit(Exception):
pass
class ServerException(Exception):
pass
class Server:
"""
server protocol. Methods from classes that subclass Server can be called
transparently from a client
"""
def __init__(self, s, logf=sys.stderr, log_prefix=''):
self.logf = logf
self.log_prefix = log_prefix
# connection
self.conn = s
self.fs = FileSock(s)
def log(self, s):
self.logf.write("Sever log %s: %s\n" % (self.log_prefix, s))
def one_function(self):
"""
Executes a single function with associated I/O.
Protocol:
- the arguments and results are serialized with the pickle protocol
- client sends : (fname,args)
fname = method name to call
args = tuple of arguments
- server sends result: (rid,st,ret)
rid = request id
st = None, or exception if there was during execution
ret = return value or None if st!=None
"""
try:
(fname, args) = RestrictedUnpickler(self.fs).load()
except EOFError:
raise ClientExit("read args")
self.log("executing method %s"%(fname))
st = None
ret = None
try:
f=getattr(self,fname)
except AttributeError:
st = AttributeError("unknown method "+fname)
self.log("unknown method")
try:
ret = f(*args)
except Exception as e:
# due to a bug (in mod_python?), ServerException cannot be
# unpickled, so send the string and make the exception on the client side
#st=ServerException(
# "".join(traceback.format_tb(sys.exc_info()[2]))+
# str(e))
st="".join(traceback.format_tb(sys.exc_info()[2]))+str(e)
self.log("exception in method")
traceback.print_exc(50,self.logf)
self.logf.flush()
LOG.info("return")
try:
pickle.dump((st ,ret), self.fs, protocol=4)
except EOFError:
raise ClientExit("function return")
def exec_loop(self):
""" main execution loop. Loops and handles exit states"""
self.log("in exec_loop")
try:
while True:
self.one_function()
except ClientExit as e:
self.log("ClientExit %s"%e)
except socket.error as e:
self.log("socket error %s"%e)
traceback.print_exc(50,self.logf)
except EOFError:
self.log("EOF during communication")
traceback.print_exc(50,self.logf)
except BaseException:
# unexpected
traceback.print_exc(50,sys.stderr)
sys.exit(1)
LOG.info("exit sever")
def exec_loop_cleanup(self):
pass
###################################################################
# spying stuff
def get_ps_stats(self):
ret=''
f=os.popen("echo ============ `hostname` uptime:; uptime;"+
"echo ============ self:; "+
"ps -p %d -o pid,vsize,rss,%%cpu,nlwp,psr; "%os.getpid()+
"echo ============ run queue:;"+
"ps ar -o user,pid,%cpu,%mem,ni,nlwp,psr,vsz,rss,cputime,command")
for l in f:
ret+=l
return ret
class Client:
"""
Methods of the server object can be called transparently. Exceptions are
re-raised.
"""
def __init__(self, HOST, port=PORT, v6=False):
socktype = socket.AF_INET6 if v6 else socket.AF_INET
sock = socket.socket(socktype, socket.SOCK_STREAM)
LOG.info("connecting to %s:%d, socket type: %s", HOST, port, socktype)
sock.connect((HOST, port))
self.sock = sock
self.fs = FileSock(sock)
def generic_fun(self, fname, args):
# int "gen fun",fname
pickle.dump((fname, args), self.fs, protocol=4)
return self.get_result()
def get_result(self):
(st, ret) = RestrictedUnpickler(self.fs).load()
if st!=None:
raise ServerException(st)
else:
return ret
def __getattr__(self,name):
return lambda *x: self.generic_fun(name,x)
def run_server(new_handler, port=PORT, report_to_file=None, v6=False):
HOST = '' # Symbolic name meaning the local host
socktype = socket.AF_INET6 if v6 else socket.AF_INET
s = socket.socket(socktype, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
LOG.info("bind %s:%d", HOST, port)
s.bind((HOST, port))
s.listen(5)
LOG.info("accepting connections")
if report_to_file is not None:
LOG.info('storing host+port in %s', report_to_file)
open(report_to_file, 'w').write('%s:%d ' % (socket.gethostname(), port))
while True:
try:
conn, addr = s.accept()
except socket.error as e:
if e[1]=='Interrupted system call': continue
raise
LOG.info('Connected to %s', addr)
ibs = new_handler(conn)
tid = _thread.start_new_thread(ibs.exec_loop,())
LOG.debug("Thread ID: %d", tid)

View File

@@ -0,0 +1,6 @@
# The Torch contrib
This contrib directory contains a few Pytorch routines that
are useful for similarity search. They do not necessarily depend on Faiss.
The code is designed to work with CPU and GPU tensors.

View File

View File

@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This contrib module contains Pytorch code for k-means clustering
"""
import faiss
import faiss.contrib.torch_utils
import torch
# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import kmeans
class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""
def __init__(self, x):
self.x = x
def count(self):
return self.x.shape[0]
def dim(self):
return self.x.shape[1]
def get_subset(self, indices):
return self.x[indices]
def perform_search(self, centroids):
return faiss.knn(self.x, centroids, 1)
def assign_to(self, centroids, weights=None):
D, I = self.perform_search(centroids)
I = I.ravel()
D = D.ravel()
nc, d = centroids.shape
sum_per_centroid = torch.zeros_like(centroids)
if weights is None:
sum_per_centroid.index_add_(0, I, self.x)
else:
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])
# the indices are still in numpy.
return I.cpu().numpy(), D, sum_per_centroid
class DatasetAssignGPU(DatasetAssign):
def __init__(self, res, x):
DatasetAssign.__init__(self, x)
self.res = res
def perform_search(self, centroids):
return faiss.knn_gpu(self.res, self.x, centroids, 1)

View File

@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This contrib module contains Pytorch code for quantization.
"""
import torch
import faiss
import math
from faiss.contrib.torch import clustering
# the kmeans can produce both torch and numpy centroids
class Quantizer:
def __init__(self, d, code_size):
"""
d: dimension of vectors
code_size: nb of bytes of the code (per vector)
"""
self.d = d
self.code_size = code_size
def train(self, x):
"""
takes a n-by-d array and peforms training
"""
pass
def encode(self, x):
"""
takes a n-by-d float array, encodes to an n-by-code_size uint8 array
"""
pass
def decode(self, codes):
"""
takes a n-by-code_size uint8 array, returns a n-by-d array
"""
pass
class VectorQuantizer(Quantizer):
def __init__(self, d, k):
code_size = int(math.ceil(torch.log2(k) / 8))
Quantizer.__init__(d, code_size)
self.k = k
def train(self, x):
pass
class ProductQuantizer(Quantizer):
def __init__(self, d, M, nbits):
""" M: number of subvectors, d%M == 0
nbits: number of bits that each vector is encoded into
"""
assert d % M == 0
assert nbits == 8 # todo: implement other nbits values
code_size = int(math.ceil(M * nbits / 8))
Quantizer.__init__(self, d, code_size)
self.M = M
self.nbits = nbits
self.code_size = code_size
def train(self, x):
nc = 2 ** self.nbits
sd = self.d // self.M
dev = x.device
dtype = x.dtype
self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype)
for m in range(self.M):
xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M]
data = clustering.DatasetAssign(xsub.contiguous())
self.codebook[m] = clustering.kmeans(2 ** self.nbits, data)
def encode(self, x):
codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
for m in range(self.M):
xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M]
_, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1)
codes[:, m] = I.ravel()
return codes
def decode(self, codes):
idxs = [codes[:, m].long() for m in range(self.M)]
vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)]
stacked_vectors = torch.stack(vectors, dim=1)
cbd = self.codebook.shape[-1]
x_rec = stacked_vectors.reshape(-1, cbd * self.M)
return x_rec

View File

@@ -0,0 +1,764 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This is a set of function wrappers that override the default numpy versions.
Interoperability functions for pytorch and Faiss: Importing this will allow
pytorch Tensors (CPU or GPU) to be used as arguments to Faiss indexes and
other functions. Torch GPU tensors can only be used with Faiss GPU indexes.
If this is imported with a package that supports Faiss GPU, the necessary
stream synchronization with the current pytorch stream will be automatically
performed.
Numpy ndarrays can continue to be used in the Faiss python interface after
importing this file. All arguments must be uniformly either numpy ndarrays
or Torch tensors; no mixing is allowed.
"""
import faiss
import torch
import contextlib
import inspect
import sys
import numpy as np
##################################################################
# Equivalent of swig_ptr for Torch tensors
##################################################################
def swig_ptr_from_UInt8Tensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.uint8
return faiss.cast_integer_to_uint8_ptr(
x.untyped_storage().data_ptr() + x.storage_offset())
def swig_ptr_from_HalfTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.float16
# no canonical half type in C/C++
return faiss.cast_integer_to_void_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 2)
def swig_ptr_from_FloatTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.float32
return faiss.cast_integer_to_float_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 4)
def swig_ptr_from_BFloat16Tensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.bfloat16
return faiss.cast_integer_to_void_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 2)
def swig_ptr_from_IntTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.int32, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_int_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 4)
def swig_ptr_from_IndicesTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_idx_t_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 8)
##################################################################
# utilities
##################################################################
@contextlib.contextmanager
def using_stream(res, pytorch_stream=None):
""" Creates a scoping object to make Faiss GPU use the same stream
as pytorch, based on torch.cuda.current_stream().
Or, a specific pytorch stream can be passed in as a second
argument, in which case we will use that stream.
"""
if pytorch_stream is None:
pytorch_stream = torch.cuda.current_stream()
# This is the cudaStream_t that we wish to use
cuda_stream_s = faiss.cast_integer_to_cudastream_t(pytorch_stream.cuda_stream)
# So we can revert GpuResources stream state upon exit
prior_dev = torch.cuda.current_device()
prior_stream = res.getDefaultStream(torch.cuda.current_device())
res.setDefaultStream(torch.cuda.current_device(), cuda_stream_s)
# Do the user work
try:
yield
finally:
res.setDefaultStream(prior_dev, prior_stream)
def torch_replace_method(the_class, name, replacement,
ignore_missing=False, ignore_no_base=False):
try:
orig_method = getattr(the_class, name)
except AttributeError:
if ignore_missing:
return
raise
if orig_method.__name__ == 'torch_replacement_' + name:
# replacement was done in parent class
return
# We should already have the numpy replacement methods patched
assert ignore_no_base or (orig_method.__name__ == 'replacement_' + name)
setattr(the_class, name + '_numpy', orig_method)
setattr(the_class, name, replacement)
##################################################################
# Setup wrappers
##################################################################
def handle_torch_Index(the_class):
def torch_replacement_add(self, x):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.add_numpy(x)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.add_c(n, x_ptr)
else:
# CPU torch
self.add_c(n, x_ptr)
def torch_replacement_add_with_ids(self, x, ids):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.add_with_ids_numpy(x, ids)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
assert type(ids) is torch.Tensor
assert ids.shape == (n, ), 'not same number of vectors as ids'
ids_ptr = swig_ptr_from_IndicesTensor(ids)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.add_with_ids_c(n, x_ptr, ids_ptr)
else:
# CPU torch
self.add_with_ids_c(n, x_ptr, ids_ptr)
def torch_replacement_assign(self, x, k, labels=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.assign_numpy(x, k, labels)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if labels is None:
labels = torch.empty(n, k, device=x.device, dtype=torch.int64)
else:
assert type(labels) is torch.Tensor
assert labels.shape == (n, k)
L_ptr = swig_ptr_from_IndicesTensor(labels)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.assign_c(n, x_ptr, L_ptr, k)
else:
# CPU torch
self.assign_c(n, x_ptr, L_ptr, k)
return labels
def torch_replacement_train(self, x):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.train_numpy(x)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.train_c(n, x_ptr)
else:
# CPU torch
self.train_c(n, x_ptr)
def search_methods_common(x, k, D, I):
n, d = x.shape
x_ptr = swig_ptr_from_FloatTensor(x)
if D is None:
D = torch.empty(n, k, device=x.device, dtype=torch.float32)
else:
assert type(D) is torch.Tensor
assert D.shape == (n, k)
D_ptr = swig_ptr_from_FloatTensor(D)
if I is None:
I = torch.empty(n, k, device=x.device, dtype=torch.int64)
else:
assert type(I) is torch.Tensor
assert I.shape == (n, k)
I_ptr = swig_ptr_from_IndicesTensor(I)
return x_ptr, D_ptr, I_ptr, D, I
def torch_replacement_search(self, x, k, D=None, I=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.search_numpy(x, k, D=D, I=I)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.search_c(n, x_ptr, k, D_ptr, I_ptr)
else:
# CPU torch
self.search_c(n, x_ptr, k, D_ptr, I_ptr)
return D, I
def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None):
if type(x) is np.ndarray:
# Forward to faiss __init__.py base method
return self.search_and_reconstruct_numpy(x, k, D=D, I=I, R=R)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
if R is None:
R = torch.empty(n, k, d, device=x.device, dtype=torch.float32)
else:
assert type(R) is torch.Tensor
assert R.shape == (n, k, d)
R_ptr = swig_ptr_from_FloatTensor(R)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr)
else:
# CPU torch
self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr)
return D, I, R
def torch_replacement_search_preassigned(self, x, k, Iq, Dq, *, D=None, I=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.search_preassigned_numpy(x, k, Iq, Dq, D=D, I=I)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
assert Iq.shape == (n, self.nprobe)
Iq = Iq.contiguous()
Iq_ptr = swig_ptr_from_IndicesTensor(Iq)
if Dq is not None:
Dq = Dq.contiguous()
assert Dq.shape == Iq.shape
Dq_ptr = swig_ptr_from_FloatTensor(Dq)
else:
Dq_ptr = None
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
else:
# CPU torch
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
return D, I
def torch_replacement_remove_ids(self, x):
# Not yet implemented
assert type(x) is not torch.Tensor, 'remove_ids not yet implemented for torch'
return self.remove_ids_numpy(x)
def torch_replacement_reconstruct(self, key, x=None):
# No tensor inputs are required, but with importing this module, we
# assume that the default should be torch tensors. If we are passed a
# numpy array, however, assume that the user is overriding this default
if (x is not None) and (type(x) is np.ndarray):
# Forward to faiss __init__.py base method
return self.reconstruct_numpy(key, x)
# If the index is a CPU index, the default device is CPU, otherwise we
# produce a GPU tensor
device = torch.device('cpu')
if hasattr(self, 'getDevice'):
# same device as the index
device = torch.device('cuda', self.getDevice())
if x is None:
x = torch.empty(self.d, device=device, dtype=torch.float32)
else:
assert type(x) is torch.Tensor
assert x.shape == (self.d, )
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.reconstruct_c(key, x_ptr)
else:
# CPU torch
self.reconstruct_c(key, x_ptr)
return x
def torch_replacement_reconstruct_n(self, n0=0, ni=-1, x=None):
if ni == -1:
ni = self.ntotal
# No tensor inputs are required, but with importing this module, we
# assume that the default should be torch tensors. If we are passed a
# numpy array, however, assume that the user is overriding this default
if (x is not None) and (type(x) is np.ndarray):
# Forward to faiss __init__.py base method
return self.reconstruct_n_numpy(n0, ni, x)
# If the index is a CPU index, the default device is CPU, otherwise we
# produce a GPU tensor
device = torch.device('cpu')
if hasattr(self, 'getDevice'):
# same device as the index
device = torch.device('cuda', self.getDevice())
if x is None:
x = torch.empty(ni, self.d, device=device, dtype=torch.float32)
else:
assert type(x) is torch.Tensor
assert x.shape == (ni, self.d)
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.reconstruct_n_c(n0, ni, x_ptr)
else:
# CPU torch
self.reconstruct_n_c(n0, ni, x_ptr)
return x
def torch_replacement_update_vectors(self, keys, x):
if type(keys) is np.ndarray:
# Forward to faiss __init__.py base method
return self.update_vectors_numpy(keys, x)
assert type(keys) is torch.Tensor
(n, ) = keys.shape
keys_ptr = swig_ptr_from_IndicesTensor(keys)
assert type(x) is torch.Tensor
assert x.shape == (n, self.d)
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.update_vectors_c(n, keys_ptr, x_ptr)
else:
# CPU torch
self.update_vectors_c(n, keys_ptr, x_ptr)
# Until the GPU version is implemented, we do not support pre-allocated
# output buffers
def torch_replacement_range_search(self, x, thresh):
if type(x) is np.ndarray:
# Forward to faiss __init__.py base method
return self.range_search_numpy(x, thresh)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
assert not x.is_cuda, 'Range search using GPU tensor not yet implemented'
assert not hasattr(self, 'getDevice'), 'Range search on GPU index not yet implemented'
res = faiss.RangeSearchResult(n)
self.range_search_c(n, x_ptr, thresh, res)
# get pointers and copy them
# FIXME: no rev_swig_ptr equivalent for torch.Tensor, just convert
# np to torch
# NOTE: torch does not support np.uint64, just np.int64
lims = torch.from_numpy(faiss.rev_swig_ptr(res.lims, n + 1).copy().astype('int64'))
nd = int(lims[-1])
D = torch.from_numpy(faiss.rev_swig_ptr(res.distances, nd).copy())
I = torch.from_numpy(faiss.rev_swig_ptr(res.labels, nd).copy())
return lims, D, I
def torch_replacement_sa_encode(self, x, codes=None):
if type(x) is np.ndarray:
# Forward to faiss __init__.py base method
return self.sa_encode_numpy(x, codes)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if codes is None:
codes = torch.empty(n, self.sa_code_size(), dtype=torch.uint8)
else:
assert codes.shape == (n, self.sa_code_size())
codes_ptr = swig_ptr_from_UInt8Tensor(codes)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.sa_encode_c(n, x_ptr, codes_ptr)
else:
# CPU torch
self.sa_encode_c(n, x_ptr, codes_ptr)
return codes
def torch_replacement_sa_decode(self, codes, x=None):
if type(codes) is np.ndarray:
# Forward to faiss __init__.py base method
return self.sa_decode_numpy(codes, x)
assert type(codes) is torch.Tensor
n, cs = codes.shape
assert cs == self.sa_code_size()
codes_ptr = swig_ptr_from_UInt8Tensor(codes)
if x is None:
x = torch.empty(n, self.d, dtype=torch.float32)
else:
assert type(x) is torch.Tensor
assert x.shape == (n, self.d)
x_ptr = swig_ptr_from_FloatTensor(x)
if codes.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.sa_decode_c(n, codes_ptr, x_ptr)
else:
# CPU torch
self.sa_decode_c(n, codes_ptr, x_ptr)
return x
torch_replace_method(the_class, 'add', torch_replacement_add)
torch_replace_method(the_class, 'add_with_ids', torch_replacement_add_with_ids)
torch_replace_method(the_class, 'assign', torch_replacement_assign)
torch_replace_method(the_class, 'train', torch_replacement_train)
torch_replace_method(the_class, 'search', torch_replacement_search)
torch_replace_method(the_class, 'remove_ids', torch_replacement_remove_ids)
torch_replace_method(the_class, 'reconstruct', torch_replacement_reconstruct)
torch_replace_method(the_class, 'reconstruct_n', torch_replacement_reconstruct_n)
torch_replace_method(the_class, 'range_search', torch_replacement_range_search)
torch_replace_method(the_class, 'update_vectors', torch_replacement_update_vectors,
ignore_missing=True)
torch_replace_method(the_class, 'search_and_reconstruct',
torch_replacement_search_and_reconstruct, ignore_missing=True)
torch_replace_method(the_class, 'search_preassigned',
torch_replacement_search_preassigned, ignore_missing=True)
torch_replace_method(the_class, 'sa_encode', torch_replacement_sa_encode)
torch_replace_method(the_class, 'sa_decode', torch_replacement_sa_decode)
faiss_module = sys.modules['faiss']
# Re-patch anything that inherits from faiss.Index to add the torch bindings
for symbol in dir(faiss_module):
obj = getattr(faiss_module, symbol)
if inspect.isclass(obj):
the_class = obj
if issubclass(the_class, faiss.Index):
handle_torch_Index(the_class)
# allows torch tensor usage with knn
def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
if type(xb) is np.ndarray:
# Forward to faiss __init__.py base method
return faiss.knn_numpy(xq, xb, k, metric=metric, metric_arg=metric_arg)
nb, d = xb.size()
assert xb.is_contiguous()
assert xb.dtype == torch.float32
assert not xb.is_cuda, "use knn_gpu for GPU tensors"
nq, d2 = xq.size()
assert d2 == d
assert xq.is_contiguous()
assert xq.dtype == torch.float32
assert not xq.is_cuda, "use knn_gpu for GPU tensors"
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
I_ptr = swig_ptr_from_IndicesTensor(I)
D_ptr = swig_ptr_from_FloatTensor(D)
xb_ptr = swig_ptr_from_FloatTensor(xb)
xq_ptr = swig_ptr_from_FloatTensor(xq)
if metric == faiss.METRIC_L2:
faiss.knn_L2sqr(
xq_ptr, xb_ptr,
d, nq, nb, k, D_ptr, I_ptr
)
elif metric == faiss.METRIC_INNER_PRODUCT:
faiss.knn_inner_product(
xq_ptr, xb_ptr,
d, nq, nb, k, D_ptr, I_ptr
)
else:
faiss.knn_extra_metrics(
xq_ptr, xb_ptr,
d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr
)
return D, I
torch_replace_method(faiss_module, 'knn', torch_replacement_knn, True, True)
# allows torch tensor usage with bfKnn
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_cuvs=False):
if type(xb) is np.ndarray:
# Forward to faiss __init__.py base method
return faiss.knn_gpu_numpy(res, xq, xb, k, D, I, metric, device)
nb, d = xb.size()
if xb.is_contiguous():
xb_row_major = True
elif xb.t().is_contiguous():
xb = xb.t()
xb_row_major = False
else:
raise TypeError('matrix should be row or column-major')
if xb.dtype == torch.float32:
xb_type = faiss.DistanceDataType_F32
xb_ptr = swig_ptr_from_FloatTensor(xb)
elif xb.dtype == torch.float16:
xb_type = faiss.DistanceDataType_F16
xb_ptr = swig_ptr_from_HalfTensor(xb)
elif xb.dtype == torch.bfloat16:
xb_type = faiss.DistanceDataType_BF16
xb_ptr = swig_ptr_from_BFloat16Tensor(xb)
else:
raise TypeError('xq must be float32, float16 or bfloat16')
nq, d2 = xq.size()
assert d2 == d
if xq.is_contiguous():
xq_row_major = True
elif xq.t().is_contiguous():
xq = xq.t()
xq_row_major = False
else:
raise TypeError('matrix should be row or column-major')
if xq.dtype == torch.float32:
xq_type = faiss.DistanceDataType_F32
xq_ptr = swig_ptr_from_FloatTensor(xq)
elif xq.dtype == torch.float16:
xq_type = faiss.DistanceDataType_F16
xq_ptr = swig_ptr_from_HalfTensor(xq)
elif xq.dtype == torch.bfloat16:
xq_type = faiss.DistanceDataType_BF16
xq_ptr = swig_ptr_from_BFloat16Tensor(xq)
else:
raise TypeError('xq must be float32, float16 or bfloat16')
if D is None:
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
else:
assert D.shape == (nq, k)
# interface takes void*, we need to check this
assert (D.dtype == torch.float32)
if I is None:
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
else:
assert I.shape == (nq, k)
if I.dtype == torch.int64:
I_type = faiss.IndicesDataType_I64
I_ptr = swig_ptr_from_IndicesTensor(I)
elif I.dtype == I.dtype == torch.int32:
I_type = faiss.IndicesDataType_I32
I_ptr = swig_ptr_from_IntTensor(I)
else:
raise TypeError('I must be i64 or i32')
D_ptr = swig_ptr_from_FloatTensor(D)
args = faiss.GpuDistanceParams()
args.metric = metric
args.k = k
args.dims = d
args.vectors = xb_ptr
args.vectorsRowMajor = xb_row_major
args.vectorType = xb_type
args.numVectors = nb
args.queries = xq_ptr
args.queriesRowMajor = xq_row_major
args.queryType = xq_type
args.numQueries = nq
args.outDistances = D_ptr
args.outIndices = I_ptr
args.outIndicesType = I_type
args.device = device
args.use_cuvs = use_cuvs
with using_stream(res):
faiss.bfKnn(res, args)
return D, I
torch_replace_method(faiss_module, 'knn_gpu', torch_replacement_knn_gpu, True, True)
# allows torch tensor usage with bfKnn for all pairwise distances
def torch_replacement_pairwise_distance_gpu(res, xq, xb, D=None, metric=faiss.METRIC_L2, device=-1):
if type(xb) is np.ndarray:
# Forward to faiss __init__.py base method
return faiss.pairwise_distance_gpu_numpy(res, xq, xb, D, metric)
nb, d = xb.size()
if xb.is_contiguous():
xb_row_major = True
elif xb.t().is_contiguous():
xb = xb.t()
xb_row_major = False
else:
raise TypeError('xb matrix should be row or column-major')
if xb.dtype == torch.float32:
xb_type = faiss.DistanceDataType_F32
xb_ptr = swig_ptr_from_FloatTensor(xb)
elif xb.dtype == torch.float16:
xb_type = faiss.DistanceDataType_F16
xb_ptr = swig_ptr_from_HalfTensor(xb)
else:
raise TypeError('xb must be float32 or float16')
nq, d2 = xq.size()
assert d2 == d
if xq.is_contiguous():
xq_row_major = True
elif xq.t().is_contiguous():
xq = xq.t()
xq_row_major = False
else:
raise TypeError('xq matrix should be row or column-major')
if xq.dtype == torch.float32:
xq_type = faiss.DistanceDataType_F32
xq_ptr = swig_ptr_from_FloatTensor(xq)
elif xq.dtype == torch.float16:
xq_type = faiss.DistanceDataType_F16
xq_ptr = swig_ptr_from_HalfTensor(xq)
else:
raise TypeError('xq must be float32 or float16')
if D is None:
D = torch.empty(nq, nb, device=xb.device, dtype=torch.float32)
else:
assert D.shape == (nq, nb)
# interface takes void*, we need to check this
assert (D.dtype == torch.float32)
D_ptr = swig_ptr_from_FloatTensor(D)
args = faiss.GpuDistanceParams()
args.metric = metric
args.k = -1 # selects all pairwise distance
args.dims = d
args.vectors = xb_ptr
args.vectorsRowMajor = xb_row_major
args.vectorType = xb_type
args.numVectors = nb
args.queries = xq_ptr
args.queriesRowMajor = xq_row_major
args.queryType = xq_type
args.numQueries = nq
args.outDistances = D_ptr
args.device = device
with using_stream(res):
faiss.bfKnn(res, args)
return D
torch_replace_method(faiss_module, 'pairwise_distance_gpu', torch_replacement_pairwise_distance_gpu, True, True)

View File

@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import numpy as np
"""
I/O functions in fvecs, bvecs, ivecs formats
definition of the formats here: http://corpus-texmex.irisa.fr/
"""
def ivecs_read(fname):
a = np.fromfile(fname, dtype='int32')
if sys.byteorder == 'big':
a.byteswap(inplace=True)
d = a[0]
return a.reshape(-1, d + 1)[:, 1:].copy()
def fvecs_read(fname):
return ivecs_read(fname).view('float32')
def ivecs_mmap(fname):
assert sys.byteorder != 'big'
a = np.memmap(fname, dtype='int32', mode='r')
d = a[0]
return a.reshape(-1, d + 1)[:, 1:]
def fvecs_mmap(fname):
return ivecs_mmap(fname).view('float32')
def bvecs_mmap(fname):
x = np.memmap(fname, dtype='uint8', mode='r')
if sys.byteorder == 'big':
da = x[:4][::-1].copy()
d = da.view('int32')[0]
else:
d = x[:4].view('int32')[0]
return x.reshape(-1, d + 4)[:, 4:]
def ivecs_write(fname, m):
n, d = m.shape
m1 = np.empty((n, d + 1), dtype='int32')
m1[:, 0] = d
m1[:, 1:] = m
if sys.byteorder == 'big':
m1.byteswap(inplace=True)
m1.tofile(fname)
def fvecs_write(fname, m):
m = m.astype('float32')
ivecs_write(fname, m.view('int32'))