Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,52 @@
# Offline IVF
This folder contains the code for the offline ivf algorithm powered by faiss big batch search.
Create a conda env:
`conda create --name oivf python=3.10`
`conda activate oivf`
`conda install -c pytorch/label/nightly -c nvidia faiss-gpu=1.7.4`
`conda install tqdm`
`conda install pyyaml`
`conda install -c conda-forge submitit`
## Run book
1. Optionally shard your dataset (see create_sharded_dataset.py) and create the corresponding yaml file `config_ssnpp.yaml`. You can use `generate_config.py` by specifying the root directory of your dataset and the files with the data shards
`python generate_config`
2. Run the train index command
`python run.py --command train_index --config config_ssnpp.yaml --xb ssnpp_1B`
3. Run the index-shard command so it produces sharded indexes, required for the search step
`python run.py --command index_shard --config config_ssnpp.yaml --xb ssnpp_1B`
6. Send jobs to the cluster to run search
`python run.py --command search --config config_ssnpp.yaml --xb ssnpp_1B --cluster_run --partition <PARTITION-NAME>`
Remarks about the `search` command: it is assumed that the database vectors are the query vectors when performing the search step.
a. If the query vectors are different than the database vectors, it should be passed in the xq argument
b. A new dataset needs to be prepared (step 1) before passing it to the query vectors argument `xq`
`python run.py --command search --config config_ssnpp.yaml --xb ssnpp_1B --xq <QUERIES_DATASET_NAME>`
6. We can always run the consistency-check for sanity checks!
`python run.py --command consistency_check--config config_ssnpp.yaml --xb ssnpp_1B`

View File

View File

@@ -0,0 +1,110 @@
d: 256
output: /checkpoint/marialomeli/offline_faiss/ssnpp
index:
prod:
- 'IVF8192,PQ128'
non-prod:
- 'IVF16384,PQ128'
- 'IVF32768,PQ128'
- 'OPQ64_128,IVF4096,PQ64'
nprobe:
prod:
- 512
non-prod:
- 256
- 128
- 1024
- 2048
- 4096
- 8192
k: 50
index_shard_size: 50000000
query_batch_size: 50000000
evaluation_sample: 10000
training_sample: 1572864
datasets:
ssnpp_1B:
root: /checkpoint/marialomeli/ssnpp_data
size: 1000000000
files:
- dtype: uint8
format: npy
name: ssnpp_0000000000.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000001.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000002.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000003.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000004.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000005.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000006.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000007.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000008.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000009.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000010.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000011.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000012.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000013.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000014.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000015.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000016.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000017.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000018.npy
size: 50000000
- dtype: uint8
format: npy
name: ssnpp_0000000019.npy
size: 50000000

View File

@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import argparse
import os
def xbin_mmap(fname, dtype, maxn=-1):
"""
Code from
https://github.com/harsha-simhadri/big-ann-benchmarks/blob/main/benchmark/dataset_io.py#L94
mmap the competition file format for a given type of items
"""
n, d = map(int, np.fromfile(fname, dtype="uint32", count=2))
assert os.stat(fname).st_size == 8 + n * d * np.dtype(dtype).itemsize
if maxn > 0:
n = min(n, maxn)
return np.memmap(fname, dtype=dtype, mode="r", offset=8, shape=(n, d))
def main(args: argparse.Namespace):
ssnpp_data = xbin_mmap(fname=args.filepath, dtype="uint8")
num_batches = ssnpp_data.shape[0] // args.data_batch
assert (
ssnpp_data.shape[0] % args.data_batch == 0
), "num of embeddings per file should divide total num of embeddings"
for i in range(num_batches):
xb_batch = ssnpp_data[
i * args.data_batch:(i + 1) * args.data_batch, :
]
filename = args.output_dir + f"/ssnpp_{(i):010}.npy"
np.save(filename, xb_batch)
print(f"File {filename} is saved!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_batch",
dest="data_batch",
type=int,
default=50000000,
help="Number of embeddings per file, should be a divisor of 1B",
)
parser.add_argument(
"--filepath",
dest="filepath",
type=str,
default="/datasets01/big-ann-challenge-data/FB_ssnpp/FB_ssnpp_database.u8bin",
help="path of 1B ssnpp database vectors' original file",
)
parser.add_argument(
"--filepath",
dest="output_dir",
type=str,
default="/checkpoint/marialomeli/ssnpp_data",
help="path to put sharded files",
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import faiss
from typing import List
import random
import logging
from functools import lru_cache
def create_dataset_from_oivf_config(cfg, ds_name):
normalise = cfg["normalise"] if "normalise" in cfg else False
return MultiFileVectorDataset(
cfg["datasets"][ds_name]["root"],
[
FileDescriptor(
f["name"], f["format"], np.dtype(f["dtype"]), f["size"]
)
for f in cfg["datasets"][ds_name]["files"]
],
cfg["d"],
normalise,
cfg["datasets"][ds_name]["size"],
)
@lru_cache(maxsize=100)
def _memmap_vecs(
file_name: str, format: str, dtype: np.dtype, size: int, d: int
) -> np.array:
"""
If the file is in raw format, the file size will
be divisible by the dimensionality and by the size
of the data type.
Otherwise,the file contains a header and we assume
it is of .npy type. It the returns the memmapped file.
"""
assert os.path.exists(file_name), f"file does not exist {file_name}"
if format == "raw":
fl = os.path.getsize(file_name)
nb = fl // d // dtype.itemsize
assert nb == size, f"{nb} is different than config's {size}"
assert fl == d * dtype.itemsize * nb # no header
return np.memmap(file_name, shape=(nb, d), dtype=dtype, mode="r")
elif format == "npy":
vecs = np.load(file_name, mmap_mode="r")
assert vecs.shape[0] == size, f"size:{size},shape {vecs.shape[0]}"
assert vecs.shape[1] == d
assert vecs.dtype == dtype
return vecs
else:
ValueError("The file cannot be loaded in the current format.")
class FileDescriptor:
def __init__(self, name: str, format: str, dtype: np.dtype, size: int):
self.name = name
self.format = format
self.dtype = dtype
self.size = size
class MultiFileVectorDataset:
def __init__(
self,
root: str,
file_descriptors: List[FileDescriptor],
d: int,
normalize: bool,
size: int,
):
assert os.path.exists(root)
self.root = root
self.file_descriptors = file_descriptors
self.d = d
self.normalize = normalize
self.size = size
self.file_offsets = [0]
t = 0
for f in self.file_descriptors:
xb = _memmap_vecs(
f"{self.root}/{f.name}", f.format, f.dtype, f.size, self.d
)
t += xb.shape[0]
self.file_offsets.append(t)
assert (
t == self.size
), "the sum of num of embeddings per file!=total num of embeddings"
def iterate(self, start: int, batch_size: int, dt: np.dtype):
buffer = np.empty(shape=(batch_size, self.d), dtype=dt)
rem = 0
for f in self.file_descriptors:
if start >= f.size:
start -= f.size
continue
logging.info(f"processing: {f.name}...")
xb = _memmap_vecs(
f"{self.root}/{f.name}",
f.format,
f.dtype,
f.size,
self.d,
)
if start > 0:
xb = xb[start:]
start = 0
req = min(batch_size - rem, xb.shape[0])
buffer[rem:rem + req] = xb[:req]
rem += req
if rem == batch_size:
if self.normalize:
faiss.normalize_L2(buffer)
yield buffer.copy()
rem = 0
for i in range(req, xb.shape[0], batch_size):
j = i + batch_size
if j <= xb.shape[0]:
tmp = xb[i:j].astype(dt)
if self.normalize:
faiss.normalize_L2(tmp)
yield tmp
else:
rem = xb.shape[0] - i
buffer[:rem] = xb[i:j]
if rem > 0:
tmp = buffer[:rem]
if self.normalize:
faiss.normalize_L2(tmp)
yield tmp
def get(self, idx: List[int]):
n = len(idx)
fidx = np.searchsorted(self.file_offsets, idx, "right")
res = np.empty(shape=(len(idx), self.d), dtype=np.float32)
for r, id, fid in zip(range(n), idx, fidx):
assert fid > 0 and fid <= len(self.file_descriptors), f"{fid}"
f = self.file_descriptors[fid - 1]
# deferring normalization until after reading the vec
vecs = _memmap_vecs(
f"{self.root}/{f.name}", f.format, f.dtype, f.size, self.d
)
i = id - self.file_offsets[fid - 1]
assert i >= 0 and i < vecs.shape[0]
res[r, :] = vecs[i] # TODO: find a faster way
if self.normalize:
faiss.normalize_L2(res)
return res
def sample(self, n, idx_fn, vecs_fn):
if vecs_fn and os.path.exists(vecs_fn):
vecs = np.load(vecs_fn)
assert vecs.shape == (n, self.d)
return vecs
if idx_fn and os.path.exists(idx_fn):
idx = np.load(idx_fn)
assert idx.size == n
else:
idx = np.array(sorted(random.sample(range(self.size), n)))
if idx_fn:
np.save(idx_fn, idx)
vecs = self.get(idx)
if vecs_fn:
np.save(vecs_fn, vecs)
return vecs
def get_first_n(self, n, dt):
assert n <= self.size
return next(self.iterate(0, n, dt))

View File

@@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import os
import yaml
# with ssnpp sharded data
root = "/checkpoint/marialomeli/ssnpp_data"
file_names = [f"ssnpp_{i:010}.npy" for i in range(20)]
d = 256
dt = np.dtype(np.uint8)
def read_embeddings(fp):
fl = os.path.getsize(fp)
nb = fl // d // dt.itemsize
print(nb)
if fl == d * dt.itemsize * nb: # no header
return ("raw", np.memmap(fp, shape=(nb, d), dtype=dt, mode="r"))
else: # assume npy
vecs = np.load(fp, mmap_mode="r")
assert vecs.shape[1] == d
assert vecs.dtype == dt
return ("npy", vecs)
cfg = {}
files = []
size = 0
for fn in file_names:
fp = f"{root}/{fn}"
assert os.path.exists(fp), f"{fp} is missing"
ft, xb = read_embeddings(fp)
files.append(
{"name": fn, "size": xb.shape[0], "dtype": dt.name, "format": ft}
)
size += xb.shape[0]
cfg["size"] = size
cfg["root"] = root
cfg["d"] = d
cfg["files"] = files
print(yaml.dump(cfg))

View File

@@ -0,0 +1,891 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faiss
import numpy as np
import os
from tqdm import tqdm, trange
import sys
import logging
from faiss.contrib.ondisk import merge_ondisk
from faiss.contrib.big_batch_search import big_batch_search
from faiss.contrib.exhaustive_search import knn_ground_truth
from faiss.contrib.evaluation import knn_intersection_measure
from utils import (
get_intersection_cardinality_frequencies,
margin,
is_pretransform_index,
)
from dataset import create_dataset_from_oivf_config
logging.basicConfig(
format=(
"%(asctime)s.%(msecs)03d %(levelname)-8s %(threadName)-12s %(message)s"
),
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
EMBEDDINGS_BATCH_SIZE: int = 100_000
NUM_SUBSAMPLES: int = 100
SMALL_DATA_SAMPLE: int = 10000
class OfflineIVF:
def __init__(self, cfg, args, nprobe, index_factory_str):
self.input_d = cfg["d"]
self.dt = cfg["datasets"][args.xb]["files"][0]["dtype"]
assert self.input_d > 0
output_dir = cfg["output"]
assert os.path.exists(output_dir)
self.index_factory = index_factory_str
assert self.index_factory is not None
self.index_factory_fn = self.index_factory.replace(",", "_")
self.index_template_file = (
f"{output_dir}/{args.xb}/{self.index_factory_fn}.empty.faissindex"
)
logging.info(f"index template: {self.index_template_file}")
if not args.xq:
args.xq = args.xb
self.by_residual = True
if args.no_residuals:
self.by_residual = False
xb_output_dir = f"{output_dir}/{args.xb}"
if not os.path.exists(xb_output_dir):
os.makedirs(xb_output_dir)
xq_output_dir = f"{output_dir}/{args.xq}"
if not os.path.exists(xq_output_dir):
os.makedirs(xq_output_dir)
search_output_dir = f"{output_dir}/{args.xq}_in_{args.xb}"
if not os.path.exists(search_output_dir):
os.makedirs(search_output_dir)
self.knn_dir = f"{search_output_dir}/knn"
if not os.path.exists(self.knn_dir):
os.makedirs(self.knn_dir)
self.eval_dir = f"{search_output_dir}/eval"
if not os.path.exists(self.eval_dir):
os.makedirs(self.eval_dir)
self.index = {} # to keep a reference to opened indices,
self.ivls = {} # hstack inverted lists,
self.index_shards = {} # and index shards
self.index_shard_prefix = (
f"{xb_output_dir}/{self.index_factory_fn}.shard_"
)
self.xq_index_shard_prefix = (
f"{xq_output_dir}/{self.index_factory_fn}.shard_"
)
self.index_file = ( # TODO: added back temporarily for evaluate, handle name of non-sharded index file and remove.
f"{xb_output_dir}/{self.index_factory_fn}.faissindex"
)
self.xq_index_file = (
f"{xq_output_dir}/{self.index_factory_fn}.faissindex"
)
self.training_sample = cfg["training_sample"]
self.evaluation_sample = cfg["evaluation_sample"]
self.xq_ds = create_dataset_from_oivf_config(cfg, args.xq)
self.xb_ds = create_dataset_from_oivf_config(cfg, args.xb)
file_descriptors = self.xq_ds.file_descriptors
self.file_sizes = [fd.size for fd in file_descriptors]
self.shard_size = cfg["index_shard_size"] # ~100GB
self.nshards = self.xb_ds.size // self.shard_size
if self.xb_ds.size % self.shard_size != 0:
self.nshards += 1
self.xq_nshards = self.xq_ds.size // self.shard_size
if self.xq_ds.size % self.shard_size != 0:
self.xq_nshards += 1
self.nprobe = nprobe
assert self.nprobe > 0, "Invalid nprobe parameter."
if "deduper" in cfg:
self.deduper = cfg["deduper"]
self.deduper_codec_fn = [
f"{xb_output_dir}/deduper_codec_{codec.replace(',', '_')}"
for codec in self.deduper
]
self.deduper_idx_fn = [
f"{xb_output_dir}/deduper_idx_{codec.replace(',', '_')}"
for codec in self.deduper
]
else:
self.deduper = None
self.k = cfg["k"]
assert self.k > 0, "Invalid number of neighbours parameter."
self.knn_output_file_suffix = (
f"{self.index_factory_fn}_np{self.nprobe}.npy"
)
fp = 32
if self.dt == "float16":
fp = 16
self.xq_bs = cfg["query_batch_size"]
if "metric" in cfg:
self.metric = eval(f'faiss.{cfg["metric"]}')
else:
self.metric = faiss.METRIC_L2
if "evaluate_by_margin" in cfg:
self.evaluate_by_margin = cfg["evaluate_by_margin"]
else:
self.evaluate_by_margin = False
os.system("grep -m1 'model name' < /proc/cpuinfo")
os.system("grep -E 'MemTotal|MemFree' /proc/meminfo")
os.system("nvidia-smi")
os.system("nvcc --version")
self.knn_queries_memory_limit = 4 * 1024 * 1024 * 1024 # 4 GB
self.knn_vectors_memory_limit = 8 * 1024 * 1024 * 1024 # 8 GB
def input_stats(self):
"""
Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added).
"""
xb_sample = self.xb_ds.get_first_n(self.training_sample, np.float32)
logging.info(f"input shape: {xb_sample.shape}")
logging.info("running MatrixStats on training sample...")
logging.info(faiss.MatrixStats(xb_sample).comments)
logging.info("done")
def dedupe(self):
logging.info(self.deduper)
if self.deduper is None:
logging.info("No deduper configured")
return
codecs = []
codesets = []
idxs = []
for factory, filename in zip(self.deduper, self.deduper_codec_fn):
if os.path.exists(filename):
logging.info(f"loading trained dedupe codec: {filename}")
codec = faiss.read_index(filename)
else:
logging.info(f"training dedupe codec: {factory}")
codec = faiss.index_factory(self.input_d, factory)
xb_sample = np.unique(
self.xb_ds.get_first_n(100_000, np.float32), axis=0
)
faiss.ParameterSpace().set_index_parameter(codec, "verbose", 1)
codec.train(xb_sample)
logging.info(f"writing trained dedupe codec: {filename}")
faiss.write_index(codec, filename)
codecs.append(codec)
codesets.append(faiss.CodeSet(codec.sa_code_size()))
idxs.append(np.empty((0,), dtype=np.uint32))
bs = 1_000_000
i = 0
for buffer in tqdm(self._iterate_transformed(self.xb_ds, 0, bs, np.float32)):
for j in range(len(codecs)):
codec, codeset, idx = codecs[j], codesets[j], idxs[j]
uniq = codeset.insert(codec.sa_encode(buffer))
idxs[j] = np.append(
idx,
np.arange(i, i + buffer.shape[0], dtype=np.uint32)[uniq],
)
i += buffer.shape[0]
for idx, filename in zip(idxs, self.deduper_idx_fn):
logging.info(f"writing {filename}, shape: {idx.shape}")
np.save(filename, idx)
logging.info("done")
def train_index(self):
"""
Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added).
"""
assert not os.path.exists(self.index_template_file), (
"The train command has been ran, the index template file already"
" exists."
)
xb_sample = np.unique(
self.xb_ds.get_first_n(self.training_sample, np.float32), axis=0
)
logging.info(f"input shape: {xb_sample.shape}")
index = faiss.index_factory(
self.input_d, self.index_factory, self.metric
)
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
index_ivf.by_residual = True
faiss.ParameterSpace().set_index_parameter(index, "verbose", 1)
logging.info("running training...")
index.train(xb_sample)
logging.info(f"writing trained index {self.index_template_file}...")
faiss.write_index(index, self.index_template_file)
logging.info("done")
def _iterate_transformed(self, ds, start, batch_size, dt):
assert os.path.exists(self.index_template_file)
index = faiss.read_index(self.index_template_file)
if is_pretransform_index(index):
vt = index.chain.at(0) # fetch pretransform
for buffer in ds.iterate(start, batch_size, dt):
yield vt.apply(buffer)
else:
for buffer in ds.iterate(start, batch_size, dt):
yield buffer
def index_shard(self):
assert os.path.exists(self.index_template_file)
index = faiss.read_index(self.index_template_file)
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
assert self.nprobe <= index_ivf.quantizer.ntotal, (
f"the number of vectors {index_ivf.quantizer.ntotal} is not enough"
f" to retrieve {self.nprobe} neighbours, check."
)
cpu_quantizer = index_ivf.quantizer
gpu_quantizer = faiss.index_cpu_to_all_gpus(cpu_quantizer)
for i in range(0, self.nshards):
sfn = f"{self.index_shard_prefix}{i}"
try:
index.reset()
index_ivf.quantizer = gpu_quantizer
with open(sfn, "xb"):
start = i * self.shard_size
jj = 0
embeddings_batch_size = min(
EMBEDDINGS_BATCH_SIZE, self.shard_size
)
assert (
self.shard_size % embeddings_batch_size == 0
or EMBEDDINGS_BATCH_SIZE % embeddings_batch_size == 0
), (
f"the shard size {self.shard_size} and embeddings"
f" shard size {EMBEDDINGS_BATCH_SIZE} are not"
" divisible"
)
for xb_j in tqdm(
self._iterate_transformed(
self.xb_ds,
start,
embeddings_batch_size,
np.float32,
),
file=sys.stdout,
):
if is_pretransform_index(index):
assert xb_j.shape[1] == index.chain.at(0).d_out
index_ivf.add_with_ids(
xb_j,
np.arange(start + jj, start + jj + xb_j.shape[0]),
)
else:
assert xb_j.shape[1] == index.d
index.add_with_ids(
xb_j,
np.arange(start + jj, start + jj + xb_j.shape[0]),
)
jj += xb_j.shape[0]
logging.info(jj)
assert (
jj <= self.shard_size
), f"jj {jj} and shard_zide {self.shard_size}"
if jj == self.shard_size:
break
logging.info(f"writing {sfn}...")
index_ivf.quantizer = cpu_quantizer
faiss.write_index(index, sfn)
except FileExistsError:
logging.info(f"skipping shard: {i}")
continue
logging.info("done")
def merge_index(self):
ivf_file = f"{self.index_file}.ivfdata"
assert os.path.exists(self.index_template_file)
assert not os.path.exists(
ivf_file
), f"file with embeddings data {ivf_file} not found, check."
assert not os.path.exists(self.index_file)
index = faiss.read_index(self.index_template_file)
block_fnames = [
f"{self.index_shard_prefix}{i}" for i in range(self.nshards)
]
for fn in block_fnames:
assert os.path.exists(fn)
logging.info(block_fnames)
logging.info("merging...")
merge_ondisk(index, block_fnames, ivf_file)
logging.info("writing index...")
faiss.write_index(index, self.index_file)
logging.info("done")
def _cached_search(
self,
sample,
xq_ds,
xb_ds,
idx_file,
vecs_file,
I_file,
D_file,
index_file=None,
nprobe=None,
):
if not os.path.exists(I_file):
assert not os.path.exists(I_file), f"file {I_file} does not exist "
assert not os.path.exists(D_file), f"file {D_file} does not exist "
xq = xq_ds.sample(sample, idx_file, vecs_file)
if index_file:
D, I = self._index_nonsharded_search(index_file, xq, nprobe)
else:
logging.info("ground truth computations")
db_iterator = xb_ds.iterate(0, 100_000, np.float32)
D, I = knn_ground_truth(
xq, db_iterator, self.k, metric_type=self.metric
)
assert np.amin(I) >= 0
np.save(I_file, I)
np.save(D_file, D)
else:
assert os.path.exists(idx_file), f"file {idx_file} does not exist "
assert os.path.exists(
vecs_file
), f"file {vecs_file} does not exist "
assert os.path.exists(I_file), f"file {I_file} does not exist "
assert os.path.exists(D_file), f"file {D_file} does not exist "
I = np.load(I_file)
D = np.load(D_file)
assert I.shape == (sample, self.k), f"{I_file} shape mismatch"
assert D.shape == (sample, self.k), f"{D_file} shape mismatch"
return (D, I)
def _index_search(self, index_shard_prefix, xq, nprobe):
assert nprobe is not None
logging.info(
f"open sharded index: {index_shard_prefix}, {self.nshards}"
)
index = self._open_sharded_index(index_shard_prefix)
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
logging.info(f"setting nprobe to {nprobe}")
index_ivf.nprobe = nprobe
return index.search(xq, self.k)
def _index_nonsharded_search(self, index_file, xq, nprobe):
assert nprobe is not None
logging.info(f"index {index_file}")
assert os.path.exists(index_file), f"file {index_file} does not exist "
index = faiss.read_index(index_file, faiss.IO_FLAG_ONDISK_SAME_DIR)
logging.info(f"index size {index.ntotal} ")
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
logging.info(f"setting nprobe to {nprobe}")
index_ivf.nprobe = nprobe
return index.search(xq, self.k)
def _refine_distances(self, xq_ds, idx, xb_ds, I):
xq = xq_ds.get(idx).repeat(self.k, axis=0)
xb = xb_ds.get(I.reshape(-1))
if self.metric == faiss.METRIC_INNER_PRODUCT:
return (xq * xb).sum(axis=1).reshape(I.shape)
elif self.metric == faiss.METRIC_L2:
return ((xq - xb) ** 2).sum(axis=1).reshape(I.shape)
else:
raise ValueError(f"metric not supported {self.metric}")
def evaluate(self):
self._evaluate(
self.index_factory_fn,
self.index_file,
self.xq_index_file,
self.nprobe,
)
def _evaluate(self, index_factory_fn, index_file, xq_index_file, nprobe):
idx_a_file = f"{self.eval_dir}/idx_a.npy"
idx_b_gt_file = f"{self.eval_dir}/idx_b_gt.npy"
idx_b_ann_file = (
f"{self.eval_dir}/idx_b_ann_{index_factory_fn}_np{nprobe}.npy"
)
vecs_a_file = f"{self.eval_dir}/vecs_a.npy"
vecs_b_gt_file = f"{self.eval_dir}/vecs_b_gt.npy"
vecs_b_ann_file = (
f"{self.eval_dir}/vecs_b_ann_{index_factory_fn}_np{nprobe}.npy"
)
D_a_gt_file = f"{self.eval_dir}/D_a_gt.npy"
D_a_ann_file = (
f"{self.eval_dir}/D_a_ann_{index_factory_fn}_np{nprobe}.npy"
)
D_a_ann_refined_file = f"{self.eval_dir}/D_a_ann_refined_{index_factory_fn}_np{nprobe}.npy"
D_b_gt_file = f"{self.eval_dir}/D_b_gt.npy"
D_b_ann_file = (
f"{self.eval_dir}/D_b_ann_{index_factory_fn}_np{nprobe}.npy"
)
D_b_ann_gt_file = (
f"{self.eval_dir}/D_b_ann_gt_{index_factory_fn}_np{nprobe}.npy"
)
I_a_gt_file = f"{self.eval_dir}/I_a_gt.npy"
I_a_ann_file = (
f"{self.eval_dir}/I_a_ann_{index_factory_fn}_np{nprobe}.npy"
)
I_b_gt_file = f"{self.eval_dir}/I_b_gt.npy"
I_b_ann_file = (
f"{self.eval_dir}/I_b_ann_{index_factory_fn}_np{nprobe}.npy"
)
I_b_ann_gt_file = (
f"{self.eval_dir}/I_b_ann_gt_{index_factory_fn}_np{nprobe}.npy"
)
margin_gt_file = f"{self.eval_dir}/margin_gt.npy"
margin_refined_file = (
f"{self.eval_dir}/margin_refined_{index_factory_fn}_np{nprobe}.npy"
)
margin_ann_file = (
f"{self.eval_dir}/margin_ann_{index_factory_fn}_np{nprobe}.npy"
)
logging.info("exact search forward")
# xq -> xb AKA a -> b
D_a_gt, I_a_gt = self._cached_search(
self.evaluation_sample,
self.xq_ds,
self.xb_ds,
idx_a_file,
vecs_a_file,
I_a_gt_file,
D_a_gt_file,
)
idx_a = np.load(idx_a_file)
logging.info("approximate search forward")
D_a_ann, I_a_ann = self._cached_search(
self.evaluation_sample,
self.xq_ds,
self.xb_ds,
idx_a_file,
vecs_a_file,
I_a_ann_file,
D_a_ann_file,
index_file,
nprobe,
)
logging.info(
"calculate refined distances on approximate search forward"
)
if os.path.exists(D_a_ann_refined_file):
D_a_ann_refined = np.load(D_a_ann_refined_file)
assert D_a_ann.shape == D_a_ann_refined.shape
else:
D_a_ann_refined = self._refine_distances(
self.xq_ds, idx_a, self.xb_ds, I_a_ann
)
np.save(D_a_ann_refined_file, D_a_ann_refined)
if self.evaluate_by_margin:
k_extract = self.k
margin_threshold = 1.05
logging.info(
"exact search backward from the k_extract NN results of"
" forward search"
)
# xb -> xq AKA b -> a
D_a_b_gt = D_a_gt[:, :k_extract].ravel()
idx_b_gt = I_a_gt[:, :k_extract].ravel()
assert len(idx_b_gt) == self.evaluation_sample * k_extract
np.save(idx_b_gt_file, idx_b_gt)
# exact search
D_b_gt, _ = self._cached_search(
len(idx_b_gt),
self.xb_ds,
self.xq_ds,
idx_b_gt_file,
vecs_b_gt_file,
I_b_gt_file,
D_b_gt_file,
) # xb and xq ^^^ are inverted
logging.info("margin on exact search")
margin_gt = margin(
self.evaluation_sample,
idx_a,
idx_b_gt,
D_a_b_gt,
D_a_gt,
D_b_gt,
self.k,
k_extract,
margin_threshold,
)
np.save(margin_gt_file, margin_gt)
logging.info(
"exact search backward from the k_extract NN results of"
" approximate forward search"
)
D_a_b_refined = D_a_ann_refined[:, :k_extract].ravel()
idx_b_ann = I_a_ann[:, :k_extract].ravel()
assert len(idx_b_ann) == self.evaluation_sample * k_extract
np.save(idx_b_ann_file, idx_b_ann)
# exact search
D_b_ann_gt, _ = self._cached_search(
len(idx_b_ann),
self.xb_ds,
self.xq_ds,
idx_b_ann_file,
vecs_b_ann_file,
I_b_ann_gt_file,
D_b_ann_gt_file,
) # xb and xq ^^^ are inverted
logging.info("refined margin on approximate search")
margin_refined = margin(
self.evaluation_sample,
idx_a,
idx_b_ann,
D_a_b_refined,
D_a_gt, # not D_a_ann_refined(!)
D_b_ann_gt,
self.k,
k_extract,
margin_threshold,
)
np.save(margin_refined_file, margin_refined)
D_b_ann, I_b_ann = self._cached_search(
len(idx_b_ann),
self.xb_ds,
self.xq_ds,
idx_b_ann_file,
vecs_b_ann_file,
I_b_ann_file,
D_b_ann_file,
xq_index_file,
nprobe,
)
D_a_b_ann = D_a_ann[:, :k_extract].ravel()
logging.info("approximate search margin")
margin_ann = margin(
self.evaluation_sample,
idx_a,
idx_b_ann,
D_a_b_ann,
D_a_ann,
D_b_ann,
self.k,
k_extract,
margin_threshold,
)
np.save(margin_ann_file, margin_ann)
logging.info("intersection")
logging.info(I_a_gt)
logging.info(I_a_ann)
for i in range(1, self.k + 1):
logging.info(
f"{i}: {knn_intersection_measure(I_a_gt[:,:i], I_a_ann[:,:i])}"
)
logging.info(f"mean of gt distances: {D_a_gt.mean()}")
logging.info(f"mean of approx distances: {D_a_ann.mean()}")
logging.info(f"mean of refined distances: {D_a_ann_refined.mean()}")
logging.info("intersection cardinality frequencies")
logging.info(get_intersection_cardinality_frequencies(I_a_ann, I_a_gt))
logging.info("done")
pass
def _knn_function(self, xq, xb, k, metric, thread_id=None):
try:
return faiss.knn_gpu(
self.all_gpu_resources[thread_id],
xq,
xb,
k,
metric=metric,
device=thread_id,
vectorsMemoryLimit=self.knn_vectors_memory_limit,
queriesMemoryLimit=self.knn_queries_memory_limit,
)
except Exception:
logging.info(f"knn_function failed: {xq.shape}, {xb.shape}")
raise
def _coarse_quantize(self, index_ivf, xq, nprobe):
assert nprobe <= index_ivf.quantizer.ntotal
quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer)
bs = 100_000
nq = len(xq)
q_assign = np.empty((nq, nprobe), dtype="int32")
for i0 in trange(0, nq, bs):
i1 = min(nq, i0 + bs)
_, q_assign_i = quantizer.search(xq[i0:i1], nprobe)
q_assign[i0:i1] = q_assign_i
return q_assign
def search(self):
logging.info(f"search: {self.knn_dir}")
slurm_job_id = os.environ.get("SLURM_JOB_ID")
ngpu = faiss.get_num_gpus()
logging.info(f"number of gpus: {ngpu}")
self.all_gpu_resources = [
faiss.StandardGpuResources() for _ in range(ngpu)
]
self._knn_function(
np.zeros((10, 10), dtype=np.float16),
np.zeros((10, 10), dtype=np.float16),
self.k,
metric=self.metric,
thread_id=0,
)
index = self._open_sharded_index()
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
logging.info(f"setting nprobe to {self.nprobe}")
index_ivf.nprobe = self.nprobe
# quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer)
for i in range(0, self.xq_ds.size, self.xq_bs):
Ifn = f"{self.knn_dir}/I{(i):010}_{self.knn_output_file_suffix}"
Dfn = f"{self.knn_dir}/D_approx{(i):010}_{self.knn_output_file_suffix}"
CPfn = f"{self.knn_dir}/CP{(i):010}_{self.knn_output_file_suffix}"
if slurm_job_id:
worker_record = (
self.knn_dir
+ f"/record_{(i):010}_{self.knn_output_file_suffix}.txt"
)
if not os.path.exists(worker_record):
logging.info(
f"creating record file {worker_record} and saving job"
f" id: {slurm_job_id}"
)
with open(worker_record, "w") as h:
h.write(slurm_job_id)
else:
old_slurm_id = open(worker_record, "r").read()
logging.info(
f"old job slurm id {old_slurm_id} and current job id:"
f" {slurm_job_id}"
)
if old_slurm_id == slurm_job_id:
if os.path.getsize(Ifn) == 0:
logging.info(
f"cleaning up zero length files {Ifn} and"
f" {Dfn}"
)
os.remove(Ifn)
os.remove(Dfn)
try:
if is_pretransform_index(index):
d = index.chain.at(0).d_out
else:
d = self.input_d
with open(Ifn, "xb") as f, open(Dfn, "xb") as g:
xq_i = np.empty(
shape=(self.xq_bs, d), dtype=np.float16
)
q_assign = np.empty(
(self.xq_bs, self.nprobe), dtype=np.int32
)
j = 0
quantizer = faiss.index_cpu_to_all_gpus(
index_ivf.quantizer
)
for xq_i_j in tqdm(
self._iterate_transformed(
self.xq_ds, i, min(100_000, self.xq_bs), np.float16
),
file=sys.stdout,
):
xq_i[j:j + xq_i_j.shape[0]] = xq_i_j
(
_,
q_assign[j:j + xq_i_j.shape[0]],
) = quantizer.search(xq_i_j, self.nprobe)
j += xq_i_j.shape[0]
assert j <= xq_i.shape[0]
if j == xq_i.shape[0]:
break
xq_i = xq_i[:j]
q_assign = q_assign[:j]
assert q_assign.shape == (xq_i.shape[0], index_ivf.nprobe)
del quantizer
logging.info(f"computing: {Ifn}")
logging.info(f"computing: {Dfn}")
prefetch_threads = faiss.get_num_gpus()
D_ann, I = big_batch_search(
index_ivf,
xq_i,
self.k,
verbose=10,
method="knn_function",
knn=self._knn_function,
threaded=faiss.get_num_gpus() * 8,
use_float16=True,
prefetch_threads=prefetch_threads,
computation_threads=faiss.get_num_gpus(),
q_assign=q_assign,
checkpoint=CPfn,
checkpoint_freq=7200, # in seconds
)
assert (
np.amin(I) >= 0
), f"{I}, there exists negative indices, check"
logging.info(f"saving: {Ifn}")
np.save(f, I)
logging.info(f"saving: {Dfn}")
np.save(g, D_ann)
if os.path.exists(CPfn):
logging.info(f"removing: {CPfn}")
os.remove(CPfn)
except FileExistsError:
logging.info(f"skipping {Ifn}, already exists")
logging.info(f"skipping {Dfn}, already exists")
continue
def _open_index_shard(self, fn):
if fn in self.index_shards:
index_shard = self.index_shards[fn]
else:
logging.info(f"open index shard: {fn}")
index_shard = faiss.read_index(
fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
)
self.index_shards[fn] = index_shard
return index_shard
def _open_sharded_index(self, index_shard_prefix=None):
if index_shard_prefix is None:
index_shard_prefix = self.index_shard_prefix
if index_shard_prefix in self.index:
return self.index[index_shard_prefix]
assert os.path.exists(
self.index_template_file
), f"file {self.index_template_file} does not exist "
logging.info(f"open index template: {self.index_template_file}")
index = faiss.read_index(self.index_template_file)
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
ilv = faiss.InvertedListsPtrVector()
for i in range(self.nshards):
fn = f"{index_shard_prefix}{i}"
assert os.path.exists(fn), f"file {fn} does not exist "
logging.info(fn)
index_shard = self._open_index_shard(fn)
il = faiss.downcast_index(
faiss.extract_index_ivf(index_shard)
).invlists
ilv.push_back(il)
hsil = faiss.HStackInvertedLists(ilv.size(), ilv.data())
index_ivf.replace_invlists(hsil, False)
self.ivls[index_shard_prefix] = hsil
self.index[index_shard_prefix] = index
return index
def index_shard_stats(self):
for i in range(self.nshards):
fn = f"{self.index_shard_prefix}{i}"
assert os.path.exists(fn)
index = faiss.read_index(
fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
)
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
il = index_ivf.invlists
il.print_stats()
def index_stats(self):
index = self._open_sharded_index()
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
il = index_ivf.invlists
list_sizes = [il.list_size(i) for i in range(il.nlist)]
logging.info(np.max(list_sizes))
logging.info(np.mean(list_sizes))
logging.info(np.argmax(list_sizes))
logging.info("index_stats:")
il.print_stats()
def consistency_check(self):
logging.info("consistency-check")
logging.info("index template...")
assert os.path.exists(self.index_template_file)
index = faiss.read_index(self.index_template_file)
offset = 0 # 2**24
assert self.shard_size > offset + SMALL_DATA_SAMPLE
logging.info("index shards...")
for i in range(self.nshards):
r = i * self.shard_size + offset
xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32))
fn = f"{self.index_shard_prefix}{i}"
assert os.path.exists(fn), f"There is no index shard file {fn}"
index = self._open_index_shard(fn)
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
index_ivf.nprobe = 1
_, I = index.search(xb, 100)
for j in range(SMALL_DATA_SAMPLE):
assert np.where(I[j] == j + r)[0].size > 0, (
f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:"
f" {self.shard_size}"
)
logging.info("merged index...")
index = self._open_sharded_index()
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
index_ivf.nprobe = 1
for i in range(self.nshards):
r = i * self.shard_size + offset
xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32))
_, I = index.search(xb, 100)
for j in range(SMALL_DATA_SAMPLE):
assert np.where(I[j] == j + r)[0].size > 0, (
f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:"
f" {self.shard_size}")
logging.info("search results...")
index_ivf.nprobe = self.nprobe
for i in range(0, self.xq_ds.size, self.xq_bs):
Ifn = f"{self.knn_dir}/I{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy"
assert os.path.exists(Ifn)
assert os.path.getsize(Ifn) > 0, f"The file {Ifn} is empty."
logging.info(Ifn)
I = np.load(Ifn, mmap_mode="r")
assert I.shape[1] == self.k
assert I.shape[0] == min(self.xq_bs, self.xq_ds.size - i)
assert np.all(I[:, 1] >= 0)
Dfn = f"{self.knn_dir}/D_approx{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy"
assert os.path.exists(Dfn)
assert os.path.getsize(Dfn) > 0, f"The file {Dfn} is empty."
logging.info(Dfn)
D = np.load(Dfn, mmap_mode="r")
assert D.shape == I.shape
xq = next(self.xq_ds.iterate(i, SMALL_DATA_SAMPLE, np.float32))
D_online, I_online = index.search(xq, self.k)
assert (
np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size
/ (self.k * SMALL_DATA_SAMPLE)
> 0.95
), (
"the ratio is"
f" {np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size / (self.k * SMALL_DATA_SAMPLE)}"
)
assert np.allclose(
D[:SMALL_DATA_SAMPLE].sum(axis=1),
D_online.sum(axis=1),
rtol=0.01,
), (
"the difference is"
f" {D[:SMALL_DATA_SAMPLE].sum(axis=1), D_online.sum(axis=1)}"
)
logging.info("done")

View File

@@ -0,0 +1,219 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from utils import (
load_config,
add_group_args,
)
from offline_ivf import OfflineIVF
import faiss
from typing import List, Callable, Dict
import submitit
def join_lists_in_dict(poss: List[str]) -> List[str]:
"""
Joins two lists of prod and non-prod values, checking if the prod value is already included.
If there is no non-prod list, it returns the prod list.
"""
if "non-prod" in poss.keys():
all_poss = poss["non-prod"]
if poss["prod"][-1] not in poss["non-prod"]:
all_poss += poss["prod"]
return all_poss
else:
return poss["prod"]
def main(
args: argparse.Namespace,
cfg: Dict[str, str],
nprobe: int,
index_factory_str: str,
) -> None:
oivf = OfflineIVF(cfg, args, nprobe, index_factory_str)
eval(f"oivf.{args.command}()")
def process_options_and_run_jobs(args: argparse.Namespace) -> None:
"""
If "--cluster_run", it launches an array of jobs to the cluster using the submitit library for all the index strings. In
the case of evaluate, it launches a job for each index string and nprobe pair. Otherwise, it launches a single job
that is ran locally with the prod values for index string and nprobe.
"""
cfg = load_config(args.config)
index_strings = cfg["index"]
nprobes = cfg["nprobe"]
if args.command == "evaluate":
if args.cluster_run:
all_nprobes = join_lists_in_dict(nprobes)
all_index_strings = join_lists_in_dict(index_strings)
for index_factory_str in all_index_strings:
for nprobe in all_nprobes:
launch_job(main, args, cfg, nprobe, index_factory_str)
else:
launch_job(
main, args, cfg, nprobes["prod"][-1], index_strings["prod"][-1]
)
else:
if args.cluster_run:
all_index_strings = join_lists_in_dict(index_strings)
for index_factory_str in all_index_strings:
launch_job(
main, args, cfg, nprobes["prod"][-1], index_factory_str
)
else:
launch_job(
main, args, cfg, nprobes["prod"][-1], index_strings["prod"][-1]
)
def launch_job(
func: Callable,
args: argparse.Namespace,
cfg: Dict[str, str],
n_probe: int,
index_str: str,
) -> None:
"""
Launches an array of slurm jobs to the cluster using the submitit library.
"""
if args.cluster_run:
assert args.num_nodes >= 1
executor = submitit.AutoExecutor(folder=args.logs_dir)
executor.update_parameters(
nodes=args.num_nodes,
gpus_per_node=args.gpus_per_node,
cpus_per_task=args.cpus_per_task,
tasks_per_node=args.tasks_per_node,
name=args.job_name,
slurm_partition=args.partition,
slurm_time=70 * 60,
)
if args.slurm_constraint:
executor.update_parameters(slurm_constraint=args.slurm_constrain)
job = executor.submit(func, args, cfg, n_probe, index_str)
print(f"Job id: {job.job_id}")
else:
func(args, cfg, n_probe, index_str)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
group = parser.add_argument_group("general")
add_group_args(group, "--command", required=True, help="command to run")
add_group_args(
group,
"--config",
required=True,
help="config yaml with the dataset specs",
)
add_group_args(
group, "--nt", type=int, default=96, help="nb search threads"
)
add_group_args(
group,
"--no_residuals",
action="store_false",
help="set index.by_residual to False during train index.",
)
group = parser.add_argument_group("slurm_job")
add_group_args(
group,
"--cluster_run",
action="store_true",
help=" if True, runs in cluster",
)
add_group_args(
group,
"--job_name",
type=str,
default="oivf",
help="cluster job name",
)
add_group_args(
group,
"--num_nodes",
type=str,
default=1,
help="num of nodes per job",
)
add_group_args(
group,
"--tasks_per_node",
type=int,
default=1,
help="tasks per job",
)
add_group_args(
group,
"--gpus_per_node",
type=int,
default=8,
help="cluster job name",
)
add_group_args(
group,
"--cpus_per_task",
type=int,
default=80,
help="cluster job name",
)
add_group_args(
group,
"--logs_dir",
type=str,
default="/checkpoint/marialomeli/offline_faiss/logs",
help="cluster job name",
)
add_group_args(
group,
"--slurm_constraint",
type=str,
default=None,
help="can be volta32gb for the fair cluster",
)
add_group_args(
group,
"--partition",
type=str,
default="learnlab",
help="specify which partition to use if ran on cluster with job arrays",
choices=[
"learnfair",
"devlab",
"scavenge",
"learnlab",
"nllb",
"seamless",
"seamless_medium",
"learnaccel",
"onellm_low",
"learn",
"scavenge",
],
)
group = parser.add_argument_group("dataset")
add_group_args(group, "--xb", required=True, help="database vectors")
add_group_args(group, "--xq", help="query vectors")
args = parser.parse_args()
print("args:", args)
faiss.omp_set_num_threads(args.nt)
process_options_and_run_jobs(args=args)

View File

@@ -0,0 +1,181 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import yaml
import numpy as np
from typing import Dict, List, Optional
OIVF_TEST_ARGS: List[str] = [
"--config",
"--xb",
"--xq",
"--command",
"--cluster_run",
"--no_residuals",
]
def get_test_parser(args) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
for arg in args:
parser.add_argument(arg)
return parser
class TestDataCreator:
def __init__(
self,
tempdir: str,
dimension: int,
data_type: np.dtype,
index_factory: Optional[List] = ["OPQ4,IVF256,PQ4"],
training_sample: Optional[int] = 9984,
index_shard_size: Optional[int] = 1000,
query_batch_size: Optional[int] = 1000,
evaluation_sample: Optional[int] = 100,
num_files: Optional[int] = None,
file_size: Optional[int] = None,
file_sizes: Optional[List] = None,
nprobe: Optional[int] = 64,
k: Optional[int] = 10,
metric: Optional[str] = "METRIC_L2",
normalise: Optional[bool] = False,
with_queries_ds: Optional[bool] = False,
evaluate_by_margin: Optional[bool] = False,
) -> None:
self.tempdir = tempdir
self.dimension = dimension
self.data_type = np.dtype(data_type).name
self.index_factory = {"prod": index_factory}
if file_size and num_files:
self.file_sizes = [file_size for _ in range(num_files)]
elif file_sizes:
self.file_sizes = file_sizes
else:
raise ValueError("no file sizes provided")
self.num_files = len(self.file_sizes)
self.training_sample = training_sample
self.index_shard_size = index_shard_size
self.query_batch_size = query_batch_size
self.evaluation_sample = evaluation_sample
self.nprobe = {"prod": [nprobe]}
self.k = k
self.metric = metric
self.normalise = normalise
self.config_file = self.tempdir + "/config_test.yaml"
self.ds_name = "my_test_data"
self.qs_name = "my_queries_data"
self.evaluate_by_margin = evaluate_by_margin
self.with_queries_ds = with_queries_ds
def create_test_data(self) -> None:
datafiles = self._create_data_files()
files_info = []
for i, file in enumerate(datafiles):
files_info.append(
{
"dtype": self.data_type,
"format": "npy",
"name": file,
"size": self.file_sizes[i],
}
)
config_for_yaml = {
"d": self.dimension,
"output": self.tempdir,
"index": self.index_factory,
"nprobe": self.nprobe,
"k": self.k,
"normalise": self.normalise,
"metric": self.metric,
"training_sample": self.training_sample,
"evaluation_sample": self.evaluation_sample,
"index_shard_size": self.index_shard_size,
"query_batch_size": self.query_batch_size,
"datasets": {
self.ds_name: {
"root": self.tempdir,
"size": sum(self.file_sizes),
"files": files_info,
}
},
}
if self.evaluate_by_margin:
config_for_yaml["evaluate_by_margin"] = self.evaluate_by_margin
q_datafiles = self._create_data_files("my_q_data")
q_files_info = []
for i, file in enumerate(q_datafiles):
q_files_info.append(
{
"dtype": self.data_type,
"format": "npy",
"name": file,
"size": self.file_sizes[i],
}
)
if self.with_queries_ds:
config_for_yaml["datasets"][self.qs_name] = {
"root": self.tempdir,
"size": sum(self.file_sizes),
"files": q_files_info,
}
self._create_config_yaml(config_for_yaml)
def setup_cli(self, command="consistency_check") -> argparse.Namespace:
parser = get_test_parser(OIVF_TEST_ARGS)
if self.with_queries_ds:
return parser.parse_args(
[
"--xb",
self.ds_name,
"--config",
self.config_file,
"--command",
command,
"--xq",
self.qs_name,
]
)
return parser.parse_args(
[
"--xb",
self.ds_name,
"--config",
self.config_file,
"--command",
command,
]
)
def _create_data_files(self, name_of_file="my_data") -> List[str]:
"""
Creates a dataset "my_test_data" with number of files (num_files), using padding in the files
name. If self.with_queries is True, it adds an extra dataset "my_queries_data" with the same number of files
as the "my_test_data". The default name for embeddings files is "my_data" + <padding>.npy.
"""
filenames = []
for i, file_size in enumerate(self.file_sizes):
# np.random.seed(i)
db_vectors = np.random.random((file_size, self.dimension)).astype(
self.data_type
)
filename = name_of_file + f"{i:02}" + ".npy"
filenames.append(filename)
np.save(self.tempdir + "/" + filename, db_vectors)
return filenames
def _create_config_yaml(self, dict_file: Dict[str, str]) -> None:
"""
Creates a yaml file in dir (can be a temporary dir for tests).
"""
filename = self.tempdir + "/config_test.yaml"
with open(filename, "w") as file:
yaml.dump(dict_file, file, default_flow_style=False)

View File

@@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import os
from typing import Dict
import yaml
import faiss
from faiss.contrib.datasets import SyntheticDataset
def load_config(config):
assert os.path.exists(config)
with open(config, "r") as f:
return yaml.safe_load(f)
def faiss_sanity_check():
ds = SyntheticDataset(256, 0, 100, 100)
xq = ds.get_queries()
xb = ds.get_database()
index_cpu = faiss.IndexFlat(ds.d)
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
index_cpu.add(xb)
index_gpu.add(xb)
D_cpu, I_cpu = index_cpu.search(xq, 10)
D_gpu, I_gpu = index_gpu.search(xq, 10)
assert np.all(I_cpu == I_gpu), "faiss sanity check failed"
assert np.all(np.isclose(D_cpu, D_gpu)), "faiss sanity check failed"
def margin(sample, idx_a, idx_b, D_a_b, D_a, D_b, k, k_extract, threshold):
"""
two datasets: xa, xb; n = number of pairs
idx_a - (np,) - query vector ids in xa
idx_b - (np,) - query vector ids in xb
D_a_b - (np,) - pairwise distances between xa[idx_a] and xb[idx_b]
D_a - (np, k) - distances between vectors xa[idx_a] and corresponding nearest neighbours in xb
D_b - (np, k) - distances between vectors xb[idx_b] and corresponding nearest neighbours in xa
k - k nearest neighbours used for margin
k_extract - number of nearest neighbours of each query in xb we consider for margin calculation and filtering
threshold - margin threshold
"""
n = sample
nk = n * k_extract
assert idx_a.shape == (n,)
idx_a_k = idx_a.repeat(k_extract)
assert idx_a_k.shape == (nk,)
assert idx_b.shape == (nk,)
assert D_a_b.shape == (nk,)
assert D_a.shape == (n, k)
assert D_b.shape == (nk, k)
mean_a = np.mean(D_a, axis=1)
assert mean_a.shape == (n,)
mean_a_k = mean_a.repeat(k_extract)
assert mean_a_k.shape == (nk,)
mean_b = np.mean(D_b, axis=1)
assert mean_b.shape == (nk,)
margin = 2 * D_a_b / (mean_a_k + mean_b)
above_threshold = margin > threshold
print(np.count_nonzero(above_threshold))
print(idx_a_k[above_threshold])
print(idx_b[above_threshold])
print(margin[above_threshold])
return margin
def add_group_args(group, *args, **kwargs):
return group.add_argument(*args, **kwargs)
def get_intersection_cardinality_frequencies(
I: np.ndarray, I_gt: np.ndarray
) -> Dict[int, int]:
"""
Computes the frequencies for the cardinalities of the intersection of neighbour indices.
"""
nq = I.shape[0]
res = []
for ell in range(nq):
res.append(len(np.intersect1d(I[ell, :], I_gt[ell, :])))
values, counts = np.unique(res, return_counts=True)
return dict(zip(values, counts))
def is_pretransform_index(index):
if index.__class__ == faiss.IndexPreTransform:
assert hasattr(index, "chain")
return True
else:
assert not hasattr(index, "chain")
return False