Initial commit
This commit is contained in:
103
packages/leann-backend-diskann/third_party/DiskANN/python/apps/insert-in-clustered-order.py
vendored
Normal file
103
packages/leann-backend-diskann/third_party/DiskANN/python/apps/insert-in-clustered-order.py
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
import diskannpy
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
|
||||
def insert_and_search(
|
||||
dtype_str,
|
||||
indexdata_file,
|
||||
querydata_file,
|
||||
Lb,
|
||||
graph_degree,
|
||||
num_clusters,
|
||||
num_insert_threads,
|
||||
K,
|
||||
Ls,
|
||||
num_search_threads,
|
||||
gt_file,
|
||||
):
|
||||
npts, ndims = utils.get_bin_metadata(indexdata_file)
|
||||
|
||||
if dtype_str == "float":
|
||||
dtype = np.float32
|
||||
elif dtype_str == "int8":
|
||||
dtype = np.int8
|
||||
elif dtype_str == "uint8":
|
||||
dtype = np.uint8
|
||||
else:
|
||||
raise ValueError("data_type must be float, int8 or uint8")
|
||||
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
distance_metric="l2",
|
||||
vector_dtype=dtype,
|
||||
dimensions=ndims,
|
||||
max_vectors=npts,
|
||||
complexity=Lb,
|
||||
graph_degree=graph_degree
|
||||
)
|
||||
queries = diskannpy.vectors_from_file(querydata_file, dtype)
|
||||
data = diskannpy.vectors_from_file(indexdata_file, dtype)
|
||||
|
||||
offsets, permutation = utils.cluster_and_permute(
|
||||
dtype_str, npts, ndims, data, num_clusters
|
||||
)
|
||||
|
||||
i = 0
|
||||
timer = utils.Timer()
|
||||
for c in range(num_clusters):
|
||||
cluster_index_range = range(offsets[c], offsets[c + 1])
|
||||
cluster_indices = np.array(permutation[cluster_index_range], dtype=np.uint32)
|
||||
cluster_data = data[cluster_indices, :]
|
||||
index.batch_insert(cluster_data, cluster_indices + 1, num_insert_threads)
|
||||
print('Inserted cluster', c, 'in', timer.elapsed(), 's')
|
||||
tags, dists = index.batch_search(queries, K, Ls, num_search_threads)
|
||||
print('Batch searched', queries.shape[0], 'queries in', timer.elapsed(), 's')
|
||||
res_ids = tags - 1
|
||||
|
||||
if gt_file != "":
|
||||
recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file)
|
||||
print(f"recall@{K} is {recall}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="in-mem-dynamic",
|
||||
description="Inserts points dynamically in a clustered order and search from vectors in a file.",
|
||||
)
|
||||
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-q", "--querydata_file", required=True)
|
||||
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
|
||||
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
|
||||
parser.add_argument("-R", "--graph_degree", default=32, type=int)
|
||||
parser.add_argument("-TI", "--num_insert_threads", default=8, type=int)
|
||||
parser.add_argument("-TS", "--num_search_threads", default=8, type=int)
|
||||
parser.add_argument("-C", "--num_clusters", default=32, type=int)
|
||||
parser.add_argument("-K", default=10, type=int)
|
||||
parser.add_argument("--gt_file", default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
insert_and_search(
|
||||
args.data_type,
|
||||
args.indexdata_file,
|
||||
args.querydata_file,
|
||||
args.Lbuild,
|
||||
args.graph_degree, # Build args
|
||||
args.num_clusters,
|
||||
args.num_insert_threads,
|
||||
args.K,
|
||||
args.Lsearch,
|
||||
args.num_search_threads, # search args
|
||||
args.gt_file,
|
||||
)
|
||||
|
||||
# An ingest optimized example with SIFT1M
|
||||
# python3 ~/DiskANN/python/apps/insert-in-clustered-order.py -d float \
|
||||
# -i sift_base.fbin -q sift_query.fbin --gt_file gt100_base \
|
||||
# -Lb 10 -R 30 -Ls 200 -C 32
|
||||
Reference in New Issue
Block a user