1220 lines
42 KiB
Python
1220 lines
42 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from operator import itemgetter
|
|
from statistics import mean, median
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import faiss # @manual=//faiss/python:pyfaiss
|
|
|
|
import numpy as np
|
|
|
|
from scipy.optimize import curve_fit
|
|
|
|
from .benchmark_io import BenchmarkIO
|
|
|
|
from .descriptors import (
|
|
CodecDescriptor,
|
|
DatasetDescriptor,
|
|
IndexDescriptor,
|
|
IndexDescriptorClassic,
|
|
KnnDescriptor,
|
|
)
|
|
|
|
from .index import Index, IndexFromCodec, IndexFromFactory
|
|
|
|
from .utils import dict_merge
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def range_search_pr_curve(
|
|
dist_ann: np.ndarray, metric_score: np.ndarray, gt_rsm: float
|
|
):
|
|
assert dist_ann.shape == metric_score.shape
|
|
assert dist_ann.ndim == 1
|
|
l = len(dist_ann)
|
|
if l == 0:
|
|
return {
|
|
"dist_ann": [],
|
|
"metric_score_sample": [],
|
|
"cum_score": [],
|
|
"precision": [],
|
|
"recall": [],
|
|
"unique_key": [],
|
|
}
|
|
sort_by_dist_ann = dist_ann.argsort()
|
|
dist_ann = dist_ann[sort_by_dist_ann]
|
|
metric_score = metric_score[sort_by_dist_ann]
|
|
cum_score = np.cumsum(metric_score)
|
|
precision = cum_score / np.arange(1, len(cum_score) + 1)
|
|
recall = cum_score / gt_rsm
|
|
unique_key = np.round(precision * 100) * 100 + np.round(recall * 100)
|
|
tbl = np.vstack(
|
|
[dist_ann, metric_score, cum_score, precision, recall, unique_key]
|
|
)
|
|
group_by_dist_max_cum_score = np.empty(l, bool)
|
|
group_by_dist_max_cum_score[-1] = True
|
|
group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1]
|
|
tbl = tbl[:, group_by_dist_max_cum_score]
|
|
_, unique_key_idx = np.unique(tbl[5], return_index=True)
|
|
dist_ann, metric_score, cum_score, precision, recall, unique_key = tbl[
|
|
:, np.sort(unique_key_idx)
|
|
].tolist()
|
|
return {
|
|
"dist_ann": dist_ann,
|
|
"metric_score_sample": metric_score,
|
|
"cum_score": cum_score,
|
|
"precision": precision,
|
|
"recall": recall,
|
|
"unique_key": unique_key,
|
|
}
|
|
|
|
|
|
def optimizer(op, search, cost_metric, perf_metric):
|
|
totex = op.num_experiments()
|
|
rs = np.random.RandomState(123)
|
|
if totex > 1:
|
|
experiments = rs.permutation(totex - 2) + 1
|
|
experiments = [0, totex - 1] + list(experiments)
|
|
else:
|
|
experiments = [0]
|
|
|
|
print(f"total nb experiments {totex}, running {len(experiments)}")
|
|
|
|
for cno in experiments:
|
|
key = op.cno_to_key(cno)
|
|
parameters = op.get_parameters(key)
|
|
|
|
(max_perf, min_cost) = op.predict_bounds(key)
|
|
if not op.is_pareto_optimal(max_perf, min_cost):
|
|
logger.info(
|
|
f"{cno=:4d} {str(parameters):50}: SKIP, {max_perf=:.3f} {min_cost=:.3f}",
|
|
)
|
|
continue
|
|
|
|
logger.info(f"{cno=:4d} {str(parameters):50}: RUN")
|
|
cost, perf, requires = search(
|
|
parameters,
|
|
cost_metric,
|
|
perf_metric,
|
|
)
|
|
if requires is not None:
|
|
return requires
|
|
logger.info(
|
|
f"{cno=:4d} {str(parameters):50}: DONE, {cost=:.3f} {perf=:.3f}"
|
|
)
|
|
op.add_operating_point(key, perf, cost)
|
|
return None
|
|
|
|
|
|
# range_metric possible values:
|
|
#
|
|
# radius
|
|
# [0..radius) -> 1
|
|
# [radius..inf) -> 0
|
|
#
|
|
# [[radius1, score1], ...]
|
|
# [0..radius1) -> score1
|
|
# [radius1..radius2) -> score2
|
|
#
|
|
# [[radius1_from, radius1_to, score1], ...]
|
|
# [radius1_from, radius1_to) -> score1,
|
|
# [radius2_from, radius2_to) -> score2
|
|
def get_range_search_metric_function(range_metric, D, R):
|
|
if D is not None:
|
|
assert R is not None
|
|
assert D.shape == R.shape
|
|
if isinstance(range_metric, list):
|
|
aradius, ascore, aradius_from, aradius_to = [], [], [], []
|
|
radius_to = 0
|
|
for rsd in range_metric:
|
|
assert isinstance(rsd, list)
|
|
if len(rsd) == 3:
|
|
radius_from, radius_to, score = rsd
|
|
elif len(rsd) == 2:
|
|
radius_from = radius_to
|
|
radius_to, score = rsd
|
|
else:
|
|
raise AssertionError(f"invalid range definition {rsd}")
|
|
# radius_from and radius_to are compressed distances,
|
|
# we need to convert them to real embedding distances.
|
|
if D is not None:
|
|
sample_idxs = np.argwhere((D <= radius_to) & (D > radius_from))
|
|
assert len(sample_idxs) > 0
|
|
real_radius = np.mean(R[sample_idxs]).item()
|
|
else:
|
|
real_radius = mean([radius_from, radius_to])
|
|
logger.info(
|
|
f"range_search_metric_function {radius_from=} {radius_to=} {real_radius=} {score=}"
|
|
)
|
|
aradius.append(real_radius)
|
|
ascore.append(score)
|
|
aradius_from.append(radius_from)
|
|
aradius_to.append(radius_to)
|
|
|
|
def sigmoid(x, a, b, c):
|
|
return a / (1 + np.exp(b * x - c))
|
|
|
|
cutoff = max(aradius)
|
|
popt, _ = curve_fit(sigmoid, aradius, ascore, [1, 5, 5])
|
|
|
|
for r in np.arange(0, cutoff + 0.05, 0.05):
|
|
logger.info(
|
|
f"range_search_metric_function {r=} {sigmoid(r, *popt)=}"
|
|
)
|
|
|
|
assert isinstance(cutoff, float)
|
|
return (
|
|
cutoff,
|
|
lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0),
|
|
popt.tolist(),
|
|
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True)),
|
|
)
|
|
else:
|
|
# Assuming that the range_metric is a float,
|
|
# so the range is [0..range_metric).
|
|
# D is the result of a range_search with a radius of range_metric,
|
|
# but both range_metric and D may be compressed distances.
|
|
# We approximate the real embedding distance as max(R).
|
|
if R is not None:
|
|
real_range = np.max(R).item()
|
|
else:
|
|
real_range = range_metric
|
|
logger.info(
|
|
f"range_search_metric_function {range_metric=} {real_range=}"
|
|
)
|
|
assert isinstance(real_range, float)
|
|
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), [], []
|
|
|
|
|
|
@dataclass
|
|
class IndexOperator:
|
|
num_threads: int
|
|
distance_metric: str
|
|
|
|
def __post_init__(self):
|
|
if self.distance_metric == "IP":
|
|
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
|
|
elif self.distance_metric == "L2":
|
|
self.distance_metric_type = faiss.METRIC_L2
|
|
else:
|
|
raise ValueError
|
|
|
|
def set_io(self, benchmark_io: BenchmarkIO):
|
|
self.io = benchmark_io
|
|
self.io.distance_metric = self.distance_metric
|
|
self.io.distance_metric_type = self.distance_metric_type
|
|
|
|
|
|
@dataclass
|
|
class TrainOperator(IndexOperator):
|
|
codec_descs: List[CodecDescriptor] = field(default_factory=lambda: [])
|
|
assemble_opaque: bool = True
|
|
|
|
def get_desc(self, name: str) -> Optional[CodecDescriptor]:
|
|
for desc in self.codec_descs:
|
|
if desc.get_name() == name:
|
|
return desc
|
|
elif desc.factory == name:
|
|
return desc
|
|
return None
|
|
|
|
def get_flat_desc(self, name=None) -> Optional[CodecDescriptor]:
|
|
for desc in self.codec_descs:
|
|
desc_name = desc.get_name()
|
|
if desc_name == name:
|
|
return desc
|
|
if desc_name.startswith("Flat"):
|
|
return desc
|
|
return None
|
|
|
|
def build_index_wrapper(self, codec_desc: CodecDescriptor):
|
|
if hasattr(codec_desc, "index"):
|
|
return
|
|
|
|
if codec_desc.factory is not None:
|
|
assert (
|
|
codec_desc.factory == "Flat" or codec_desc.training_vectors is not None
|
|
)
|
|
index = IndexFromFactory(
|
|
num_threads=self.num_threads,
|
|
d=codec_desc.d,
|
|
metric=self.distance_metric,
|
|
construction_params=codec_desc.construction_params,
|
|
factory=codec_desc.factory,
|
|
training_vectors=codec_desc.training_vectors,
|
|
codec_name=codec_desc.get_name(),
|
|
assemble_opaque=self.assemble_opaque,
|
|
)
|
|
index.set_io(self.io)
|
|
codec_desc.index = index
|
|
else:
|
|
assert codec_desc.is_trained()
|
|
|
|
def train_one(
|
|
self, codec_desc: CodecDescriptor, results: Dict[str, Any], dry_run=False
|
|
):
|
|
faiss.omp_set_num_threads(codec_desc.num_threads)
|
|
self.build_index_wrapper(codec_desc)
|
|
if codec_desc.is_trained():
|
|
return results, None
|
|
|
|
if dry_run:
|
|
meta, requires = codec_desc.index.fetch_meta(dry_run=dry_run)
|
|
else:
|
|
codec_desc.index.get_codec()
|
|
meta, requires = codec_desc.index.fetch_meta(dry_run=dry_run)
|
|
assert requires is None
|
|
|
|
if requires is None:
|
|
results["indices"][codec_desc.get_name()] = meta
|
|
return results, requires
|
|
|
|
def train(self, results, dry_run=False):
|
|
for desc in self.codec_descs:
|
|
results, requires = self.train_one(desc, results, dry_run=dry_run)
|
|
if dry_run:
|
|
if requires is None:
|
|
continue
|
|
return results, requires
|
|
assert requires is None
|
|
return results, None
|
|
|
|
|
|
@dataclass
|
|
class BuildOperator(IndexOperator):
|
|
index_descs: List[IndexDescriptor] = field(default_factory=lambda: [])
|
|
serialize_index: bool = False
|
|
|
|
def get_desc(self, name: str) -> Optional[IndexDescriptor]:
|
|
for desc in self.index_descs:
|
|
if desc.get_name() == name:
|
|
return desc
|
|
return None
|
|
|
|
def get_flat_desc(self, name=None) -> Optional[IndexDescriptor]:
|
|
for desc in self.index_descs:
|
|
desc_name = desc.get_name()
|
|
if desc_name == name:
|
|
return desc
|
|
if desc_name.startswith("Flat"):
|
|
return desc
|
|
return None
|
|
|
|
def build_index_wrapper(self, index_desc: IndexDescriptor):
|
|
if hasattr(index_desc, "index"):
|
|
return
|
|
|
|
if hasattr(index_desc.codec_desc, "index"):
|
|
index_desc.index = index_desc.codec_desc.index
|
|
index_desc.index.database_vectors = index_desc.database_desc
|
|
index_desc.index.index_name = index_desc.get_name()
|
|
return
|
|
|
|
if index_desc.codec_desc is not None:
|
|
index = IndexFromCodec(
|
|
num_threads=self.num_threads,
|
|
d=index_desc.d,
|
|
metric=self.distance_metric,
|
|
database_vectors=index_desc.database_desc,
|
|
bucket=index_desc.codec_desc.bucket,
|
|
path=index_desc.codec_desc.path,
|
|
index_name=index_desc.get_name(),
|
|
codec_name=index_desc.codec_desc.get_name(),
|
|
serialize_full_index=self.serialize_index,
|
|
)
|
|
index.set_io(self.io)
|
|
index_desc.index = index
|
|
else:
|
|
assert index_desc.is_built()
|
|
|
|
def build_one(self, index_desc: IndexDescriptor, results: Dict[str, Any]):
|
|
faiss.omp_set_num_threads(index_desc.num_threads)
|
|
self.build_index_wrapper(index_desc)
|
|
if index_desc.is_built():
|
|
return
|
|
index_desc.index.get_index()
|
|
|
|
def build(self, results: Dict[str, Any]):
|
|
# TODO: add support for dry_run
|
|
for index_desc in self.index_descs:
|
|
self.build_one(index_desc, results)
|
|
return results, None
|
|
|
|
|
|
@dataclass
|
|
class SearchOperator(IndexOperator):
|
|
knn_descs: List[KnnDescriptor] = field(default_factory=lambda: [])
|
|
range: bool = False
|
|
compute_gt: bool = True
|
|
|
|
def get_desc(self, name: str) -> Optional[KnnDescriptor]:
|
|
for desc in self.knn_descs:
|
|
if desc.get_name() == name:
|
|
return desc
|
|
return None
|
|
|
|
def get_flat_desc(self, name=None) -> Optional[KnnDescriptor]:
|
|
for desc in self.knn_descs:
|
|
if desc.get_name().startswith("Flat"):
|
|
return desc
|
|
return None
|
|
|
|
def build_index_wrapper(self, knn_desc: KnnDescriptor):
|
|
if hasattr(knn_desc, "index"):
|
|
return
|
|
|
|
assert knn_desc.index_desc is not None
|
|
if hasattr(knn_desc.index_desc, "index"):
|
|
knn_desc.index = knn_desc.index_desc.index
|
|
knn_desc.index.knn_name = knn_desc.get_name()
|
|
knn_desc.index.search_params = knn_desc.search_params
|
|
else:
|
|
index = Index(
|
|
num_threads=self.num_threads,
|
|
d=knn_desc.d,
|
|
metric=self.distance_metric,
|
|
bucket=knn_desc.index_desc.bucket,
|
|
index_path=knn_desc.index_desc.path,
|
|
index_name=knn_desc.index_desc.get_name(),
|
|
# knn_name=knn_desc.get_name(),
|
|
search_params=knn_desc.search_params,
|
|
)
|
|
index.set_io(self.io)
|
|
knn_desc.index = index
|
|
|
|
knn_desc.index.get_index()
|
|
|
|
def range_search_reference(self, index, parameters, range_metric, query_dataset):
|
|
logger.info("range_search_reference: begin")
|
|
if isinstance(range_metric, list):
|
|
assert len(range_metric) > 0
|
|
m_radius = (
|
|
max(rm[-2] for rm in range_metric)
|
|
if self.distance_metric_type == faiss.METRIC_L2
|
|
else min(rm[-2] for rm in range_metric)
|
|
)
|
|
else:
|
|
m_radius = range_metric
|
|
|
|
lims, D, I, R, P, _ = self.range_search(
|
|
False,
|
|
index,
|
|
parameters,
|
|
radius=m_radius,
|
|
query_dataset=query_dataset,
|
|
)
|
|
flat = index.is_flat_index()
|
|
(
|
|
gt_radius,
|
|
range_search_metric_function,
|
|
coefficients,
|
|
coefficients_training_data,
|
|
) = get_range_search_metric_function(
|
|
range_metric,
|
|
D if not flat else None,
|
|
R if not flat else None,
|
|
)
|
|
logger.info("range_search_reference: end")
|
|
return (
|
|
gt_radius,
|
|
range_search_metric_function,
|
|
coefficients,
|
|
coefficients_training_data,
|
|
)
|
|
|
|
def estimate_range(self, index, parameters, range_scoring_radius, query_dataset):
|
|
D, I, R, P, _ = index.knn_search(
|
|
False,
|
|
parameters,
|
|
query_dataset,
|
|
self.k,
|
|
)
|
|
samples = []
|
|
for i, j in np.argwhere(R < range_scoring_radius):
|
|
samples.append((R[i, j].item(), D[i, j].item()))
|
|
if len(samples) > 0: # estimate range
|
|
samples.sort(key=itemgetter(0))
|
|
return median(r for _, r in samples[-3:])
|
|
else: # ensure at least one result
|
|
i, j = np.argwhere(R.min() == R)[0]
|
|
return D[i, j].item()
|
|
|
|
def range_search(
|
|
self,
|
|
dry_run,
|
|
index: Index,
|
|
search_parameters: Optional[Dict[str, int]],
|
|
query_dataset: DatasetDescriptor,
|
|
radius: Optional[float] = None,
|
|
gt_radius: Optional[float] = None,
|
|
range_search_metric_function=None,
|
|
gt_rsm=None,
|
|
):
|
|
logger.info("range_search: begin")
|
|
if radius is None:
|
|
assert gt_radius is not None
|
|
radius = (
|
|
gt_radius
|
|
if index.is_flat()
|
|
else self.estimate_range(
|
|
index, search_parameters, gt_radius, query_dataset
|
|
)
|
|
)
|
|
logger.info(f"Radius={radius}")
|
|
lims, D, I, R, P, requires = index.range_search(
|
|
dry_run=dry_run,
|
|
search_parameters=search_parameters,
|
|
query_vectors=query_dataset,
|
|
radius=radius,
|
|
)
|
|
if requires is not None:
|
|
return None, None, None, None, None, requires
|
|
if range_search_metric_function is not None:
|
|
range_search_metric = range_search_metric_function(R)
|
|
range_search_pr = range_search_pr_curve(D, range_search_metric, gt_rsm)
|
|
range_score_sum = np.sum(range_search_metric).item()
|
|
P |= {
|
|
"range_score_sum": range_score_sum,
|
|
"range_score_max_recall": range_score_sum / gt_rsm,
|
|
"range_search_pr": range_search_pr,
|
|
}
|
|
return lims, D, I, R, P, requires
|
|
|
|
def range_ground_truth(
|
|
self, gt_radius, range_search_metric_function, flat_desc=None
|
|
):
|
|
logger.info("range_ground_truth: begin")
|
|
if flat_desc is None:
|
|
flat_desc = self.get_flat_desc()
|
|
lims, D, I, R, P, _ = self.range_search(
|
|
False,
|
|
flat_desc.index,
|
|
search_parameters=None,
|
|
radius=gt_radius,
|
|
query_dataset=flat_desc.query_dataset,
|
|
)
|
|
gt_rsm = np.sum(range_search_metric_function(R)).item()
|
|
logger.info("range_ground_truth: end")
|
|
return gt_rsm
|
|
|
|
def knn_ground_truth(self, flat_desc=None):
|
|
logger.info("knn_ground_truth: begin")
|
|
if flat_desc is None:
|
|
flat_desc = self.get_flat_desc()
|
|
self.build_index_wrapper(flat_desc)
|
|
# TODO(kuarora): Consider moving gt results(gt_knn_D, gt_knn_I) to the index as there can be multiple ground truths.
|
|
(
|
|
self.gt_knn_D,
|
|
self.gt_knn_I,
|
|
_,
|
|
_,
|
|
requires,
|
|
) = flat_desc.index.knn_search(
|
|
dry_run=False,
|
|
search_parameters=None,
|
|
query_vectors=flat_desc.query_dataset,
|
|
k=flat_desc.k,
|
|
)
|
|
assert requires is None
|
|
logger.info("knn_ground_truth: end")
|
|
|
|
def search_benchmark(
|
|
self,
|
|
name,
|
|
search_func,
|
|
key_func,
|
|
cost_metrics,
|
|
perf_metrics,
|
|
results: Dict[str, Any],
|
|
index: Index,
|
|
):
|
|
index_name = index.get_index_name()
|
|
logger.info(f"{name}_benchmark: begin {index_name}")
|
|
|
|
def experiment(parameters, cost_metric, perf_metric):
|
|
nonlocal results
|
|
key = key_func(parameters)
|
|
if key in results["experiments"]:
|
|
metrics = results["experiments"][key]
|
|
else:
|
|
metrics, requires = search_func(parameters)
|
|
if requires is not None:
|
|
return None, None, requires
|
|
results["experiments"][key] = metrics
|
|
return metrics[cost_metric], metrics[perf_metric], None
|
|
|
|
requires = None
|
|
for cost_metric in cost_metrics:
|
|
for perf_metric in perf_metrics:
|
|
op = index.get_operating_points()
|
|
requires = optimizer(
|
|
op,
|
|
experiment,
|
|
cost_metric,
|
|
perf_metric,
|
|
)
|
|
if requires is not None:
|
|
break
|
|
logger.info(f"{name}_benchmark: end")
|
|
return results, requires
|
|
|
|
def knn_search_benchmark(
|
|
self, dry_run, results: Dict[str, Any], knn_desc: KnnDescriptor
|
|
):
|
|
gt_knn_D = None
|
|
gt_knn_I = None
|
|
if hasattr(self, "gt_knn_D"):
|
|
gt_knn_D = self.gt_knn_D
|
|
gt_knn_I = self.gt_knn_I
|
|
|
|
assert hasattr(knn_desc, "index")
|
|
if not knn_desc.index.is_flat_index() and gt_knn_I is None:
|
|
key = knn_desc.index.get_knn_search_name(
|
|
search_parameters=knn_desc.search_params,
|
|
query_vectors=knn_desc.query_dataset,
|
|
k=knn_desc.k,
|
|
reconstruct=False,
|
|
)
|
|
metrics, requires = knn_desc.index.knn_search(
|
|
dry_run,
|
|
knn_desc.search_params,
|
|
knn_desc.query_dataset,
|
|
knn_desc.k,
|
|
)[3:]
|
|
if requires is not None:
|
|
return results, requires
|
|
results["experiments"][key] = metrics
|
|
return results, requires
|
|
|
|
return self.search_benchmark(
|
|
name="knn_search",
|
|
search_func=lambda parameters: knn_desc.index.knn_search(
|
|
dry_run,
|
|
parameters,
|
|
knn_desc.query_dataset,
|
|
knn_desc.k,
|
|
gt_knn_I,
|
|
gt_knn_D,
|
|
)[3:],
|
|
key_func=lambda parameters: knn_desc.index.get_knn_search_name(
|
|
search_parameters=parameters,
|
|
query_vectors=knn_desc.query_dataset,
|
|
k=knn_desc.k,
|
|
reconstruct=False,
|
|
),
|
|
cost_metrics=["time"],
|
|
perf_metrics=["knn_intersection", "distance_ratio"],
|
|
results=results,
|
|
index=knn_desc.index,
|
|
)
|
|
|
|
def reconstruct_benchmark(
|
|
self, dry_run, results: Dict[str, Any], knn_desc: KnnDescriptor
|
|
):
|
|
return self.search_benchmark(
|
|
name="reconstruct",
|
|
search_func=lambda parameters: knn_desc.index.reconstruct(
|
|
dry_run,
|
|
parameters,
|
|
knn_desc.query_dataset,
|
|
knn_desc.k,
|
|
self.gt_knn_I,
|
|
),
|
|
key_func=lambda parameters: knn_desc.index.get_knn_search_name(
|
|
search_parameters=parameters,
|
|
query_vectors=knn_desc.query_dataset,
|
|
k=knn_desc.k,
|
|
reconstruct=True,
|
|
),
|
|
cost_metrics=["encode_time"],
|
|
perf_metrics=["sym_recall"],
|
|
results=results,
|
|
index=knn_desc.index,
|
|
)
|
|
|
|
def range_search_benchmark(
|
|
self,
|
|
dry_run,
|
|
results: Dict[str, Any],
|
|
index: Index,
|
|
metric_key: str,
|
|
radius: float,
|
|
gt_radius: float,
|
|
range_search_metric_function,
|
|
gt_rsm: float,
|
|
query_dataset: DatasetDescriptor,
|
|
):
|
|
return self.search_benchmark(
|
|
name="range_search",
|
|
search_func=lambda parameters: self.range_search(
|
|
dry_run=dry_run,
|
|
index=index,
|
|
search_parameters=parameters,
|
|
radius=radius,
|
|
gt_radius=gt_radius,
|
|
range_search_metric_function=range_search_metric_function,
|
|
gt_rsm=gt_rsm,
|
|
query_dataset=query_dataset,
|
|
)[4:],
|
|
key_func=lambda parameters: index.get_range_search_name(
|
|
search_parameters=parameters,
|
|
query_vectors=query_dataset,
|
|
radius=radius,
|
|
)
|
|
+ metric_key,
|
|
cost_metrics=["time"],
|
|
perf_metrics=["range_score_max_recall"],
|
|
results=results,
|
|
index=index,
|
|
)
|
|
|
|
def search_one(
|
|
self,
|
|
knn_desc: KnnDescriptor,
|
|
results: Dict[str, Any],
|
|
dry_run=False,
|
|
range=False,
|
|
):
|
|
faiss.omp_set_num_threads(knn_desc.num_threads)
|
|
|
|
self.build_index_wrapper(knn_desc)
|
|
# results, requires = self.reconstruct_benchmark(
|
|
# dry_run=True,
|
|
# results=results,
|
|
# index=index_desc.index,
|
|
# )
|
|
# if reconstruct and requires is not None:
|
|
# if dry_run:
|
|
# return results, requires
|
|
# else:
|
|
# results, requires = self.reconstruct_benchmark(
|
|
# dry_run=False,
|
|
# results=results,
|
|
# index=index_desc.index,
|
|
# )
|
|
# assert requires is None
|
|
results, requires = self.knn_search_benchmark(
|
|
dry_run=True,
|
|
results=results,
|
|
knn_desc=knn_desc,
|
|
)
|
|
if requires is not None:
|
|
if dry_run:
|
|
return results, requires
|
|
else:
|
|
results, requires = self.knn_search_benchmark(
|
|
dry_run=False,
|
|
results=results,
|
|
knn_desc=knn_desc,
|
|
)
|
|
assert requires is None
|
|
|
|
if (
|
|
knn_desc.range_ref_index_desc is None or
|
|
not knn_desc.index.supports_range_search()
|
|
):
|
|
return results, None
|
|
|
|
ref_index_desc = self.get_desc(knn_desc.range_ref_index_desc)
|
|
if ref_index_desc is None:
|
|
raise ValueError(
|
|
f"{knn_desc.get_name()}: Unknown range index {knn_desc.range_ref_index_desc}"
|
|
)
|
|
if ref_index_desc.range_metrics is None:
|
|
raise ValueError(
|
|
f"Range index {ref_index_desc.factory} has no radius_score"
|
|
)
|
|
for metric_key, range_metric in ref_index_desc.range_metrics.items():
|
|
(
|
|
gt_radius,
|
|
range_search_metric_function,
|
|
coefficients,
|
|
coefficients_training_data,
|
|
) = self.range_search_reference(
|
|
ref_index_desc.index,
|
|
ref_index_desc.search_params,
|
|
range_metric,
|
|
query_dataset=knn_desc.query_dataset,
|
|
)
|
|
gt_rsm = None
|
|
if self.compute_gt:
|
|
gt_rsm = self.range_ground_truth(
|
|
gt_radius, range_search_metric_function
|
|
)
|
|
results, requires = self.range_search_benchmark(
|
|
dry_run=True,
|
|
results=results,
|
|
index=knn_desc.index,
|
|
metric_key=metric_key,
|
|
radius=knn_desc.radius,
|
|
gt_radius=gt_radius,
|
|
range_search_metric_function=range_search_metric_function,
|
|
gt_rsm=gt_rsm,
|
|
query_dataset=knn_desc.query_dataset,
|
|
)
|
|
if range and requires is not None:
|
|
if dry_run:
|
|
return results, requires
|
|
else:
|
|
results, requires = self.range_search_benchmark(
|
|
dry_run=False,
|
|
results=results,
|
|
index=knn_desc.index,
|
|
metric_key=metric_key,
|
|
radius=knn_desc.radius,
|
|
gt_radius=gt_radius,
|
|
range_search_metric_function=range_search_metric_function,
|
|
gt_rsm=gt_rsm,
|
|
query_dataset=knn_desc.query_dataset,
|
|
)
|
|
assert requires is None
|
|
|
|
return results, None
|
|
|
|
def search(
|
|
self,
|
|
results: Dict[str, Any],
|
|
dry_run: bool = False,):
|
|
for knn_desc in self.knn_descs:
|
|
results, requires = self.search_one(
|
|
knn_desc=knn_desc,
|
|
results=results,
|
|
dry_run=dry_run,
|
|
range=self.range)
|
|
if dry_run:
|
|
if requires is None:
|
|
continue
|
|
return results, requires
|
|
|
|
assert requires is None
|
|
return results, None
|
|
|
|
|
|
@dataclass
|
|
class ExecutionOperator:
|
|
distance_metric: str = "L2"
|
|
num_threads: int = 1
|
|
train_op: Optional[TrainOperator] = None
|
|
build_op: Optional[BuildOperator] = None
|
|
search_op: Optional[SearchOperator] = None
|
|
compute_gt: bool = True
|
|
|
|
def __post_init__(self):
|
|
if self.distance_metric == "IP":
|
|
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
|
|
elif self.distance_metric == "L2":
|
|
self.distance_metric_type = faiss.METRIC_L2
|
|
else:
|
|
raise ValueError
|
|
|
|
if self.search_op is not None:
|
|
self.search_op.compute_gt = self.compute_gt
|
|
|
|
def set_io(self, io: BenchmarkIO):
|
|
self.io = io
|
|
self.io.distance_metric = self.distance_metric
|
|
self.io.distance_metric_type = self.distance_metric_type
|
|
if self.train_op:
|
|
self.train_op.set_io(io)
|
|
if self.build_op:
|
|
self.build_op.set_io(io)
|
|
if self.search_op:
|
|
self.search_op.set_io(io)
|
|
|
|
def create_gt_codec(
|
|
self, codec_desc, results, train=True
|
|
) -> Optional[CodecDescriptor]:
|
|
gt_codec_desc = None
|
|
if self.train_op:
|
|
gt_codec_desc = self.train_op.get_flat_desc(codec_desc.flat_name())
|
|
if gt_codec_desc is None:
|
|
gt_codec_desc = CodecDescriptor(
|
|
factory="Flat",
|
|
d=codec_desc.d,
|
|
metric=codec_desc.metric,
|
|
num_threads=self.num_threads,
|
|
)
|
|
self.train_op.codec_descs.insert(0, gt_codec_desc)
|
|
if train:
|
|
self.train_op.train_one(gt_codec_desc, results, dry_run=False)
|
|
|
|
return gt_codec_desc
|
|
|
|
def create_gt_index(
|
|
self, index_desc: IndexDescriptor, results: Dict[str, Any], build=True
|
|
) -> Optional[IndexDescriptor]:
|
|
gt_index_desc = None
|
|
if self.build_op:
|
|
gt_index_desc = self.build_op.get_flat_desc(index_desc.flat_name())
|
|
if gt_index_desc is None:
|
|
gt_codec_desc = self.train_op.get_flat_desc(
|
|
index_desc.codec_desc.flat_name()
|
|
)
|
|
assert gt_codec_desc is not None
|
|
gt_index_desc = IndexDescriptor(
|
|
d=index_desc.d,
|
|
metric=index_desc.metric,
|
|
num_threads=self.num_threads,
|
|
codec_desc=gt_codec_desc,
|
|
database_desc=index_desc.database_desc,
|
|
)
|
|
self.build_op.index_descs.insert(0, gt_index_desc)
|
|
if build:
|
|
self.build_op.build_one(gt_index_desc, results)
|
|
|
|
return gt_index_desc
|
|
|
|
def create_gt_knn(self, knn_desc, search=True) -> Optional[KnnDescriptor]:
|
|
gt_knn_desc = None
|
|
if self.search_op:
|
|
gt_knn_desc = self.search_op.get_flat_desc(knn_desc.flat_name())
|
|
if gt_knn_desc is None:
|
|
if knn_desc.gt_index_desc is not None:
|
|
gt_index_desc = knn_desc.gt_index_desc
|
|
else:
|
|
gt_index_desc = self.build_op.get_flat_desc(
|
|
knn_desc.index_desc.flat_name()
|
|
)
|
|
knn_desc.gt_index_desc = gt_index_desc
|
|
|
|
assert gt_index_desc is not None
|
|
gt_knn_desc = KnnDescriptor(
|
|
d=knn_desc.d,
|
|
metric=knn_desc.metric,
|
|
num_threads=self.num_threads,
|
|
index_desc=gt_index_desc,
|
|
query_dataset=knn_desc.query_dataset,
|
|
k=knn_desc.k,
|
|
)
|
|
self.search_op.knn_descs.insert(0, gt_knn_desc)
|
|
if search:
|
|
self.search_op.build_index_wrapper(gt_knn_desc)
|
|
self.search_op.knn_ground_truth(gt_knn_desc)
|
|
|
|
return gt_knn_desc
|
|
|
|
def create_range_ref_knn(self, knn_desc):
|
|
if (
|
|
knn_desc.range_ref_index_desc is None or
|
|
not knn_desc.index.supports_range_search()
|
|
):
|
|
return
|
|
|
|
if knn_desc.range_ref_index_desc is not None:
|
|
ref_index_desc = (
|
|
self.search_op.get_desc(knn_desc.range_ref_index_desc)
|
|
)
|
|
if ref_index_desc is None:
|
|
raise ValueError(f"Unknown range index {knn_desc.range_ref_index_desc}")
|
|
if ref_index_desc.range_metrics is None:
|
|
raise ValueError(
|
|
f"Range index {knn_desc.get_name()} has no radius_score"
|
|
)
|
|
results["metrics"] = {}
|
|
self.build_index_wrapper(ref_index_desc)
|
|
for metric_key, range_metric in ref_index_desc.range_metrics.items():
|
|
(
|
|
knn_desc.gt_radius,
|
|
range_search_metric_function,
|
|
coefficients,
|
|
coefficients_training_data,
|
|
) = self.search_op.range_search_reference(
|
|
knn_desc.index, knn_desc.search_params, range_metric
|
|
)
|
|
results["metrics"][metric_key] = {
|
|
"coefficients": coefficients,
|
|
"training_data": coefficients_training_data,
|
|
}
|
|
knn_desc.gt_rsm = self.search_op.range_ground_truth(
|
|
knn_desc.gt_radius, range_search_metric_function
|
|
)
|
|
|
|
def create_ground_truths(self, results: Dict[str, Any]):
|
|
# TODO: Create all ground truth descriptors and
|
|
# put them in index descriptor as reference
|
|
if self.train_op is not None:
|
|
for codec_desc in self.train_op.codec_descs:
|
|
self.create_gt_codec(codec_desc, results)
|
|
|
|
if self.build_op is not None:
|
|
for index_desc in self.build_op.index_descs:
|
|
self.create_gt_index(
|
|
index_desc, results
|
|
) # may need to pass results in future
|
|
|
|
if self.search_op is not None:
|
|
for knn_desc in self.search_op.knn_descs:
|
|
self.create_gt_knn(knn_desc, results)
|
|
self.create_range_ref_knn(knn_desc)
|
|
|
|
def prepare_gt_or_range_knn(self, results: Dict[str, Any]):
|
|
if self.search_op is not None:
|
|
for knn_desc in self.search_op.knn_descs:
|
|
self.create_gt_knn(knn_desc, results)
|
|
self.create_range_ref_knn(knn_desc)
|
|
|
|
def execute(self, results: Dict[str, Any], dry_run: bool = False):
|
|
faiss.omp_set_num_threads(self.num_threads)
|
|
if self.train_op is not None:
|
|
results, requires = (
|
|
self.train_op.train(results=results, dry_run=dry_run)
|
|
)
|
|
if dry_run and requires:
|
|
return results, requires
|
|
|
|
if self.build_op is not None:
|
|
self.build_op.build(results)
|
|
|
|
if self.search_op is not None:
|
|
if not dry_run and self.compute_gt:
|
|
self.prepare_gt_or_range_knn(results)
|
|
|
|
results, requires = (
|
|
self.search_op.search(results=results, dry_run=dry_run)
|
|
)
|
|
if dry_run and requires:
|
|
return results, requires
|
|
return results, None
|
|
|
|
def execute_2(self, result_file=None):
|
|
results = {"indices": {}, "experiments": {}}
|
|
results, requires = self.execute(results=results)
|
|
assert requires is None
|
|
if result_file is not None:
|
|
self.io.write_json(results, result_file, overwrite=True)
|
|
|
|
def add_index_descs(self, codec_desc, index_desc, knn_desc):
|
|
if codec_desc is not None:
|
|
self.train_op.codec_descs.append(codec_desc)
|
|
if index_desc is not None:
|
|
self.build_op.index_descs.append(index_desc)
|
|
if knn_desc is not None:
|
|
self.search_op.knn_descs.append(knn_desc)
|
|
|
|
|
|
@dataclass
|
|
class Benchmark:
|
|
num_threads: int
|
|
training_vectors: Optional[DatasetDescriptor] = None
|
|
database_vectors: Optional[DatasetDescriptor] = None
|
|
query_vectors: Optional[DatasetDescriptor] = None
|
|
index_descs: Optional[List[IndexDescriptorClassic]] = None
|
|
range_ref_index_desc: Optional[str] = None
|
|
k: int = 1
|
|
distance_metric: str = "L2"
|
|
|
|
def set_io(self, benchmark_io):
|
|
self.io = benchmark_io
|
|
|
|
def get_embedding_dimension(self):
|
|
if self.training_vectors is not None:
|
|
xt = self.io.get_dataset(self.training_vectors)
|
|
return xt.shape[1]
|
|
if self.database_vectors is not None:
|
|
xb = self.io.get_dataset(self.database_vectors)
|
|
return xb.shape[1]
|
|
if self.query_vectors is not None:
|
|
xq = self.io.get_dataset(self.query_vectors)
|
|
return xq.shape[1]
|
|
raise ValueError("Failed to determine dimension of dataset")
|
|
|
|
def create_descriptors(
|
|
self, ci_desc: IndexDescriptorClassic, train, build, knn, reconstruct, range
|
|
):
|
|
codec_desc = None
|
|
index_desc = None
|
|
knn_desc = None
|
|
dim = self.get_embedding_dimension()
|
|
if train and ci_desc.factory is not None:
|
|
codec_desc = CodecDescriptor(
|
|
d=dim,
|
|
metric=self.distance_metric,
|
|
num_threads=self.num_threads,
|
|
factory=ci_desc.factory,
|
|
construction_params=ci_desc.construction_params,
|
|
training_vectors=self.training_vectors,
|
|
)
|
|
if build:
|
|
if codec_desc is None:
|
|
assert ci_desc.path is not None
|
|
codec_desc = CodecDescriptor(
|
|
d=dim,
|
|
metric=self.distance_metric,
|
|
num_threads=self.num_threads,
|
|
bucket=ci_desc.bucket,
|
|
path=ci_desc.path,
|
|
)
|
|
index_desc = IndexDescriptor(
|
|
d=codec_desc.d,
|
|
metric=self.distance_metric,
|
|
num_threads=self.num_threads,
|
|
codec_desc=codec_desc,
|
|
database_desc=self.database_vectors,
|
|
)
|
|
if knn or range:
|
|
if index_desc is None:
|
|
assert ci_desc.path is not None
|
|
index_desc = IndexDescriptor(
|
|
d=dim,
|
|
metric=self.distance_metric,
|
|
num_threads=self.num_threads,
|
|
bucket=ci_desc.bucket,
|
|
path=ci_desc.path,
|
|
)
|
|
knn_desc = KnnDescriptor(
|
|
d=dim,
|
|
metric=self.distance_metric,
|
|
num_threads=self.num_threads,
|
|
index_desc=index_desc,
|
|
query_dataset=self.query_vectors,
|
|
search_params=ci_desc.search_params,
|
|
range_metrics=ci_desc.range_metrics,
|
|
radius=ci_desc.radius,
|
|
k=self.k,
|
|
)
|
|
|
|
return codec_desc, index_desc, knn_desc
|
|
|
|
def create_execution_operator(
|
|
self,
|
|
train,
|
|
build,
|
|
knn,
|
|
reconstruct,
|
|
range,
|
|
) -> ExecutionOperator:
|
|
# all operators are created, as ground truth are always created in benchmarking
|
|
train_op = TrainOperator(
|
|
num_threads=self.num_threads, distance_metric=self.distance_metric
|
|
)
|
|
build_op = BuildOperator(
|
|
num_threads=self.num_threads, distance_metric=self.distance_metric
|
|
)
|
|
search_op = SearchOperator(
|
|
num_threads=self.num_threads, distance_metric=self.distance_metric
|
|
)
|
|
search_op.range = range
|
|
|
|
exec_op = ExecutionOperator(
|
|
train_op=train_op,
|
|
build_op=build_op,
|
|
search_op=search_op,
|
|
num_threads=self.num_threads,
|
|
)
|
|
assert hasattr(self, "io")
|
|
exec_op.set_io(self.io)
|
|
|
|
# iterate over classic descriptors
|
|
for ci_desc in self.index_descs:
|
|
codec_desc, index_desc, knn_desc = self.create_descriptors(
|
|
ci_desc, train, build, knn, reconstruct, range
|
|
)
|
|
exec_op.add_index_descs(codec_desc, index_desc, knn_desc)
|
|
|
|
return exec_op
|
|
|
|
def clone_one(self, index_desc):
|
|
benchmark = Benchmark(
|
|
num_threads=self.num_threads,
|
|
training_vectors=self.training_vectors,
|
|
database_vectors=self.database_vectors,
|
|
query_vectors=self.query_vectors,
|
|
# index_descs=[self.get_flat_desc("Flat"), index_desc],
|
|
index_descs=[index_desc], # Should automatically find flat descriptors
|
|
range_ref_index_desc=self.range_ref_index_desc,
|
|
k=self.k,
|
|
distance_metric=self.distance_metric,
|
|
)
|
|
benchmark.set_io(self.io.clone())
|
|
return benchmark
|
|
|
|
def benchmark(
|
|
self,
|
|
result_file=None,
|
|
local=False,
|
|
train=False,
|
|
reconstruct=False,
|
|
knn=False,
|
|
range=False,
|
|
):
|
|
logger.info("begin evaluate")
|
|
results = {"indices": {}, "experiments": {}}
|
|
faiss.omp_set_num_threads(self.num_threads)
|
|
exec_op = self.create_execution_operator(
|
|
train=train,
|
|
build=knn or range,
|
|
knn=knn,
|
|
reconstruct=reconstruct,
|
|
range=range,
|
|
)
|
|
exec_op.create_ground_truths(results)
|
|
|
|
todo = self.index_descs
|
|
for index_desc in self.index_descs:
|
|
index_desc.requires = None
|
|
|
|
queued = set()
|
|
while todo:
|
|
current_todo = []
|
|
next_todo = []
|
|
for index_desc in todo:
|
|
results, requires = exec_op.execute(results, dry_run=False)
|
|
if requires is None:
|
|
continue
|
|
if requires in queued:
|
|
if index_desc.requires != requires:
|
|
index_desc.requires = requires
|
|
next_todo.append(index_desc)
|
|
else:
|
|
queued.add(requires)
|
|
index_desc.requires = requires
|
|
current_todo.append(index_desc)
|
|
|
|
if current_todo:
|
|
results_one = {"indices": {}, "experiments": {}}
|
|
params = [
|
|
(
|
|
index_desc,
|
|
self.clone_one(index_desc),
|
|
results_one,
|
|
train,
|
|
reconstruct,
|
|
knn,
|
|
range,
|
|
)
|
|
for index_desc in current_todo
|
|
]
|
|
for result in self.io.launch_jobs(
|
|
run_benchmark_one, params, local=local
|
|
):
|
|
dict_merge(results, result)
|
|
|
|
todo = next_todo
|
|
|
|
if result_file is not None:
|
|
self.io.write_json(results, result_file, overwrite=True)
|
|
logger.info("end evaluate")
|
|
return results
|
|
|
|
|
|
def run_benchmark_one(params):
|
|
logger.info(params)
|
|
index_desc, benchmark, results, train, reconstruct, knn, range = params
|
|
exec_op = benchmark.create_execution_operator(
|
|
train=train,
|
|
build=knn,
|
|
knn=knn,
|
|
reconstruct=reconstruct,
|
|
range=range,
|
|
)
|
|
results, requires = exec_op.execute(results=results, dry_run=False)
|
|
assert requires is None
|
|
assert results is not None
|
|
return results
|