150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT license.
|
|
|
|
import argparse
|
|
from xml.dom.pulldom import default_bufsize
|
|
|
|
import diskannpy
|
|
import numpy as np
|
|
import utils
|
|
|
|
def build_and_search(
|
|
metric,
|
|
dtype_str,
|
|
index_directory,
|
|
indexdata_file,
|
|
querydata_file,
|
|
Lb,
|
|
graph_degree,
|
|
K,
|
|
Ls,
|
|
num_threads,
|
|
gt_file,
|
|
index_prefix,
|
|
search_only
|
|
) -> dict[str, float]:
|
|
"""
|
|
|
|
:param metric:
|
|
:param dtype_str:
|
|
:param index_directory:
|
|
:param indexdata_file:
|
|
:param querydata_file:
|
|
:param Lb:
|
|
:param graph_degree:
|
|
:param K:
|
|
:param Ls:
|
|
:param num_threads:
|
|
:param gt_file:
|
|
:param index_prefix:
|
|
:param search_only:
|
|
:return: Dictionary of timings. Key is the event and value is the number of seconds the event took
|
|
in wall-clock-time.
|
|
"""
|
|
timer_results: dict[str, float] = {}
|
|
|
|
method_timer: utils.Timer = utils.Timer()
|
|
|
|
if dtype_str == "float":
|
|
dtype = np.single
|
|
elif dtype_str == "int8":
|
|
dtype = np.byte
|
|
elif dtype_str == "uint8":
|
|
dtype = np.ubyte
|
|
else:
|
|
raise ValueError("data_type must be float, int8 or uint8")
|
|
|
|
# build index
|
|
if not search_only:
|
|
build_index_timer = utils.Timer()
|
|
diskannpy.build_memory_index(
|
|
data=indexdata_file,
|
|
distance_metric=metric,
|
|
vector_dtype=dtype,
|
|
index_directory=index_directory,
|
|
complexity=Lb,
|
|
graph_degree=graph_degree,
|
|
num_threads=num_threads,
|
|
index_prefix=index_prefix,
|
|
alpha=1.2,
|
|
use_pq_build=False,
|
|
num_pq_bytes=8,
|
|
use_opq=False,
|
|
)
|
|
timer_results["build_index_seconds"] = build_index_timer.elapsed()
|
|
|
|
# ready search object
|
|
load_index_timer = utils.Timer()
|
|
index = diskannpy.StaticMemoryIndex(
|
|
distance_metric=metric,
|
|
vector_dtype=dtype,
|
|
index_directory=index_directory,
|
|
num_threads=num_threads, # this can be different at search time if you would like
|
|
initial_search_complexity=Ls,
|
|
index_prefix=index_prefix
|
|
)
|
|
timer_results["load_index_seconds"] = load_index_timer.elapsed()
|
|
|
|
queries = utils.bin_to_numpy(dtype, querydata_file)
|
|
|
|
query_timer = utils.Timer()
|
|
ids, dists = index.batch_search(queries, 10, Ls, num_threads)
|
|
query_time = query_timer.elapsed()
|
|
qps = round(queries.shape[0]/query_time, 1)
|
|
print('Batch searched', queries.shape[0], 'in', query_time, 's @', qps, 'QPS')
|
|
timer_results["query_seconds"] = query_time
|
|
|
|
if gt_file != "":
|
|
recall_timer = utils.Timer()
|
|
recall = utils.calculate_recall_from_gt_file(K, ids, gt_file)
|
|
print(f"recall@{K} is {recall}")
|
|
timer_results["recall_seconds"] = recall_timer.elapsed()
|
|
|
|
timer_results['total_time_seconds'] = method_timer.elapsed()
|
|
|
|
return timer_results
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
prog="in-mem-static",
|
|
description="Static in-memory build and search from vectors in a file",
|
|
)
|
|
|
|
parser.add_argument("-m", "--metric", required=False, default="l2")
|
|
parser.add_argument("-d", "--data_type", required=True)
|
|
parser.add_argument("-id", "--index_directory", required=False, default=".")
|
|
parser.add_argument("-i", "--indexdata_file", required=True)
|
|
parser.add_argument("-q", "--querydata_file", required=True)
|
|
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
|
|
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
|
|
parser.add_argument("-R", "--graph_degree", default=32, type=int)
|
|
parser.add_argument("-T", "--num_threads", default=8, type=int)
|
|
parser.add_argument("-K", default=10, type=int)
|
|
parser.add_argument("-G", "--gt_file", default="")
|
|
parser.add_argument("-ip", "--index_prefix", required=False, default="ann")
|
|
parser.add_argument("--search_only", required=False, default=False)
|
|
parser.add_argument("--json_timings_output", required=False, default=None, help="File to write out timings to as JSON. If not specified, timings will not be written out.")
|
|
args = parser.parse_args()
|
|
|
|
timings: dict[str, float] = build_and_search(
|
|
args.metric,
|
|
args.data_type,
|
|
args.index_directory.strip(),
|
|
args.indexdata_file.strip(),
|
|
args.querydata_file.strip(),
|
|
args.Lbuild,
|
|
args.graph_degree, # Build args
|
|
args.K,
|
|
args.Lsearch,
|
|
args.num_threads, # search args
|
|
args.gt_file,
|
|
args.index_prefix,
|
|
args.search_only
|
|
)
|
|
|
|
if args.json_timings_output is not None:
|
|
import json
|
|
timings['log_file'] = args.json_timings_output
|
|
with open(args.json_timings_output, "w") as f:
|
|
json.dump(timings, f)
|