Initial commit
This commit is contained in:
52
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/README.md
vendored
Normal file
52
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/README.md
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
|
||||
# Offline IVF
|
||||
|
||||
This folder contains the code for the offline ivf algorithm powered by faiss big batch search.
|
||||
|
||||
Create a conda env:
|
||||
|
||||
`conda create --name oivf python=3.10`
|
||||
|
||||
`conda activate oivf`
|
||||
|
||||
`conda install -c pytorch/label/nightly -c nvidia faiss-gpu=1.7.4`
|
||||
|
||||
`conda install tqdm`
|
||||
|
||||
`conda install pyyaml`
|
||||
|
||||
`conda install -c conda-forge submitit`
|
||||
|
||||
|
||||
## Run book
|
||||
|
||||
1. Optionally shard your dataset (see create_sharded_dataset.py) and create the corresponding yaml file `config_ssnpp.yaml`. You can use `generate_config.py` by specifying the root directory of your dataset and the files with the data shards
|
||||
|
||||
`python generate_config`
|
||||
|
||||
2. Run the train index command
|
||||
|
||||
`python run.py --command train_index --config config_ssnpp.yaml --xb ssnpp_1B`
|
||||
|
||||
|
||||
3. Run the index-shard command so it produces sharded indexes, required for the search step
|
||||
|
||||
`python run.py --command index_shard --config config_ssnpp.yaml --xb ssnpp_1B`
|
||||
|
||||
|
||||
6. Send jobs to the cluster to run search
|
||||
|
||||
`python run.py --command search --config config_ssnpp.yaml --xb ssnpp_1B --cluster_run --partition <PARTITION-NAME>`
|
||||
|
||||
|
||||
Remarks about the `search` command: it is assumed that the database vectors are the query vectors when performing the search step.
|
||||
a. If the query vectors are different than the database vectors, it should be passed in the xq argument
|
||||
b. A new dataset needs to be prepared (step 1) before passing it to the query vectors argument `–xq`
|
||||
|
||||
`python run.py --command search --config config_ssnpp.yaml --xb ssnpp_1B --xq <QUERIES_DATASET_NAME>`
|
||||
|
||||
|
||||
6. We can always run the consistency-check for sanity checks!
|
||||
|
||||
`python run.py --command consistency_check--config config_ssnpp.yaml --xb ssnpp_1B`
|
||||
|
||||
0
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/__init__.py
vendored
Normal file
0
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/__init__.py
vendored
Normal file
110
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/config_ssnpp.yaml
vendored
Normal file
110
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/config_ssnpp.yaml
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
d: 256
|
||||
output: /checkpoint/marialomeli/offline_faiss/ssnpp
|
||||
index:
|
||||
prod:
|
||||
- 'IVF8192,PQ128'
|
||||
non-prod:
|
||||
- 'IVF16384,PQ128'
|
||||
- 'IVF32768,PQ128'
|
||||
- 'OPQ64_128,IVF4096,PQ64'
|
||||
nprobe:
|
||||
prod:
|
||||
- 512
|
||||
non-prod:
|
||||
- 256
|
||||
- 128
|
||||
- 1024
|
||||
- 2048
|
||||
- 4096
|
||||
- 8192
|
||||
|
||||
k: 50
|
||||
index_shard_size: 50000000
|
||||
query_batch_size: 50000000
|
||||
evaluation_sample: 10000
|
||||
training_sample: 1572864
|
||||
datasets:
|
||||
ssnpp_1B:
|
||||
root: /checkpoint/marialomeli/ssnpp_data
|
||||
size: 1000000000
|
||||
files:
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000000.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000001.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000002.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000003.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000004.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000005.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000006.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000007.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000008.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000009.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000010.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000011.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000012.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000013.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000014.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000015.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000016.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000017.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000018.npy
|
||||
size: 50000000
|
||||
- dtype: uint8
|
||||
format: npy
|
||||
name: ssnpp_0000000019.npy
|
||||
size: 50000000
|
||||
64
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/create_sharded_ssnpp_files.py
vendored
Normal file
64
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/create_sharded_ssnpp_files.py
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def xbin_mmap(fname, dtype, maxn=-1):
|
||||
"""
|
||||
Code from
|
||||
https://github.com/harsha-simhadri/big-ann-benchmarks/blob/main/benchmark/dataset_io.py#L94
|
||||
mmap the competition file format for a given type of items
|
||||
"""
|
||||
n, d = map(int, np.fromfile(fname, dtype="uint32", count=2))
|
||||
assert os.stat(fname).st_size == 8 + n * d * np.dtype(dtype).itemsize
|
||||
if maxn > 0:
|
||||
n = min(n, maxn)
|
||||
return np.memmap(fname, dtype=dtype, mode="r", offset=8, shape=(n, d))
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
ssnpp_data = xbin_mmap(fname=args.filepath, dtype="uint8")
|
||||
num_batches = ssnpp_data.shape[0] // args.data_batch
|
||||
assert (
|
||||
ssnpp_data.shape[0] % args.data_batch == 0
|
||||
), "num of embeddings per file should divide total num of embeddings"
|
||||
for i in range(num_batches):
|
||||
xb_batch = ssnpp_data[
|
||||
i * args.data_batch:(i + 1) * args.data_batch, :
|
||||
]
|
||||
filename = args.output_dir + f"/ssnpp_{(i):010}.npy"
|
||||
np.save(filename, xb_batch)
|
||||
print(f"File {filename} is saved!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_batch",
|
||||
dest="data_batch",
|
||||
type=int,
|
||||
default=50000000,
|
||||
help="Number of embeddings per file, should be a divisor of 1B",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filepath",
|
||||
dest="filepath",
|
||||
type=str,
|
||||
default="/datasets01/big-ann-challenge-data/FB_ssnpp/FB_ssnpp_database.u8bin",
|
||||
help="path of 1B ssnpp database vectors' original file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filepath",
|
||||
dest="output_dir",
|
||||
type=str,
|
||||
default="/checkpoint/marialomeli/ssnpp_data",
|
||||
help="path to put sharded files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
174
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/dataset.py
vendored
Normal file
174
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/dataset.py
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import faiss
|
||||
from typing import List
|
||||
import random
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def create_dataset_from_oivf_config(cfg, ds_name):
|
||||
normalise = cfg["normalise"] if "normalise" in cfg else False
|
||||
return MultiFileVectorDataset(
|
||||
cfg["datasets"][ds_name]["root"],
|
||||
[
|
||||
FileDescriptor(
|
||||
f["name"], f["format"], np.dtype(f["dtype"]), f["size"]
|
||||
)
|
||||
for f in cfg["datasets"][ds_name]["files"]
|
||||
],
|
||||
cfg["d"],
|
||||
normalise,
|
||||
cfg["datasets"][ds_name]["size"],
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def _memmap_vecs(
|
||||
file_name: str, format: str, dtype: np.dtype, size: int, d: int
|
||||
) -> np.array:
|
||||
"""
|
||||
If the file is in raw format, the file size will
|
||||
be divisible by the dimensionality and by the size
|
||||
of the data type.
|
||||
Otherwise,the file contains a header and we assume
|
||||
it is of .npy type. It the returns the memmapped file.
|
||||
"""
|
||||
|
||||
assert os.path.exists(file_name), f"file does not exist {file_name}"
|
||||
if format == "raw":
|
||||
fl = os.path.getsize(file_name)
|
||||
nb = fl // d // dtype.itemsize
|
||||
assert nb == size, f"{nb} is different than config's {size}"
|
||||
assert fl == d * dtype.itemsize * nb # no header
|
||||
return np.memmap(file_name, shape=(nb, d), dtype=dtype, mode="r")
|
||||
elif format == "npy":
|
||||
vecs = np.load(file_name, mmap_mode="r")
|
||||
assert vecs.shape[0] == size, f"size:{size},shape {vecs.shape[0]}"
|
||||
assert vecs.shape[1] == d
|
||||
assert vecs.dtype == dtype
|
||||
return vecs
|
||||
else:
|
||||
ValueError("The file cannot be loaded in the current format.")
|
||||
|
||||
|
||||
class FileDescriptor:
|
||||
def __init__(self, name: str, format: str, dtype: np.dtype, size: int):
|
||||
self.name = name
|
||||
self.format = format
|
||||
self.dtype = dtype
|
||||
self.size = size
|
||||
|
||||
|
||||
class MultiFileVectorDataset:
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
file_descriptors: List[FileDescriptor],
|
||||
d: int,
|
||||
normalize: bool,
|
||||
size: int,
|
||||
):
|
||||
assert os.path.exists(root)
|
||||
self.root = root
|
||||
self.file_descriptors = file_descriptors
|
||||
self.d = d
|
||||
self.normalize = normalize
|
||||
self.size = size
|
||||
self.file_offsets = [0]
|
||||
t = 0
|
||||
for f in self.file_descriptors:
|
||||
xb = _memmap_vecs(
|
||||
f"{self.root}/{f.name}", f.format, f.dtype, f.size, self.d
|
||||
)
|
||||
t += xb.shape[0]
|
||||
self.file_offsets.append(t)
|
||||
assert (
|
||||
t == self.size
|
||||
), "the sum of num of embeddings per file!=total num of embeddings"
|
||||
|
||||
def iterate(self, start: int, batch_size: int, dt: np.dtype):
|
||||
buffer = np.empty(shape=(batch_size, self.d), dtype=dt)
|
||||
rem = 0
|
||||
for f in self.file_descriptors:
|
||||
if start >= f.size:
|
||||
start -= f.size
|
||||
continue
|
||||
logging.info(f"processing: {f.name}...")
|
||||
xb = _memmap_vecs(
|
||||
f"{self.root}/{f.name}",
|
||||
f.format,
|
||||
f.dtype,
|
||||
f.size,
|
||||
self.d,
|
||||
)
|
||||
if start > 0:
|
||||
xb = xb[start:]
|
||||
start = 0
|
||||
req = min(batch_size - rem, xb.shape[0])
|
||||
buffer[rem:rem + req] = xb[:req]
|
||||
rem += req
|
||||
if rem == batch_size:
|
||||
if self.normalize:
|
||||
faiss.normalize_L2(buffer)
|
||||
yield buffer.copy()
|
||||
rem = 0
|
||||
for i in range(req, xb.shape[0], batch_size):
|
||||
j = i + batch_size
|
||||
if j <= xb.shape[0]:
|
||||
tmp = xb[i:j].astype(dt)
|
||||
if self.normalize:
|
||||
faiss.normalize_L2(tmp)
|
||||
yield tmp
|
||||
else:
|
||||
rem = xb.shape[0] - i
|
||||
buffer[:rem] = xb[i:j]
|
||||
if rem > 0:
|
||||
tmp = buffer[:rem]
|
||||
if self.normalize:
|
||||
faiss.normalize_L2(tmp)
|
||||
yield tmp
|
||||
|
||||
def get(self, idx: List[int]):
|
||||
n = len(idx)
|
||||
fidx = np.searchsorted(self.file_offsets, idx, "right")
|
||||
res = np.empty(shape=(len(idx), self.d), dtype=np.float32)
|
||||
for r, id, fid in zip(range(n), idx, fidx):
|
||||
assert fid > 0 and fid <= len(self.file_descriptors), f"{fid}"
|
||||
f = self.file_descriptors[fid - 1]
|
||||
# deferring normalization until after reading the vec
|
||||
vecs = _memmap_vecs(
|
||||
f"{self.root}/{f.name}", f.format, f.dtype, f.size, self.d
|
||||
)
|
||||
i = id - self.file_offsets[fid - 1]
|
||||
assert i >= 0 and i < vecs.shape[0]
|
||||
res[r, :] = vecs[i] # TODO: find a faster way
|
||||
if self.normalize:
|
||||
faiss.normalize_L2(res)
|
||||
return res
|
||||
|
||||
def sample(self, n, idx_fn, vecs_fn):
|
||||
if vecs_fn and os.path.exists(vecs_fn):
|
||||
vecs = np.load(vecs_fn)
|
||||
assert vecs.shape == (n, self.d)
|
||||
return vecs
|
||||
if idx_fn and os.path.exists(idx_fn):
|
||||
idx = np.load(idx_fn)
|
||||
assert idx.size == n
|
||||
else:
|
||||
idx = np.array(sorted(random.sample(range(self.size), n)))
|
||||
if idx_fn:
|
||||
np.save(idx_fn, idx)
|
||||
vecs = self.get(idx)
|
||||
if vecs_fn:
|
||||
np.save(vecs_fn, vecs)
|
||||
return vecs
|
||||
|
||||
def get_first_n(self, n, dt):
|
||||
assert n <= self.size
|
||||
return next(self.iterate(0, n, dt))
|
||||
46
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/generate_config.py
vendored
Normal file
46
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/generate_config.py
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import yaml
|
||||
|
||||
# with ssnpp sharded data
|
||||
root = "/checkpoint/marialomeli/ssnpp_data"
|
||||
file_names = [f"ssnpp_{i:010}.npy" for i in range(20)]
|
||||
d = 256
|
||||
dt = np.dtype(np.uint8)
|
||||
|
||||
|
||||
def read_embeddings(fp):
|
||||
fl = os.path.getsize(fp)
|
||||
nb = fl // d // dt.itemsize
|
||||
print(nb)
|
||||
if fl == d * dt.itemsize * nb: # no header
|
||||
return ("raw", np.memmap(fp, shape=(nb, d), dtype=dt, mode="r"))
|
||||
else: # assume npy
|
||||
vecs = np.load(fp, mmap_mode="r")
|
||||
assert vecs.shape[1] == d
|
||||
assert vecs.dtype == dt
|
||||
return ("npy", vecs)
|
||||
|
||||
|
||||
cfg = {}
|
||||
files = []
|
||||
size = 0
|
||||
for fn in file_names:
|
||||
fp = f"{root}/{fn}"
|
||||
assert os.path.exists(fp), f"{fp} is missing"
|
||||
ft, xb = read_embeddings(fp)
|
||||
files.append(
|
||||
{"name": fn, "size": xb.shape[0], "dtype": dt.name, "format": ft}
|
||||
)
|
||||
size += xb.shape[0]
|
||||
|
||||
cfg["size"] = size
|
||||
cfg["root"] = root
|
||||
cfg["d"] = d
|
||||
cfg["files"] = files
|
||||
print(yaml.dump(cfg))
|
||||
891
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/offline_ivf.py
vendored
Normal file
891
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/offline_ivf.py
vendored
Normal file
@@ -0,0 +1,891 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
import sys
|
||||
import logging
|
||||
from faiss.contrib.ondisk import merge_ondisk
|
||||
from faiss.contrib.big_batch_search import big_batch_search
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth
|
||||
from faiss.contrib.evaluation import knn_intersection_measure
|
||||
from utils import (
|
||||
get_intersection_cardinality_frequencies,
|
||||
margin,
|
||||
is_pretransform_index,
|
||||
)
|
||||
from dataset import create_dataset_from_oivf_config
|
||||
|
||||
logging.basicConfig(
|
||||
format=(
|
||||
"%(asctime)s.%(msecs)03d %(levelname)-8s %(threadName)-12s %(message)s"
|
||||
),
|
||||
level=logging.INFO,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
|
||||
EMBEDDINGS_BATCH_SIZE: int = 100_000
|
||||
NUM_SUBSAMPLES: int = 100
|
||||
SMALL_DATA_SAMPLE: int = 10000
|
||||
|
||||
|
||||
class OfflineIVF:
|
||||
def __init__(self, cfg, args, nprobe, index_factory_str):
|
||||
self.input_d = cfg["d"]
|
||||
self.dt = cfg["datasets"][args.xb]["files"][0]["dtype"]
|
||||
assert self.input_d > 0
|
||||
output_dir = cfg["output"]
|
||||
assert os.path.exists(output_dir)
|
||||
self.index_factory = index_factory_str
|
||||
assert self.index_factory is not None
|
||||
self.index_factory_fn = self.index_factory.replace(",", "_")
|
||||
self.index_template_file = (
|
||||
f"{output_dir}/{args.xb}/{self.index_factory_fn}.empty.faissindex"
|
||||
)
|
||||
logging.info(f"index template: {self.index_template_file}")
|
||||
|
||||
if not args.xq:
|
||||
args.xq = args.xb
|
||||
|
||||
self.by_residual = True
|
||||
if args.no_residuals:
|
||||
self.by_residual = False
|
||||
|
||||
xb_output_dir = f"{output_dir}/{args.xb}"
|
||||
if not os.path.exists(xb_output_dir):
|
||||
os.makedirs(xb_output_dir)
|
||||
xq_output_dir = f"{output_dir}/{args.xq}"
|
||||
if not os.path.exists(xq_output_dir):
|
||||
os.makedirs(xq_output_dir)
|
||||
search_output_dir = f"{output_dir}/{args.xq}_in_{args.xb}"
|
||||
if not os.path.exists(search_output_dir):
|
||||
os.makedirs(search_output_dir)
|
||||
self.knn_dir = f"{search_output_dir}/knn"
|
||||
if not os.path.exists(self.knn_dir):
|
||||
os.makedirs(self.knn_dir)
|
||||
self.eval_dir = f"{search_output_dir}/eval"
|
||||
if not os.path.exists(self.eval_dir):
|
||||
os.makedirs(self.eval_dir)
|
||||
self.index = {} # to keep a reference to opened indices,
|
||||
self.ivls = {} # hstack inverted lists,
|
||||
self.index_shards = {} # and index shards
|
||||
self.index_shard_prefix = (
|
||||
f"{xb_output_dir}/{self.index_factory_fn}.shard_"
|
||||
)
|
||||
self.xq_index_shard_prefix = (
|
||||
f"{xq_output_dir}/{self.index_factory_fn}.shard_"
|
||||
)
|
||||
self.index_file = ( # TODO: added back temporarily for evaluate, handle name of non-sharded index file and remove.
|
||||
f"{xb_output_dir}/{self.index_factory_fn}.faissindex"
|
||||
)
|
||||
self.xq_index_file = (
|
||||
f"{xq_output_dir}/{self.index_factory_fn}.faissindex"
|
||||
)
|
||||
self.training_sample = cfg["training_sample"]
|
||||
self.evaluation_sample = cfg["evaluation_sample"]
|
||||
self.xq_ds = create_dataset_from_oivf_config(cfg, args.xq)
|
||||
self.xb_ds = create_dataset_from_oivf_config(cfg, args.xb)
|
||||
file_descriptors = self.xq_ds.file_descriptors
|
||||
self.file_sizes = [fd.size for fd in file_descriptors]
|
||||
self.shard_size = cfg["index_shard_size"] # ~100GB
|
||||
self.nshards = self.xb_ds.size // self.shard_size
|
||||
if self.xb_ds.size % self.shard_size != 0:
|
||||
self.nshards += 1
|
||||
self.xq_nshards = self.xq_ds.size // self.shard_size
|
||||
if self.xq_ds.size % self.shard_size != 0:
|
||||
self.xq_nshards += 1
|
||||
self.nprobe = nprobe
|
||||
assert self.nprobe > 0, "Invalid nprobe parameter."
|
||||
if "deduper" in cfg:
|
||||
self.deduper = cfg["deduper"]
|
||||
self.deduper_codec_fn = [
|
||||
f"{xb_output_dir}/deduper_codec_{codec.replace(',', '_')}"
|
||||
for codec in self.deduper
|
||||
]
|
||||
self.deduper_idx_fn = [
|
||||
f"{xb_output_dir}/deduper_idx_{codec.replace(',', '_')}"
|
||||
for codec in self.deduper
|
||||
]
|
||||
else:
|
||||
self.deduper = None
|
||||
self.k = cfg["k"]
|
||||
assert self.k > 0, "Invalid number of neighbours parameter."
|
||||
self.knn_output_file_suffix = (
|
||||
f"{self.index_factory_fn}_np{self.nprobe}.npy"
|
||||
)
|
||||
|
||||
fp = 32
|
||||
if self.dt == "float16":
|
||||
fp = 16
|
||||
|
||||
self.xq_bs = cfg["query_batch_size"]
|
||||
if "metric" in cfg:
|
||||
self.metric = eval(f'faiss.{cfg["metric"]}')
|
||||
else:
|
||||
self.metric = faiss.METRIC_L2
|
||||
|
||||
if "evaluate_by_margin" in cfg:
|
||||
self.evaluate_by_margin = cfg["evaluate_by_margin"]
|
||||
else:
|
||||
self.evaluate_by_margin = False
|
||||
|
||||
os.system("grep -m1 'model name' < /proc/cpuinfo")
|
||||
os.system("grep -E 'MemTotal|MemFree' /proc/meminfo")
|
||||
os.system("nvidia-smi")
|
||||
os.system("nvcc --version")
|
||||
|
||||
self.knn_queries_memory_limit = 4 * 1024 * 1024 * 1024 # 4 GB
|
||||
self.knn_vectors_memory_limit = 8 * 1024 * 1024 * 1024 # 8 GB
|
||||
|
||||
def input_stats(self):
|
||||
"""
|
||||
Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added).
|
||||
"""
|
||||
xb_sample = self.xb_ds.get_first_n(self.training_sample, np.float32)
|
||||
logging.info(f"input shape: {xb_sample.shape}")
|
||||
logging.info("running MatrixStats on training sample...")
|
||||
logging.info(faiss.MatrixStats(xb_sample).comments)
|
||||
logging.info("done")
|
||||
|
||||
def dedupe(self):
|
||||
logging.info(self.deduper)
|
||||
if self.deduper is None:
|
||||
logging.info("No deduper configured")
|
||||
return
|
||||
codecs = []
|
||||
codesets = []
|
||||
idxs = []
|
||||
for factory, filename in zip(self.deduper, self.deduper_codec_fn):
|
||||
if os.path.exists(filename):
|
||||
logging.info(f"loading trained dedupe codec: {filename}")
|
||||
codec = faiss.read_index(filename)
|
||||
else:
|
||||
logging.info(f"training dedupe codec: {factory}")
|
||||
codec = faiss.index_factory(self.input_d, factory)
|
||||
xb_sample = np.unique(
|
||||
self.xb_ds.get_first_n(100_000, np.float32), axis=0
|
||||
)
|
||||
faiss.ParameterSpace().set_index_parameter(codec, "verbose", 1)
|
||||
codec.train(xb_sample)
|
||||
logging.info(f"writing trained dedupe codec: {filename}")
|
||||
faiss.write_index(codec, filename)
|
||||
codecs.append(codec)
|
||||
codesets.append(faiss.CodeSet(codec.sa_code_size()))
|
||||
idxs.append(np.empty((0,), dtype=np.uint32))
|
||||
bs = 1_000_000
|
||||
i = 0
|
||||
for buffer in tqdm(self._iterate_transformed(self.xb_ds, 0, bs, np.float32)):
|
||||
for j in range(len(codecs)):
|
||||
codec, codeset, idx = codecs[j], codesets[j], idxs[j]
|
||||
uniq = codeset.insert(codec.sa_encode(buffer))
|
||||
idxs[j] = np.append(
|
||||
idx,
|
||||
np.arange(i, i + buffer.shape[0], dtype=np.uint32)[uniq],
|
||||
)
|
||||
i += buffer.shape[0]
|
||||
for idx, filename in zip(idxs, self.deduper_idx_fn):
|
||||
logging.info(f"writing {filename}, shape: {idx.shape}")
|
||||
np.save(filename, idx)
|
||||
logging.info("done")
|
||||
|
||||
def train_index(self):
|
||||
"""
|
||||
Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added).
|
||||
"""
|
||||
assert not os.path.exists(self.index_template_file), (
|
||||
"The train command has been ran, the index template file already"
|
||||
" exists."
|
||||
)
|
||||
xb_sample = np.unique(
|
||||
self.xb_ds.get_first_n(self.training_sample, np.float32), axis=0
|
||||
)
|
||||
logging.info(f"input shape: {xb_sample.shape}")
|
||||
index = faiss.index_factory(
|
||||
self.input_d, self.index_factory, self.metric
|
||||
)
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
index_ivf.by_residual = True
|
||||
faiss.ParameterSpace().set_index_parameter(index, "verbose", 1)
|
||||
logging.info("running training...")
|
||||
index.train(xb_sample)
|
||||
logging.info(f"writing trained index {self.index_template_file}...")
|
||||
faiss.write_index(index, self.index_template_file)
|
||||
logging.info("done")
|
||||
|
||||
def _iterate_transformed(self, ds, start, batch_size, dt):
|
||||
assert os.path.exists(self.index_template_file)
|
||||
index = faiss.read_index(self.index_template_file)
|
||||
if is_pretransform_index(index):
|
||||
vt = index.chain.at(0) # fetch pretransform
|
||||
for buffer in ds.iterate(start, batch_size, dt):
|
||||
yield vt.apply(buffer)
|
||||
else:
|
||||
for buffer in ds.iterate(start, batch_size, dt):
|
||||
yield buffer
|
||||
|
||||
def index_shard(self):
|
||||
assert os.path.exists(self.index_template_file)
|
||||
index = faiss.read_index(self.index_template_file)
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
assert self.nprobe <= index_ivf.quantizer.ntotal, (
|
||||
f"the number of vectors {index_ivf.quantizer.ntotal} is not enough"
|
||||
f" to retrieve {self.nprobe} neighbours, check."
|
||||
)
|
||||
cpu_quantizer = index_ivf.quantizer
|
||||
gpu_quantizer = faiss.index_cpu_to_all_gpus(cpu_quantizer)
|
||||
|
||||
for i in range(0, self.nshards):
|
||||
sfn = f"{self.index_shard_prefix}{i}"
|
||||
try:
|
||||
index.reset()
|
||||
index_ivf.quantizer = gpu_quantizer
|
||||
with open(sfn, "xb"):
|
||||
start = i * self.shard_size
|
||||
jj = 0
|
||||
embeddings_batch_size = min(
|
||||
EMBEDDINGS_BATCH_SIZE, self.shard_size
|
||||
)
|
||||
assert (
|
||||
self.shard_size % embeddings_batch_size == 0
|
||||
or EMBEDDINGS_BATCH_SIZE % embeddings_batch_size == 0
|
||||
), (
|
||||
f"the shard size {self.shard_size} and embeddings"
|
||||
f" shard size {EMBEDDINGS_BATCH_SIZE} are not"
|
||||
" divisible"
|
||||
)
|
||||
|
||||
for xb_j in tqdm(
|
||||
self._iterate_transformed(
|
||||
self.xb_ds,
|
||||
start,
|
||||
embeddings_batch_size,
|
||||
np.float32,
|
||||
),
|
||||
file=sys.stdout,
|
||||
):
|
||||
if is_pretransform_index(index):
|
||||
assert xb_j.shape[1] == index.chain.at(0).d_out
|
||||
index_ivf.add_with_ids(
|
||||
xb_j,
|
||||
np.arange(start + jj, start + jj + xb_j.shape[0]),
|
||||
)
|
||||
else:
|
||||
assert xb_j.shape[1] == index.d
|
||||
index.add_with_ids(
|
||||
xb_j,
|
||||
np.arange(start + jj, start + jj + xb_j.shape[0]),
|
||||
)
|
||||
jj += xb_j.shape[0]
|
||||
logging.info(jj)
|
||||
assert (
|
||||
jj <= self.shard_size
|
||||
), f"jj {jj} and shard_zide {self.shard_size}"
|
||||
if jj == self.shard_size:
|
||||
break
|
||||
logging.info(f"writing {sfn}...")
|
||||
index_ivf.quantizer = cpu_quantizer
|
||||
faiss.write_index(index, sfn)
|
||||
except FileExistsError:
|
||||
logging.info(f"skipping shard: {i}")
|
||||
continue
|
||||
logging.info("done")
|
||||
|
||||
def merge_index(self):
|
||||
ivf_file = f"{self.index_file}.ivfdata"
|
||||
|
||||
assert os.path.exists(self.index_template_file)
|
||||
assert not os.path.exists(
|
||||
ivf_file
|
||||
), f"file with embeddings data {ivf_file} not found, check."
|
||||
assert not os.path.exists(self.index_file)
|
||||
index = faiss.read_index(self.index_template_file)
|
||||
block_fnames = [
|
||||
f"{self.index_shard_prefix}{i}" for i in range(self.nshards)
|
||||
]
|
||||
for fn in block_fnames:
|
||||
assert os.path.exists(fn)
|
||||
logging.info(block_fnames)
|
||||
logging.info("merging...")
|
||||
merge_ondisk(index, block_fnames, ivf_file)
|
||||
logging.info("writing index...")
|
||||
faiss.write_index(index, self.index_file)
|
||||
logging.info("done")
|
||||
|
||||
def _cached_search(
|
||||
self,
|
||||
sample,
|
||||
xq_ds,
|
||||
xb_ds,
|
||||
idx_file,
|
||||
vecs_file,
|
||||
I_file,
|
||||
D_file,
|
||||
index_file=None,
|
||||
nprobe=None,
|
||||
):
|
||||
if not os.path.exists(I_file):
|
||||
assert not os.path.exists(I_file), f"file {I_file} does not exist "
|
||||
assert not os.path.exists(D_file), f"file {D_file} does not exist "
|
||||
xq = xq_ds.sample(sample, idx_file, vecs_file)
|
||||
|
||||
if index_file:
|
||||
D, I = self._index_nonsharded_search(index_file, xq, nprobe)
|
||||
else:
|
||||
logging.info("ground truth computations")
|
||||
db_iterator = xb_ds.iterate(0, 100_000, np.float32)
|
||||
D, I = knn_ground_truth(
|
||||
xq, db_iterator, self.k, metric_type=self.metric
|
||||
)
|
||||
assert np.amin(I) >= 0
|
||||
|
||||
np.save(I_file, I)
|
||||
np.save(D_file, D)
|
||||
else:
|
||||
assert os.path.exists(idx_file), f"file {idx_file} does not exist "
|
||||
assert os.path.exists(
|
||||
vecs_file
|
||||
), f"file {vecs_file} does not exist "
|
||||
assert os.path.exists(I_file), f"file {I_file} does not exist "
|
||||
assert os.path.exists(D_file), f"file {D_file} does not exist "
|
||||
I = np.load(I_file)
|
||||
D = np.load(D_file)
|
||||
assert I.shape == (sample, self.k), f"{I_file} shape mismatch"
|
||||
assert D.shape == (sample, self.k), f"{D_file} shape mismatch"
|
||||
return (D, I)
|
||||
|
||||
def _index_search(self, index_shard_prefix, xq, nprobe):
|
||||
assert nprobe is not None
|
||||
logging.info(
|
||||
f"open sharded index: {index_shard_prefix}, {self.nshards}"
|
||||
)
|
||||
index = self._open_sharded_index(index_shard_prefix)
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
logging.info(f"setting nprobe to {nprobe}")
|
||||
index_ivf.nprobe = nprobe
|
||||
return index.search(xq, self.k)
|
||||
|
||||
def _index_nonsharded_search(self, index_file, xq, nprobe):
|
||||
assert nprobe is not None
|
||||
logging.info(f"index {index_file}")
|
||||
assert os.path.exists(index_file), f"file {index_file} does not exist "
|
||||
index = faiss.read_index(index_file, faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||
logging.info(f"index size {index.ntotal} ")
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
logging.info(f"setting nprobe to {nprobe}")
|
||||
index_ivf.nprobe = nprobe
|
||||
return index.search(xq, self.k)
|
||||
|
||||
def _refine_distances(self, xq_ds, idx, xb_ds, I):
|
||||
xq = xq_ds.get(idx).repeat(self.k, axis=0)
|
||||
xb = xb_ds.get(I.reshape(-1))
|
||||
if self.metric == faiss.METRIC_INNER_PRODUCT:
|
||||
return (xq * xb).sum(axis=1).reshape(I.shape)
|
||||
elif self.metric == faiss.METRIC_L2:
|
||||
return ((xq - xb) ** 2).sum(axis=1).reshape(I.shape)
|
||||
else:
|
||||
raise ValueError(f"metric not supported {self.metric}")
|
||||
|
||||
def evaluate(self):
|
||||
self._evaluate(
|
||||
self.index_factory_fn,
|
||||
self.index_file,
|
||||
self.xq_index_file,
|
||||
self.nprobe,
|
||||
)
|
||||
|
||||
def _evaluate(self, index_factory_fn, index_file, xq_index_file, nprobe):
|
||||
idx_a_file = f"{self.eval_dir}/idx_a.npy"
|
||||
idx_b_gt_file = f"{self.eval_dir}/idx_b_gt.npy"
|
||||
idx_b_ann_file = (
|
||||
f"{self.eval_dir}/idx_b_ann_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
vecs_a_file = f"{self.eval_dir}/vecs_a.npy"
|
||||
vecs_b_gt_file = f"{self.eval_dir}/vecs_b_gt.npy"
|
||||
vecs_b_ann_file = (
|
||||
f"{self.eval_dir}/vecs_b_ann_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
D_a_gt_file = f"{self.eval_dir}/D_a_gt.npy"
|
||||
D_a_ann_file = (
|
||||
f"{self.eval_dir}/D_a_ann_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
D_a_ann_refined_file = f"{self.eval_dir}/D_a_ann_refined_{index_factory_fn}_np{nprobe}.npy"
|
||||
D_b_gt_file = f"{self.eval_dir}/D_b_gt.npy"
|
||||
D_b_ann_file = (
|
||||
f"{self.eval_dir}/D_b_ann_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
D_b_ann_gt_file = (
|
||||
f"{self.eval_dir}/D_b_ann_gt_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
I_a_gt_file = f"{self.eval_dir}/I_a_gt.npy"
|
||||
I_a_ann_file = (
|
||||
f"{self.eval_dir}/I_a_ann_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
I_b_gt_file = f"{self.eval_dir}/I_b_gt.npy"
|
||||
I_b_ann_file = (
|
||||
f"{self.eval_dir}/I_b_ann_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
I_b_ann_gt_file = (
|
||||
f"{self.eval_dir}/I_b_ann_gt_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
margin_gt_file = f"{self.eval_dir}/margin_gt.npy"
|
||||
margin_refined_file = (
|
||||
f"{self.eval_dir}/margin_refined_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
margin_ann_file = (
|
||||
f"{self.eval_dir}/margin_ann_{index_factory_fn}_np{nprobe}.npy"
|
||||
)
|
||||
|
||||
logging.info("exact search forward")
|
||||
# xq -> xb AKA a -> b
|
||||
D_a_gt, I_a_gt = self._cached_search(
|
||||
self.evaluation_sample,
|
||||
self.xq_ds,
|
||||
self.xb_ds,
|
||||
idx_a_file,
|
||||
vecs_a_file,
|
||||
I_a_gt_file,
|
||||
D_a_gt_file,
|
||||
)
|
||||
idx_a = np.load(idx_a_file)
|
||||
|
||||
logging.info("approximate search forward")
|
||||
D_a_ann, I_a_ann = self._cached_search(
|
||||
self.evaluation_sample,
|
||||
self.xq_ds,
|
||||
self.xb_ds,
|
||||
idx_a_file,
|
||||
vecs_a_file,
|
||||
I_a_ann_file,
|
||||
D_a_ann_file,
|
||||
index_file,
|
||||
nprobe,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"calculate refined distances on approximate search forward"
|
||||
)
|
||||
if os.path.exists(D_a_ann_refined_file):
|
||||
D_a_ann_refined = np.load(D_a_ann_refined_file)
|
||||
assert D_a_ann.shape == D_a_ann_refined.shape
|
||||
else:
|
||||
D_a_ann_refined = self._refine_distances(
|
||||
self.xq_ds, idx_a, self.xb_ds, I_a_ann
|
||||
)
|
||||
np.save(D_a_ann_refined_file, D_a_ann_refined)
|
||||
|
||||
if self.evaluate_by_margin:
|
||||
k_extract = self.k
|
||||
margin_threshold = 1.05
|
||||
logging.info(
|
||||
"exact search backward from the k_extract NN results of"
|
||||
" forward search"
|
||||
)
|
||||
# xb -> xq AKA b -> a
|
||||
D_a_b_gt = D_a_gt[:, :k_extract].ravel()
|
||||
idx_b_gt = I_a_gt[:, :k_extract].ravel()
|
||||
assert len(idx_b_gt) == self.evaluation_sample * k_extract
|
||||
np.save(idx_b_gt_file, idx_b_gt)
|
||||
# exact search
|
||||
D_b_gt, _ = self._cached_search(
|
||||
len(idx_b_gt),
|
||||
self.xb_ds,
|
||||
self.xq_ds,
|
||||
idx_b_gt_file,
|
||||
vecs_b_gt_file,
|
||||
I_b_gt_file,
|
||||
D_b_gt_file,
|
||||
) # xb and xq ^^^ are inverted
|
||||
|
||||
logging.info("margin on exact search")
|
||||
margin_gt = margin(
|
||||
self.evaluation_sample,
|
||||
idx_a,
|
||||
idx_b_gt,
|
||||
D_a_b_gt,
|
||||
D_a_gt,
|
||||
D_b_gt,
|
||||
self.k,
|
||||
k_extract,
|
||||
margin_threshold,
|
||||
)
|
||||
np.save(margin_gt_file, margin_gt)
|
||||
|
||||
logging.info(
|
||||
"exact search backward from the k_extract NN results of"
|
||||
" approximate forward search"
|
||||
)
|
||||
D_a_b_refined = D_a_ann_refined[:, :k_extract].ravel()
|
||||
idx_b_ann = I_a_ann[:, :k_extract].ravel()
|
||||
assert len(idx_b_ann) == self.evaluation_sample * k_extract
|
||||
np.save(idx_b_ann_file, idx_b_ann)
|
||||
# exact search
|
||||
D_b_ann_gt, _ = self._cached_search(
|
||||
len(idx_b_ann),
|
||||
self.xb_ds,
|
||||
self.xq_ds,
|
||||
idx_b_ann_file,
|
||||
vecs_b_ann_file,
|
||||
I_b_ann_gt_file,
|
||||
D_b_ann_gt_file,
|
||||
) # xb and xq ^^^ are inverted
|
||||
|
||||
logging.info("refined margin on approximate search")
|
||||
margin_refined = margin(
|
||||
self.evaluation_sample,
|
||||
idx_a,
|
||||
idx_b_ann,
|
||||
D_a_b_refined,
|
||||
D_a_gt, # not D_a_ann_refined(!)
|
||||
D_b_ann_gt,
|
||||
self.k,
|
||||
k_extract,
|
||||
margin_threshold,
|
||||
)
|
||||
np.save(margin_refined_file, margin_refined)
|
||||
|
||||
D_b_ann, I_b_ann = self._cached_search(
|
||||
len(idx_b_ann),
|
||||
self.xb_ds,
|
||||
self.xq_ds,
|
||||
idx_b_ann_file,
|
||||
vecs_b_ann_file,
|
||||
I_b_ann_file,
|
||||
D_b_ann_file,
|
||||
xq_index_file,
|
||||
nprobe,
|
||||
)
|
||||
|
||||
D_a_b_ann = D_a_ann[:, :k_extract].ravel()
|
||||
|
||||
logging.info("approximate search margin")
|
||||
|
||||
margin_ann = margin(
|
||||
self.evaluation_sample,
|
||||
idx_a,
|
||||
idx_b_ann,
|
||||
D_a_b_ann,
|
||||
D_a_ann,
|
||||
D_b_ann,
|
||||
self.k,
|
||||
k_extract,
|
||||
margin_threshold,
|
||||
)
|
||||
np.save(margin_ann_file, margin_ann)
|
||||
|
||||
logging.info("intersection")
|
||||
logging.info(I_a_gt)
|
||||
logging.info(I_a_ann)
|
||||
|
||||
for i in range(1, self.k + 1):
|
||||
logging.info(
|
||||
f"{i}: {knn_intersection_measure(I_a_gt[:,:i], I_a_ann[:,:i])}"
|
||||
)
|
||||
|
||||
logging.info(f"mean of gt distances: {D_a_gt.mean()}")
|
||||
logging.info(f"mean of approx distances: {D_a_ann.mean()}")
|
||||
logging.info(f"mean of refined distances: {D_a_ann_refined.mean()}")
|
||||
|
||||
logging.info("intersection cardinality frequencies")
|
||||
logging.info(get_intersection_cardinality_frequencies(I_a_ann, I_a_gt))
|
||||
|
||||
logging.info("done")
|
||||
pass
|
||||
|
||||
def _knn_function(self, xq, xb, k, metric, thread_id=None):
|
||||
try:
|
||||
return faiss.knn_gpu(
|
||||
self.all_gpu_resources[thread_id],
|
||||
xq,
|
||||
xb,
|
||||
k,
|
||||
metric=metric,
|
||||
device=thread_id,
|
||||
vectorsMemoryLimit=self.knn_vectors_memory_limit,
|
||||
queriesMemoryLimit=self.knn_queries_memory_limit,
|
||||
)
|
||||
except Exception:
|
||||
logging.info(f"knn_function failed: {xq.shape}, {xb.shape}")
|
||||
raise
|
||||
|
||||
def _coarse_quantize(self, index_ivf, xq, nprobe):
|
||||
assert nprobe <= index_ivf.quantizer.ntotal
|
||||
quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer)
|
||||
bs = 100_000
|
||||
nq = len(xq)
|
||||
q_assign = np.empty((nq, nprobe), dtype="int32")
|
||||
for i0 in trange(0, nq, bs):
|
||||
i1 = min(nq, i0 + bs)
|
||||
_, q_assign_i = quantizer.search(xq[i0:i1], nprobe)
|
||||
q_assign[i0:i1] = q_assign_i
|
||||
return q_assign
|
||||
|
||||
def search(self):
|
||||
logging.info(f"search: {self.knn_dir}")
|
||||
slurm_job_id = os.environ.get("SLURM_JOB_ID")
|
||||
|
||||
ngpu = faiss.get_num_gpus()
|
||||
logging.info(f"number of gpus: {ngpu}")
|
||||
self.all_gpu_resources = [
|
||||
faiss.StandardGpuResources() for _ in range(ngpu)
|
||||
]
|
||||
self._knn_function(
|
||||
np.zeros((10, 10), dtype=np.float16),
|
||||
np.zeros((10, 10), dtype=np.float16),
|
||||
self.k,
|
||||
metric=self.metric,
|
||||
thread_id=0,
|
||||
)
|
||||
|
||||
index = self._open_sharded_index()
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
logging.info(f"setting nprobe to {self.nprobe}")
|
||||
index_ivf.nprobe = self.nprobe
|
||||
# quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer)
|
||||
for i in range(0, self.xq_ds.size, self.xq_bs):
|
||||
Ifn = f"{self.knn_dir}/I{(i):010}_{self.knn_output_file_suffix}"
|
||||
Dfn = f"{self.knn_dir}/D_approx{(i):010}_{self.knn_output_file_suffix}"
|
||||
CPfn = f"{self.knn_dir}/CP{(i):010}_{self.knn_output_file_suffix}"
|
||||
|
||||
if slurm_job_id:
|
||||
worker_record = (
|
||||
self.knn_dir
|
||||
+ f"/record_{(i):010}_{self.knn_output_file_suffix}.txt"
|
||||
)
|
||||
if not os.path.exists(worker_record):
|
||||
logging.info(
|
||||
f"creating record file {worker_record} and saving job"
|
||||
f" id: {slurm_job_id}"
|
||||
)
|
||||
with open(worker_record, "w") as h:
|
||||
h.write(slurm_job_id)
|
||||
else:
|
||||
old_slurm_id = open(worker_record, "r").read()
|
||||
logging.info(
|
||||
f"old job slurm id {old_slurm_id} and current job id:"
|
||||
f" {slurm_job_id}"
|
||||
)
|
||||
if old_slurm_id == slurm_job_id:
|
||||
if os.path.getsize(Ifn) == 0:
|
||||
logging.info(
|
||||
f"cleaning up zero length files {Ifn} and"
|
||||
f" {Dfn}"
|
||||
)
|
||||
os.remove(Ifn)
|
||||
os.remove(Dfn)
|
||||
|
||||
try:
|
||||
if is_pretransform_index(index):
|
||||
d = index.chain.at(0).d_out
|
||||
else:
|
||||
d = self.input_d
|
||||
with open(Ifn, "xb") as f, open(Dfn, "xb") as g:
|
||||
xq_i = np.empty(
|
||||
shape=(self.xq_bs, d), dtype=np.float16
|
||||
)
|
||||
q_assign = np.empty(
|
||||
(self.xq_bs, self.nprobe), dtype=np.int32
|
||||
)
|
||||
j = 0
|
||||
quantizer = faiss.index_cpu_to_all_gpus(
|
||||
index_ivf.quantizer
|
||||
)
|
||||
for xq_i_j in tqdm(
|
||||
self._iterate_transformed(
|
||||
self.xq_ds, i, min(100_000, self.xq_bs), np.float16
|
||||
),
|
||||
file=sys.stdout,
|
||||
):
|
||||
xq_i[j:j + xq_i_j.shape[0]] = xq_i_j
|
||||
(
|
||||
_,
|
||||
q_assign[j:j + xq_i_j.shape[0]],
|
||||
) = quantizer.search(xq_i_j, self.nprobe)
|
||||
j += xq_i_j.shape[0]
|
||||
assert j <= xq_i.shape[0]
|
||||
if j == xq_i.shape[0]:
|
||||
break
|
||||
xq_i = xq_i[:j]
|
||||
q_assign = q_assign[:j]
|
||||
|
||||
assert q_assign.shape == (xq_i.shape[0], index_ivf.nprobe)
|
||||
del quantizer
|
||||
logging.info(f"computing: {Ifn}")
|
||||
logging.info(f"computing: {Dfn}")
|
||||
prefetch_threads = faiss.get_num_gpus()
|
||||
D_ann, I = big_batch_search(
|
||||
index_ivf,
|
||||
xq_i,
|
||||
self.k,
|
||||
verbose=10,
|
||||
method="knn_function",
|
||||
knn=self._knn_function,
|
||||
threaded=faiss.get_num_gpus() * 8,
|
||||
use_float16=True,
|
||||
prefetch_threads=prefetch_threads,
|
||||
computation_threads=faiss.get_num_gpus(),
|
||||
q_assign=q_assign,
|
||||
checkpoint=CPfn,
|
||||
checkpoint_freq=7200, # in seconds
|
||||
)
|
||||
assert (
|
||||
np.amin(I) >= 0
|
||||
), f"{I}, there exists negative indices, check"
|
||||
logging.info(f"saving: {Ifn}")
|
||||
np.save(f, I)
|
||||
logging.info(f"saving: {Dfn}")
|
||||
np.save(g, D_ann)
|
||||
|
||||
if os.path.exists(CPfn):
|
||||
logging.info(f"removing: {CPfn}")
|
||||
os.remove(CPfn)
|
||||
|
||||
except FileExistsError:
|
||||
logging.info(f"skipping {Ifn}, already exists")
|
||||
logging.info(f"skipping {Dfn}, already exists")
|
||||
continue
|
||||
|
||||
def _open_index_shard(self, fn):
|
||||
if fn in self.index_shards:
|
||||
index_shard = self.index_shards[fn]
|
||||
else:
|
||||
logging.info(f"open index shard: {fn}")
|
||||
index_shard = faiss.read_index(
|
||||
fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
|
||||
)
|
||||
self.index_shards[fn] = index_shard
|
||||
return index_shard
|
||||
|
||||
def _open_sharded_index(self, index_shard_prefix=None):
|
||||
if index_shard_prefix is None:
|
||||
index_shard_prefix = self.index_shard_prefix
|
||||
if index_shard_prefix in self.index:
|
||||
return self.index[index_shard_prefix]
|
||||
assert os.path.exists(
|
||||
self.index_template_file
|
||||
), f"file {self.index_template_file} does not exist "
|
||||
logging.info(f"open index template: {self.index_template_file}")
|
||||
index = faiss.read_index(self.index_template_file)
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
ilv = faiss.InvertedListsPtrVector()
|
||||
for i in range(self.nshards):
|
||||
fn = f"{index_shard_prefix}{i}"
|
||||
assert os.path.exists(fn), f"file {fn} does not exist "
|
||||
logging.info(fn)
|
||||
index_shard = self._open_index_shard(fn)
|
||||
il = faiss.downcast_index(
|
||||
faiss.extract_index_ivf(index_shard)
|
||||
).invlists
|
||||
ilv.push_back(il)
|
||||
hsil = faiss.HStackInvertedLists(ilv.size(), ilv.data())
|
||||
index_ivf.replace_invlists(hsil, False)
|
||||
self.ivls[index_shard_prefix] = hsil
|
||||
self.index[index_shard_prefix] = index
|
||||
return index
|
||||
|
||||
def index_shard_stats(self):
|
||||
for i in range(self.nshards):
|
||||
fn = f"{self.index_shard_prefix}{i}"
|
||||
assert os.path.exists(fn)
|
||||
index = faiss.read_index(
|
||||
fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
|
||||
)
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
il = index_ivf.invlists
|
||||
il.print_stats()
|
||||
|
||||
def index_stats(self):
|
||||
index = self._open_sharded_index()
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
il = index_ivf.invlists
|
||||
list_sizes = [il.list_size(i) for i in range(il.nlist)]
|
||||
logging.info(np.max(list_sizes))
|
||||
logging.info(np.mean(list_sizes))
|
||||
logging.info(np.argmax(list_sizes))
|
||||
logging.info("index_stats:")
|
||||
il.print_stats()
|
||||
|
||||
def consistency_check(self):
|
||||
logging.info("consistency-check")
|
||||
|
||||
logging.info("index template...")
|
||||
|
||||
assert os.path.exists(self.index_template_file)
|
||||
index = faiss.read_index(self.index_template_file)
|
||||
|
||||
offset = 0 # 2**24
|
||||
assert self.shard_size > offset + SMALL_DATA_SAMPLE
|
||||
|
||||
logging.info("index shards...")
|
||||
for i in range(self.nshards):
|
||||
r = i * self.shard_size + offset
|
||||
xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32))
|
||||
fn = f"{self.index_shard_prefix}{i}"
|
||||
assert os.path.exists(fn), f"There is no index shard file {fn}"
|
||||
index = self._open_index_shard(fn)
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
index_ivf.nprobe = 1
|
||||
_, I = index.search(xb, 100)
|
||||
for j in range(SMALL_DATA_SAMPLE):
|
||||
assert np.where(I[j] == j + r)[0].size > 0, (
|
||||
f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:"
|
||||
f" {self.shard_size}"
|
||||
)
|
||||
|
||||
logging.info("merged index...")
|
||||
index = self._open_sharded_index()
|
||||
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
||||
index_ivf.nprobe = 1
|
||||
for i in range(self.nshards):
|
||||
r = i * self.shard_size + offset
|
||||
xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32))
|
||||
_, I = index.search(xb, 100)
|
||||
for j in range(SMALL_DATA_SAMPLE):
|
||||
assert np.where(I[j] == j + r)[0].size > 0, (
|
||||
f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:"
|
||||
f" {self.shard_size}")
|
||||
|
||||
logging.info("search results...")
|
||||
index_ivf.nprobe = self.nprobe
|
||||
for i in range(0, self.xq_ds.size, self.xq_bs):
|
||||
Ifn = f"{self.knn_dir}/I{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy"
|
||||
assert os.path.exists(Ifn)
|
||||
assert os.path.getsize(Ifn) > 0, f"The file {Ifn} is empty."
|
||||
logging.info(Ifn)
|
||||
I = np.load(Ifn, mmap_mode="r")
|
||||
|
||||
assert I.shape[1] == self.k
|
||||
assert I.shape[0] == min(self.xq_bs, self.xq_ds.size - i)
|
||||
assert np.all(I[:, 1] >= 0)
|
||||
|
||||
Dfn = f"{self.knn_dir}/D_approx{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy"
|
||||
assert os.path.exists(Dfn)
|
||||
assert os.path.getsize(Dfn) > 0, f"The file {Dfn} is empty."
|
||||
logging.info(Dfn)
|
||||
D = np.load(Dfn, mmap_mode="r")
|
||||
assert D.shape == I.shape
|
||||
|
||||
xq = next(self.xq_ds.iterate(i, SMALL_DATA_SAMPLE, np.float32))
|
||||
D_online, I_online = index.search(xq, self.k)
|
||||
assert (
|
||||
np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size
|
||||
/ (self.k * SMALL_DATA_SAMPLE)
|
||||
> 0.95
|
||||
), (
|
||||
"the ratio is"
|
||||
f" {np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size / (self.k * SMALL_DATA_SAMPLE)}"
|
||||
)
|
||||
assert np.allclose(
|
||||
D[:SMALL_DATA_SAMPLE].sum(axis=1),
|
||||
D_online.sum(axis=1),
|
||||
rtol=0.01,
|
||||
), (
|
||||
"the difference is"
|
||||
f" {D[:SMALL_DATA_SAMPLE].sum(axis=1), D_online.sum(axis=1)}"
|
||||
)
|
||||
|
||||
logging.info("done")
|
||||
219
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/run.py
vendored
Normal file
219
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/run.py
vendored
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from utils import (
|
||||
load_config,
|
||||
add_group_args,
|
||||
)
|
||||
from offline_ivf import OfflineIVF
|
||||
import faiss
|
||||
from typing import List, Callable, Dict
|
||||
import submitit
|
||||
|
||||
|
||||
def join_lists_in_dict(poss: List[str]) -> List[str]:
|
||||
"""
|
||||
Joins two lists of prod and non-prod values, checking if the prod value is already included.
|
||||
If there is no non-prod list, it returns the prod list.
|
||||
"""
|
||||
if "non-prod" in poss.keys():
|
||||
all_poss = poss["non-prod"]
|
||||
if poss["prod"][-1] not in poss["non-prod"]:
|
||||
all_poss += poss["prod"]
|
||||
return all_poss
|
||||
else:
|
||||
return poss["prod"]
|
||||
|
||||
|
||||
def main(
|
||||
args: argparse.Namespace,
|
||||
cfg: Dict[str, str],
|
||||
nprobe: int,
|
||||
index_factory_str: str,
|
||||
) -> None:
|
||||
oivf = OfflineIVF(cfg, args, nprobe, index_factory_str)
|
||||
eval(f"oivf.{args.command}()")
|
||||
|
||||
|
||||
def process_options_and_run_jobs(args: argparse.Namespace) -> None:
|
||||
"""
|
||||
If "--cluster_run", it launches an array of jobs to the cluster using the submitit library for all the index strings. In
|
||||
the case of evaluate, it launches a job for each index string and nprobe pair. Otherwise, it launches a single job
|
||||
that is ran locally with the prod values for index string and nprobe.
|
||||
"""
|
||||
|
||||
cfg = load_config(args.config)
|
||||
index_strings = cfg["index"]
|
||||
nprobes = cfg["nprobe"]
|
||||
if args.command == "evaluate":
|
||||
if args.cluster_run:
|
||||
all_nprobes = join_lists_in_dict(nprobes)
|
||||
all_index_strings = join_lists_in_dict(index_strings)
|
||||
for index_factory_str in all_index_strings:
|
||||
for nprobe in all_nprobes:
|
||||
launch_job(main, args, cfg, nprobe, index_factory_str)
|
||||
else:
|
||||
launch_job(
|
||||
main, args, cfg, nprobes["prod"][-1], index_strings["prod"][-1]
|
||||
)
|
||||
else:
|
||||
if args.cluster_run:
|
||||
all_index_strings = join_lists_in_dict(index_strings)
|
||||
for index_factory_str in all_index_strings:
|
||||
launch_job(
|
||||
main, args, cfg, nprobes["prod"][-1], index_factory_str
|
||||
)
|
||||
else:
|
||||
launch_job(
|
||||
main, args, cfg, nprobes["prod"][-1], index_strings["prod"][-1]
|
||||
)
|
||||
|
||||
|
||||
def launch_job(
|
||||
func: Callable,
|
||||
args: argparse.Namespace,
|
||||
cfg: Dict[str, str],
|
||||
n_probe: int,
|
||||
index_str: str,
|
||||
) -> None:
|
||||
"""
|
||||
Launches an array of slurm jobs to the cluster using the submitit library.
|
||||
"""
|
||||
|
||||
if args.cluster_run:
|
||||
assert args.num_nodes >= 1
|
||||
executor = submitit.AutoExecutor(folder=args.logs_dir)
|
||||
|
||||
executor.update_parameters(
|
||||
nodes=args.num_nodes,
|
||||
gpus_per_node=args.gpus_per_node,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
tasks_per_node=args.tasks_per_node,
|
||||
name=args.job_name,
|
||||
slurm_partition=args.partition,
|
||||
slurm_time=70 * 60,
|
||||
)
|
||||
if args.slurm_constraint:
|
||||
executor.update_parameters(slurm_constraint=args.slurm_constrain)
|
||||
|
||||
job = executor.submit(func, args, cfg, n_probe, index_str)
|
||||
print(f"Job id: {job.job_id}")
|
||||
else:
|
||||
func(args, cfg, n_probe, index_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group("general")
|
||||
|
||||
add_group_args(group, "--command", required=True, help="command to run")
|
||||
add_group_args(
|
||||
group,
|
||||
"--config",
|
||||
required=True,
|
||||
help="config yaml with the dataset specs",
|
||||
)
|
||||
add_group_args(
|
||||
group, "--nt", type=int, default=96, help="nb search threads"
|
||||
)
|
||||
add_group_args(
|
||||
group,
|
||||
"--no_residuals",
|
||||
action="store_false",
|
||||
help="set index.by_residual to False during train index.",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("slurm_job")
|
||||
|
||||
add_group_args(
|
||||
group,
|
||||
"--cluster_run",
|
||||
action="store_true",
|
||||
help=" if True, runs in cluster",
|
||||
)
|
||||
add_group_args(
|
||||
group,
|
||||
"--job_name",
|
||||
type=str,
|
||||
default="oivf",
|
||||
help="cluster job name",
|
||||
)
|
||||
add_group_args(
|
||||
group,
|
||||
"--num_nodes",
|
||||
type=str,
|
||||
default=1,
|
||||
help="num of nodes per job",
|
||||
)
|
||||
add_group_args(
|
||||
group,
|
||||
"--tasks_per_node",
|
||||
type=int,
|
||||
default=1,
|
||||
help="tasks per job",
|
||||
)
|
||||
|
||||
add_group_args(
|
||||
group,
|
||||
"--gpus_per_node",
|
||||
type=int,
|
||||
default=8,
|
||||
help="cluster job name",
|
||||
)
|
||||
add_group_args(
|
||||
group,
|
||||
"--cpus_per_task",
|
||||
type=int,
|
||||
default=80,
|
||||
help="cluster job name",
|
||||
)
|
||||
|
||||
add_group_args(
|
||||
group,
|
||||
"--logs_dir",
|
||||
type=str,
|
||||
default="/checkpoint/marialomeli/offline_faiss/logs",
|
||||
help="cluster job name",
|
||||
)
|
||||
|
||||
add_group_args(
|
||||
group,
|
||||
"--slurm_constraint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="can be volta32gb for the fair cluster",
|
||||
)
|
||||
|
||||
add_group_args(
|
||||
group,
|
||||
"--partition",
|
||||
type=str,
|
||||
default="learnlab",
|
||||
help="specify which partition to use if ran on cluster with job arrays",
|
||||
choices=[
|
||||
"learnfair",
|
||||
"devlab",
|
||||
"scavenge",
|
||||
"learnlab",
|
||||
"nllb",
|
||||
"seamless",
|
||||
"seamless_medium",
|
||||
"learnaccel",
|
||||
"onellm_low",
|
||||
"learn",
|
||||
"scavenge",
|
||||
],
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("dataset")
|
||||
|
||||
add_group_args(group, "--xb", required=True, help="database vectors")
|
||||
add_group_args(group, "--xq", help="query vectors")
|
||||
|
||||
args = parser.parse_args()
|
||||
print("args:", args)
|
||||
faiss.omp_set_num_threads(args.nt)
|
||||
process_options_and_run_jobs(args=args)
|
||||
181
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/tests/testing_utils.py
vendored
Normal file
181
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/tests/testing_utils.py
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import yaml
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
OIVF_TEST_ARGS: List[str] = [
|
||||
"--config",
|
||||
"--xb",
|
||||
"--xq",
|
||||
"--command",
|
||||
"--cluster_run",
|
||||
"--no_residuals",
|
||||
]
|
||||
|
||||
|
||||
def get_test_parser(args) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
for arg in args:
|
||||
parser.add_argument(arg)
|
||||
return parser
|
||||
|
||||
|
||||
class TestDataCreator:
|
||||
def __init__(
|
||||
self,
|
||||
tempdir: str,
|
||||
dimension: int,
|
||||
data_type: np.dtype,
|
||||
index_factory: Optional[List] = ["OPQ4,IVF256,PQ4"],
|
||||
training_sample: Optional[int] = 9984,
|
||||
index_shard_size: Optional[int] = 1000,
|
||||
query_batch_size: Optional[int] = 1000,
|
||||
evaluation_sample: Optional[int] = 100,
|
||||
num_files: Optional[int] = None,
|
||||
file_size: Optional[int] = None,
|
||||
file_sizes: Optional[List] = None,
|
||||
nprobe: Optional[int] = 64,
|
||||
k: Optional[int] = 10,
|
||||
metric: Optional[str] = "METRIC_L2",
|
||||
normalise: Optional[bool] = False,
|
||||
with_queries_ds: Optional[bool] = False,
|
||||
evaluate_by_margin: Optional[bool] = False,
|
||||
) -> None:
|
||||
self.tempdir = tempdir
|
||||
self.dimension = dimension
|
||||
self.data_type = np.dtype(data_type).name
|
||||
self.index_factory = {"prod": index_factory}
|
||||
if file_size and num_files:
|
||||
self.file_sizes = [file_size for _ in range(num_files)]
|
||||
elif file_sizes:
|
||||
self.file_sizes = file_sizes
|
||||
else:
|
||||
raise ValueError("no file sizes provided")
|
||||
self.num_files = len(self.file_sizes)
|
||||
self.training_sample = training_sample
|
||||
self.index_shard_size = index_shard_size
|
||||
self.query_batch_size = query_batch_size
|
||||
self.evaluation_sample = evaluation_sample
|
||||
self.nprobe = {"prod": [nprobe]}
|
||||
self.k = k
|
||||
self.metric = metric
|
||||
self.normalise = normalise
|
||||
self.config_file = self.tempdir + "/config_test.yaml"
|
||||
self.ds_name = "my_test_data"
|
||||
self.qs_name = "my_queries_data"
|
||||
self.evaluate_by_margin = evaluate_by_margin
|
||||
self.with_queries_ds = with_queries_ds
|
||||
|
||||
def create_test_data(self) -> None:
|
||||
datafiles = self._create_data_files()
|
||||
files_info = []
|
||||
|
||||
for i, file in enumerate(datafiles):
|
||||
files_info.append(
|
||||
{
|
||||
"dtype": self.data_type,
|
||||
"format": "npy",
|
||||
"name": file,
|
||||
"size": self.file_sizes[i],
|
||||
}
|
||||
)
|
||||
|
||||
config_for_yaml = {
|
||||
"d": self.dimension,
|
||||
"output": self.tempdir,
|
||||
"index": self.index_factory,
|
||||
"nprobe": self.nprobe,
|
||||
"k": self.k,
|
||||
"normalise": self.normalise,
|
||||
"metric": self.metric,
|
||||
"training_sample": self.training_sample,
|
||||
"evaluation_sample": self.evaluation_sample,
|
||||
"index_shard_size": self.index_shard_size,
|
||||
"query_batch_size": self.query_batch_size,
|
||||
"datasets": {
|
||||
self.ds_name: {
|
||||
"root": self.tempdir,
|
||||
"size": sum(self.file_sizes),
|
||||
"files": files_info,
|
||||
}
|
||||
},
|
||||
}
|
||||
if self.evaluate_by_margin:
|
||||
config_for_yaml["evaluate_by_margin"] = self.evaluate_by_margin
|
||||
q_datafiles = self._create_data_files("my_q_data")
|
||||
q_files_info = []
|
||||
|
||||
for i, file in enumerate(q_datafiles):
|
||||
q_files_info.append(
|
||||
{
|
||||
"dtype": self.data_type,
|
||||
"format": "npy",
|
||||
"name": file,
|
||||
"size": self.file_sizes[i],
|
||||
}
|
||||
)
|
||||
if self.with_queries_ds:
|
||||
config_for_yaml["datasets"][self.qs_name] = {
|
||||
"root": self.tempdir,
|
||||
"size": sum(self.file_sizes),
|
||||
"files": q_files_info,
|
||||
}
|
||||
|
||||
self._create_config_yaml(config_for_yaml)
|
||||
|
||||
def setup_cli(self, command="consistency_check") -> argparse.Namespace:
|
||||
parser = get_test_parser(OIVF_TEST_ARGS)
|
||||
|
||||
if self.with_queries_ds:
|
||||
return parser.parse_args(
|
||||
[
|
||||
"--xb",
|
||||
self.ds_name,
|
||||
"--config",
|
||||
self.config_file,
|
||||
"--command",
|
||||
command,
|
||||
"--xq",
|
||||
self.qs_name,
|
||||
]
|
||||
)
|
||||
return parser.parse_args(
|
||||
[
|
||||
"--xb",
|
||||
self.ds_name,
|
||||
"--config",
|
||||
self.config_file,
|
||||
"--command",
|
||||
command,
|
||||
]
|
||||
)
|
||||
|
||||
def _create_data_files(self, name_of_file="my_data") -> List[str]:
|
||||
"""
|
||||
Creates a dataset "my_test_data" with number of files (num_files), using padding in the files
|
||||
name. If self.with_queries is True, it adds an extra dataset "my_queries_data" with the same number of files
|
||||
as the "my_test_data". The default name for embeddings files is "my_data" + <padding>.npy.
|
||||
"""
|
||||
filenames = []
|
||||
for i, file_size in enumerate(self.file_sizes):
|
||||
# np.random.seed(i)
|
||||
db_vectors = np.random.random((file_size, self.dimension)).astype(
|
||||
self.data_type
|
||||
)
|
||||
filename = name_of_file + f"{i:02}" + ".npy"
|
||||
filenames.append(filename)
|
||||
np.save(self.tempdir + "/" + filename, db_vectors)
|
||||
return filenames
|
||||
|
||||
def _create_config_yaml(self, dict_file: Dict[str, str]) -> None:
|
||||
"""
|
||||
Creates a yaml file in dir (can be a temporary dir for tests).
|
||||
"""
|
||||
filename = self.tempdir + "/config_test.yaml"
|
||||
with open(filename, "w") as file:
|
||||
yaml.dump(dict_file, file, default_flow_style=False)
|
||||
95
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/utils.py
vendored
Normal file
95
packages/leann-backend-hnsw/third_party/faiss/demos/offline_ivf/utils.py
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from typing import Dict
|
||||
import yaml
|
||||
import faiss
|
||||
from faiss.contrib.datasets import SyntheticDataset
|
||||
|
||||
|
||||
def load_config(config):
|
||||
assert os.path.exists(config)
|
||||
with open(config, "r") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def faiss_sanity_check():
|
||||
ds = SyntheticDataset(256, 0, 100, 100)
|
||||
xq = ds.get_queries()
|
||||
xb = ds.get_database()
|
||||
index_cpu = faiss.IndexFlat(ds.d)
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
index_cpu.add(xb)
|
||||
index_gpu.add(xb)
|
||||
D_cpu, I_cpu = index_cpu.search(xq, 10)
|
||||
D_gpu, I_gpu = index_gpu.search(xq, 10)
|
||||
assert np.all(I_cpu == I_gpu), "faiss sanity check failed"
|
||||
assert np.all(np.isclose(D_cpu, D_gpu)), "faiss sanity check failed"
|
||||
|
||||
|
||||
def margin(sample, idx_a, idx_b, D_a_b, D_a, D_b, k, k_extract, threshold):
|
||||
"""
|
||||
two datasets: xa, xb; n = number of pairs
|
||||
idx_a - (np,) - query vector ids in xa
|
||||
idx_b - (np,) - query vector ids in xb
|
||||
D_a_b - (np,) - pairwise distances between xa[idx_a] and xb[idx_b]
|
||||
D_a - (np, k) - distances between vectors xa[idx_a] and corresponding nearest neighbours in xb
|
||||
D_b - (np, k) - distances between vectors xb[idx_b] and corresponding nearest neighbours in xa
|
||||
k - k nearest neighbours used for margin
|
||||
k_extract - number of nearest neighbours of each query in xb we consider for margin calculation and filtering
|
||||
threshold - margin threshold
|
||||
"""
|
||||
|
||||
n = sample
|
||||
nk = n * k_extract
|
||||
assert idx_a.shape == (n,)
|
||||
idx_a_k = idx_a.repeat(k_extract)
|
||||
assert idx_a_k.shape == (nk,)
|
||||
assert idx_b.shape == (nk,)
|
||||
assert D_a_b.shape == (nk,)
|
||||
assert D_a.shape == (n, k)
|
||||
assert D_b.shape == (nk, k)
|
||||
mean_a = np.mean(D_a, axis=1)
|
||||
assert mean_a.shape == (n,)
|
||||
mean_a_k = mean_a.repeat(k_extract)
|
||||
assert mean_a_k.shape == (nk,)
|
||||
mean_b = np.mean(D_b, axis=1)
|
||||
assert mean_b.shape == (nk,)
|
||||
margin = 2 * D_a_b / (mean_a_k + mean_b)
|
||||
above_threshold = margin > threshold
|
||||
print(np.count_nonzero(above_threshold))
|
||||
print(idx_a_k[above_threshold])
|
||||
print(idx_b[above_threshold])
|
||||
print(margin[above_threshold])
|
||||
return margin
|
||||
|
||||
|
||||
def add_group_args(group, *args, **kwargs):
|
||||
return group.add_argument(*args, **kwargs)
|
||||
|
||||
|
||||
def get_intersection_cardinality_frequencies(
|
||||
I: np.ndarray, I_gt: np.ndarray
|
||||
) -> Dict[int, int]:
|
||||
"""
|
||||
Computes the frequencies for the cardinalities of the intersection of neighbour indices.
|
||||
"""
|
||||
nq = I.shape[0]
|
||||
res = []
|
||||
for ell in range(nq):
|
||||
res.append(len(np.intersect1d(I[ell, :], I_gt[ell, :])))
|
||||
values, counts = np.unique(res, return_counts=True)
|
||||
return dict(zip(values, counts))
|
||||
|
||||
|
||||
def is_pretransform_index(index):
|
||||
if index.__class__ == faiss.IndexPreTransform:
|
||||
assert hasattr(index, "chain")
|
||||
return True
|
||||
else:
|
||||
assert not hasattr(index, "chain")
|
||||
return False
|
||||
Reference in New Issue
Block a user