Initial commit
This commit is contained in:
152
packages/leann-backend-diskann/third_party/DiskANN/python/apps/cli/__main__.py
vendored
Normal file
152
packages/leann-backend-diskann/third_party/DiskANN/python/apps/cli/__main__.py
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
import diskannpy as dap
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
import fire
|
||||
|
||||
from contextlib import contextmanager
|
||||
from time import perf_counter
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def _basic_setup(
|
||||
dtype: str,
|
||||
query_vectors_file: str
|
||||
) -> Tuple[dap.VectorDType, npt.NDArray[dap.VectorDType]]:
|
||||
_dtype = dap.valid_dtype(dtype)
|
||||
vectors_to_query = dap.vectors_from_binary(query_vectors_file, dtype=_dtype)
|
||||
return _dtype, vectors_to_query
|
||||
|
||||
|
||||
def dynamic(
|
||||
dtype: str,
|
||||
index_vectors_file: str,
|
||||
query_vectors_file: str,
|
||||
build_complexity: int,
|
||||
graph_degree: int,
|
||||
K: int,
|
||||
search_complexity: int,
|
||||
num_insert_threads: int,
|
||||
num_search_threads: int,
|
||||
gt_file: str = "",
|
||||
):
|
||||
_dtype, vectors_to_query = _basic_setup(dtype, query_vectors_file)
|
||||
vectors_to_index = dap.vectors_from_binary(index_vectors_file, dtype=_dtype)
|
||||
|
||||
npts, ndims = vectors_to_index.shape
|
||||
index = dap.DynamicMemoryIndex(
|
||||
"l2", _dtype, ndims, npts, build_complexity, graph_degree
|
||||
)
|
||||
|
||||
tags = np.arange(1, npts+1, dtype=np.uintc)
|
||||
timer = Timer()
|
||||
|
||||
with timer.time("batch insert"):
|
||||
index.batch_insert(vectors_to_index, tags, num_insert_threads)
|
||||
|
||||
delete_tags = np.random.choice(
|
||||
np.array(range(1, npts + 1, 1), dtype=np.uintc),
|
||||
size=int(0.5 * npts),
|
||||
replace=False
|
||||
)
|
||||
with timer.time("mark deletion"):
|
||||
for tag in delete_tags:
|
||||
index.mark_deleted(tag)
|
||||
|
||||
with timer.time("consolidation"):
|
||||
index.consolidate_delete()
|
||||
|
||||
deleted_data = vectors_to_index[delete_tags - 1, :]
|
||||
|
||||
with timer.time("re-insertion"):
|
||||
index.batch_insert(deleted_data, delete_tags, num_insert_threads)
|
||||
|
||||
with timer.time("batch searched"):
|
||||
tags, dists = index.batch_search(vectors_to_query, K, search_complexity, num_search_threads)
|
||||
|
||||
# res_ids = tags - 1
|
||||
# if gt_file != "":
|
||||
# recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file)
|
||||
# print(f"recall@{K} is {recall}")
|
||||
|
||||
def static(
|
||||
dtype: str,
|
||||
index_directory: str,
|
||||
index_vectors_file: str,
|
||||
query_vectors_file: str,
|
||||
build_complexity: int,
|
||||
graph_degree: int,
|
||||
K: int,
|
||||
search_complexity: int,
|
||||
num_threads: int,
|
||||
gt_file: str = "",
|
||||
index_prefix: str = "ann"
|
||||
):
|
||||
_dtype, vectors_to_query = _basic_setup(dtype, query_vectors_file)
|
||||
timer = Timer()
|
||||
with timer.time("build static index"):
|
||||
# build index
|
||||
dap.build_memory_index(
|
||||
data=index_vectors_file,
|
||||
metric="l2",
|
||||
vector_dtype=_dtype,
|
||||
index_directory=index_directory,
|
||||
complexity=build_complexity,
|
||||
graph_degree=graph_degree,
|
||||
num_threads=num_threads,
|
||||
index_prefix=index_prefix,
|
||||
alpha=1.2,
|
||||
use_pq_build=False,
|
||||
num_pq_bytes=8,
|
||||
use_opq=False,
|
||||
)
|
||||
|
||||
with timer.time("load static index"):
|
||||
# ready search object
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=_dtype,
|
||||
data_path=index_vectors_file,
|
||||
index_directory=index_directory,
|
||||
num_threads=num_threads, # this can be different at search time if you would like
|
||||
initial_search_complexity=search_complexity,
|
||||
index_prefix=index_prefix
|
||||
)
|
||||
|
||||
ids, dists = index.batch_search(vectors_to_query, K, search_complexity, num_threads)
|
||||
|
||||
# if gt_file != "":
|
||||
# recall = utils.calculate_recall_from_gt_file(K, ids, gt_file)
|
||||
# print(f"recall@{K} is {recall}")
|
||||
|
||||
def dynamic_clustered():
|
||||
pass
|
||||
|
||||
def generate_clusters():
|
||||
pass
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self._start = -1
|
||||
|
||||
@contextmanager
|
||||
def time(self, message: str):
|
||||
start = perf_counter()
|
||||
if self._start == -1:
|
||||
self._start = start
|
||||
yield
|
||||
now = perf_counter()
|
||||
print(f"Operation {message} completed in {(now - start):.3f}s, total: {(now - self._start):.3f}s")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire({
|
||||
"in-mem-dynamic": dynamic,
|
||||
"in-mem-static": static,
|
||||
"in-mem-dynamic-clustered": dynamic_clustered,
|
||||
"generate-clusters": generate_clusters
|
||||
}, name="cli")
|
||||
28
packages/leann-backend-diskann/third_party/DiskANN/python/apps/cluster.py
vendored
Normal file
28
packages/leann-backend-diskann/third_party/DiskANN/python/apps/cluster.py
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import utils
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="cluster", description="kmeans cluster points in a file"
|
||||
)
|
||||
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-k", "--num_clusters", type=int, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
npts, ndims = get_bin_metadata(indexdata_file)
|
||||
|
||||
data = utils.bin_to_numpy(args.data_type, args.indexdata_file)
|
||||
|
||||
offsets, permutation = utils.cluster_and_permute(
|
||||
args.data_type, npts, ndims, data, args.num_clusters
|
||||
)
|
||||
|
||||
permuted_data = data[permutation]
|
||||
|
||||
utils.numpy_to_bin(permuted_data, args.indexdata_file + ".cluster")
|
||||
161
packages/leann-backend-diskann/third_party/DiskANN/python/apps/in-mem-dynamic.py
vendored
Normal file
161
packages/leann-backend-diskann/third_party/DiskANN/python/apps/in-mem-dynamic.py
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
import diskannpy
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
def insert_and_search(
|
||||
dtype_str,
|
||||
indexdata_file,
|
||||
querydata_file,
|
||||
Lb,
|
||||
graph_degree,
|
||||
K,
|
||||
Ls,
|
||||
num_insert_threads,
|
||||
num_search_threads,
|
||||
gt_file,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
|
||||
:param dtype_str:
|
||||
:param indexdata_file:
|
||||
:param querydata_file:
|
||||
:param Lb:
|
||||
:param graph_degree:
|
||||
:param K:
|
||||
:param Ls:
|
||||
:param num_insert_threads:
|
||||
:param num_search_threads:
|
||||
:param gt_file:
|
||||
:return: Dictionary of timings. Key is the event and value is the number of seconds the event took
|
||||
"""
|
||||
timer_results: dict[str, float] = {}
|
||||
|
||||
method_timer: utils.Timer = utils.Timer()
|
||||
|
||||
npts, ndims = utils.get_bin_metadata(indexdata_file)
|
||||
|
||||
if dtype_str == "float":
|
||||
dtype = np.float32
|
||||
elif dtype_str == "int8":
|
||||
dtype = np.int8
|
||||
elif dtype_str == "uint8":
|
||||
dtype = np.uint8
|
||||
else:
|
||||
raise ValueError("data_type must be float, int8 or uint8")
|
||||
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
distance_metric="l2",
|
||||
vector_dtype=dtype,
|
||||
dimensions=ndims,
|
||||
max_vectors=npts,
|
||||
complexity=Lb,
|
||||
graph_degree=graph_degree
|
||||
)
|
||||
queries = diskannpy.vectors_from_file(querydata_file, dtype)
|
||||
data = diskannpy.vectors_from_file(indexdata_file, dtype)
|
||||
|
||||
tags = np.zeros(npts, dtype=np.uintc)
|
||||
timer = utils.Timer()
|
||||
for i in range(npts):
|
||||
tags[i] = i + 1
|
||||
index.batch_insert(data, tags, num_insert_threads)
|
||||
compute_seconds = timer.elapsed()
|
||||
print('batch_insert complete in', compute_seconds, 's')
|
||||
timer_results["batch_insert_seconds"] = compute_seconds
|
||||
|
||||
delete_tags = np.random.choice(
|
||||
np.array(range(1, npts + 1, 1), dtype=np.uintc),
|
||||
size=int(0.5 * npts),
|
||||
replace=False
|
||||
)
|
||||
|
||||
timer.reset()
|
||||
for tag in delete_tags:
|
||||
index.mark_deleted(tag)
|
||||
compute_seconds = timer.elapsed()
|
||||
timer_results['mark_deletion_seconds'] = compute_seconds
|
||||
print('mark deletion completed in', compute_seconds, 's')
|
||||
|
||||
timer.reset()
|
||||
index.consolidate_delete()
|
||||
compute_seconds = timer.elapsed()
|
||||
print('consolidation completed in', compute_seconds, 's')
|
||||
timer_results['consolidation_completed_seconds'] = compute_seconds
|
||||
|
||||
deleted_data = data[delete_tags - 1, :]
|
||||
|
||||
timer.reset()
|
||||
index.batch_insert(deleted_data, delete_tags, num_insert_threads)
|
||||
compute_seconds = timer.elapsed()
|
||||
print('re-insertion completed in', compute_seconds, 's')
|
||||
timer_results['re-insertion_seconds'] = compute_seconds
|
||||
|
||||
timer.reset()
|
||||
tags, dists = index.batch_search(queries, K, Ls, num_search_threads)
|
||||
compute_seconds = timer.elapsed()
|
||||
print('Batch searched', queries.shape[0], ' queries in ', compute_seconds, 's')
|
||||
timer_results['batch_searched_seconds'] = compute_seconds
|
||||
|
||||
res_ids = tags - 1
|
||||
if gt_file != "":
|
||||
timer.reset()
|
||||
recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file)
|
||||
print(f"recall@{K} is {recall}")
|
||||
timer_results['recall_computed_seconds'] = timer.elapsed()
|
||||
|
||||
timer_results['total_time_seconds'] = method_timer.elapsed()
|
||||
|
||||
return timer_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="in-mem-dynamic",
|
||||
description="Inserts points dynamically in a clustered order and search from vectors in a file.",
|
||||
)
|
||||
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-q", "--querydata_file", required=True)
|
||||
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
|
||||
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
|
||||
parser.add_argument("-R", "--graph_degree", default=32, type=int)
|
||||
parser.add_argument("-TI", "--num_insert_threads", default=8, type=int)
|
||||
parser.add_argument("-TS", "--num_search_threads", default=8, type=int)
|
||||
parser.add_argument("-K", default=10, type=int)
|
||||
parser.add_argument("--gt_file", default="")
|
||||
parser.add_argument("--json_timings_output", required=False, default=None, help="File to write out timings to as JSON. If not specified, timings will not be written out.")
|
||||
args = parser.parse_args()
|
||||
|
||||
timings = insert_and_search(
|
||||
args.data_type,
|
||||
args.indexdata_file,
|
||||
args.querydata_file,
|
||||
args.Lbuild,
|
||||
args.graph_degree, # Build args
|
||||
args.K,
|
||||
args.Lsearch,
|
||||
args.num_insert_threads,
|
||||
args.num_search_threads, # search args
|
||||
args.gt_file,
|
||||
)
|
||||
|
||||
if args.json_timings_output is not None:
|
||||
import json
|
||||
timings['log_file'] = args.json_timings_output
|
||||
with open(args.json_timings_output, "w") as f:
|
||||
json.dump(timings, f)
|
||||
|
||||
"""
|
||||
An ingest optimized example with SIFT1M
|
||||
source venv/bin/activate
|
||||
python python/apps/in-mem-dynamic.py -d float \
|
||||
-i "$HOME/data/sift/sift_base.fbin" -q "$HOME/data/sift/sift_query.fbin" --gt_file "$HOME/data/sift/gt100_base" \
|
||||
-Lb 10 -R 30 -Ls 200
|
||||
"""
|
||||
|
||||
149
packages/leann-backend-diskann/third_party/DiskANN/python/apps/in-mem-static.py
vendored
Normal file
149
packages/leann-backend-diskann/third_party/DiskANN/python/apps/in-mem-static.py
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
from xml.dom.pulldom import default_bufsize
|
||||
|
||||
import diskannpy
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
def build_and_search(
|
||||
metric,
|
||||
dtype_str,
|
||||
index_directory,
|
||||
indexdata_file,
|
||||
querydata_file,
|
||||
Lb,
|
||||
graph_degree,
|
||||
K,
|
||||
Ls,
|
||||
num_threads,
|
||||
gt_file,
|
||||
index_prefix,
|
||||
search_only
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
|
||||
:param metric:
|
||||
:param dtype_str:
|
||||
:param index_directory:
|
||||
:param indexdata_file:
|
||||
:param querydata_file:
|
||||
:param Lb:
|
||||
:param graph_degree:
|
||||
:param K:
|
||||
:param Ls:
|
||||
:param num_threads:
|
||||
:param gt_file:
|
||||
:param index_prefix:
|
||||
:param search_only:
|
||||
:return: Dictionary of timings. Key is the event and value is the number of seconds the event took
|
||||
in wall-clock-time.
|
||||
"""
|
||||
timer_results: dict[str, float] = {}
|
||||
|
||||
method_timer: utils.Timer = utils.Timer()
|
||||
|
||||
if dtype_str == "float":
|
||||
dtype = np.single
|
||||
elif dtype_str == "int8":
|
||||
dtype = np.byte
|
||||
elif dtype_str == "uint8":
|
||||
dtype = np.ubyte
|
||||
else:
|
||||
raise ValueError("data_type must be float, int8 or uint8")
|
||||
|
||||
# build index
|
||||
if not search_only:
|
||||
build_index_timer = utils.Timer()
|
||||
diskannpy.build_memory_index(
|
||||
data=indexdata_file,
|
||||
distance_metric=metric,
|
||||
vector_dtype=dtype,
|
||||
index_directory=index_directory,
|
||||
complexity=Lb,
|
||||
graph_degree=graph_degree,
|
||||
num_threads=num_threads,
|
||||
index_prefix=index_prefix,
|
||||
alpha=1.2,
|
||||
use_pq_build=False,
|
||||
num_pq_bytes=8,
|
||||
use_opq=False,
|
||||
)
|
||||
timer_results["build_index_seconds"] = build_index_timer.elapsed()
|
||||
|
||||
# ready search object
|
||||
load_index_timer = utils.Timer()
|
||||
index = diskannpy.StaticMemoryIndex(
|
||||
distance_metric=metric,
|
||||
vector_dtype=dtype,
|
||||
index_directory=index_directory,
|
||||
num_threads=num_threads, # this can be different at search time if you would like
|
||||
initial_search_complexity=Ls,
|
||||
index_prefix=index_prefix
|
||||
)
|
||||
timer_results["load_index_seconds"] = load_index_timer.elapsed()
|
||||
|
||||
queries = utils.bin_to_numpy(dtype, querydata_file)
|
||||
|
||||
query_timer = utils.Timer()
|
||||
ids, dists = index.batch_search(queries, 10, Ls, num_threads)
|
||||
query_time = query_timer.elapsed()
|
||||
qps = round(queries.shape[0]/query_time, 1)
|
||||
print('Batch searched', queries.shape[0], 'in', query_time, 's @', qps, 'QPS')
|
||||
timer_results["query_seconds"] = query_time
|
||||
|
||||
if gt_file != "":
|
||||
recall_timer = utils.Timer()
|
||||
recall = utils.calculate_recall_from_gt_file(K, ids, gt_file)
|
||||
print(f"recall@{K} is {recall}")
|
||||
timer_results["recall_seconds"] = recall_timer.elapsed()
|
||||
|
||||
timer_results['total_time_seconds'] = method_timer.elapsed()
|
||||
|
||||
return timer_results
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="in-mem-static",
|
||||
description="Static in-memory build and search from vectors in a file",
|
||||
)
|
||||
|
||||
parser.add_argument("-m", "--metric", required=False, default="l2")
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-id", "--index_directory", required=False, default=".")
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-q", "--querydata_file", required=True)
|
||||
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
|
||||
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
|
||||
parser.add_argument("-R", "--graph_degree", default=32, type=int)
|
||||
parser.add_argument("-T", "--num_threads", default=8, type=int)
|
||||
parser.add_argument("-K", default=10, type=int)
|
||||
parser.add_argument("-G", "--gt_file", default="")
|
||||
parser.add_argument("-ip", "--index_prefix", required=False, default="ann")
|
||||
parser.add_argument("--search_only", required=False, default=False)
|
||||
parser.add_argument("--json_timings_output", required=False, default=None, help="File to write out timings to as JSON. If not specified, timings will not be written out.")
|
||||
args = parser.parse_args()
|
||||
|
||||
timings: dict[str, float] = build_and_search(
|
||||
args.metric,
|
||||
args.data_type,
|
||||
args.index_directory.strip(),
|
||||
args.indexdata_file.strip(),
|
||||
args.querydata_file.strip(),
|
||||
args.Lbuild,
|
||||
args.graph_degree, # Build args
|
||||
args.K,
|
||||
args.Lsearch,
|
||||
args.num_threads, # search args
|
||||
args.gt_file,
|
||||
args.index_prefix,
|
||||
args.search_only
|
||||
)
|
||||
|
||||
if args.json_timings_output is not None:
|
||||
import json
|
||||
timings['log_file'] = args.json_timings_output
|
||||
with open(args.json_timings_output, "w") as f:
|
||||
json.dump(timings, f)
|
||||
103
packages/leann-backend-diskann/third_party/DiskANN/python/apps/insert-in-clustered-order.py
vendored
Normal file
103
packages/leann-backend-diskann/third_party/DiskANN/python/apps/insert-in-clustered-order.py
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
import diskannpy
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
|
||||
def insert_and_search(
|
||||
dtype_str,
|
||||
indexdata_file,
|
||||
querydata_file,
|
||||
Lb,
|
||||
graph_degree,
|
||||
num_clusters,
|
||||
num_insert_threads,
|
||||
K,
|
||||
Ls,
|
||||
num_search_threads,
|
||||
gt_file,
|
||||
):
|
||||
npts, ndims = utils.get_bin_metadata(indexdata_file)
|
||||
|
||||
if dtype_str == "float":
|
||||
dtype = np.float32
|
||||
elif dtype_str == "int8":
|
||||
dtype = np.int8
|
||||
elif dtype_str == "uint8":
|
||||
dtype = np.uint8
|
||||
else:
|
||||
raise ValueError("data_type must be float, int8 or uint8")
|
||||
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
distance_metric="l2",
|
||||
vector_dtype=dtype,
|
||||
dimensions=ndims,
|
||||
max_vectors=npts,
|
||||
complexity=Lb,
|
||||
graph_degree=graph_degree
|
||||
)
|
||||
queries = diskannpy.vectors_from_file(querydata_file, dtype)
|
||||
data = diskannpy.vectors_from_file(indexdata_file, dtype)
|
||||
|
||||
offsets, permutation = utils.cluster_and_permute(
|
||||
dtype_str, npts, ndims, data, num_clusters
|
||||
)
|
||||
|
||||
i = 0
|
||||
timer = utils.Timer()
|
||||
for c in range(num_clusters):
|
||||
cluster_index_range = range(offsets[c], offsets[c + 1])
|
||||
cluster_indices = np.array(permutation[cluster_index_range], dtype=np.uint32)
|
||||
cluster_data = data[cluster_indices, :]
|
||||
index.batch_insert(cluster_data, cluster_indices + 1, num_insert_threads)
|
||||
print('Inserted cluster', c, 'in', timer.elapsed(), 's')
|
||||
tags, dists = index.batch_search(queries, K, Ls, num_search_threads)
|
||||
print('Batch searched', queries.shape[0], 'queries in', timer.elapsed(), 's')
|
||||
res_ids = tags - 1
|
||||
|
||||
if gt_file != "":
|
||||
recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file)
|
||||
print(f"recall@{K} is {recall}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="in-mem-dynamic",
|
||||
description="Inserts points dynamically in a clustered order and search from vectors in a file.",
|
||||
)
|
||||
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-q", "--querydata_file", required=True)
|
||||
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
|
||||
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
|
||||
parser.add_argument("-R", "--graph_degree", default=32, type=int)
|
||||
parser.add_argument("-TI", "--num_insert_threads", default=8, type=int)
|
||||
parser.add_argument("-TS", "--num_search_threads", default=8, type=int)
|
||||
parser.add_argument("-C", "--num_clusters", default=32, type=int)
|
||||
parser.add_argument("-K", default=10, type=int)
|
||||
parser.add_argument("--gt_file", default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
insert_and_search(
|
||||
args.data_type,
|
||||
args.indexdata_file,
|
||||
args.querydata_file,
|
||||
args.Lbuild,
|
||||
args.graph_degree, # Build args
|
||||
args.num_clusters,
|
||||
args.num_insert_threads,
|
||||
args.K,
|
||||
args.Lsearch,
|
||||
args.num_search_threads, # search args
|
||||
args.gt_file,
|
||||
)
|
||||
|
||||
# An ingest optimized example with SIFT1M
|
||||
# python3 ~/DiskANN/python/apps/insert-in-clustered-order.py -d float \
|
||||
# -i sift_base.fbin -q sift_query.fbin --gt_file gt100_base \
|
||||
# -Lb 10 -R 30 -Ls 200 -C 32
|
||||
120
packages/leann-backend-diskann/third_party/DiskANN/python/apps/utils.py
vendored
Normal file
120
packages/leann-backend-diskann/third_party/DiskANN/python/apps/utils.py
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
from scipy.cluster.vq import vq, kmeans2
|
||||
from typing import Tuple
|
||||
from time import perf_counter
|
||||
|
||||
|
||||
def get_bin_metadata(bin_file) -> Tuple[int, int]:
|
||||
array = np.fromfile(file=bin_file, dtype=np.uint32, count=2)
|
||||
return array[0], array[1]
|
||||
|
||||
|
||||
def bin_to_numpy(dtype, bin_file) -> np.ndarray:
|
||||
npts, ndims = get_bin_metadata(bin_file)
|
||||
return np.fromfile(file=bin_file, dtype=dtype, offset=8).reshape(npts, ndims)
|
||||
|
||||
|
||||
class Timer:
|
||||
last = perf_counter()
|
||||
|
||||
def reset(self):
|
||||
new = perf_counter()
|
||||
self.last = new
|
||||
|
||||
def elapsed(self, round_digit:int = 3):
|
||||
new = perf_counter()
|
||||
elapsed_time = new - self.last
|
||||
self.last = new
|
||||
return round(elapsed_time, round_digit)
|
||||
|
||||
|
||||
def numpy_to_bin(array, out_file):
|
||||
shape = np.shape(array)
|
||||
npts = shape[0].astype(np.uint32)
|
||||
ndims = shape[1].astype(np.uint32)
|
||||
f = open(out_file, "wb")
|
||||
f.write(npts.tobytes())
|
||||
f.write(ndims.tobytes())
|
||||
f.write(array.tobytes())
|
||||
f.close()
|
||||
|
||||
|
||||
def read_gt_file(gt_file) -> Tuple[np.ndarray[int], np.ndarray[float]]:
|
||||
"""
|
||||
Return ids and distances to queries
|
||||
"""
|
||||
nq, K = get_bin_metadata(gt_file)
|
||||
ids = np.fromfile(file=gt_file, dtype=np.uint32, offset=8, count=nq * K).reshape(
|
||||
nq, K
|
||||
)
|
||||
dists = np.fromfile(
|
||||
file=gt_file, dtype=np.float32, offset=8 + nq * K * 4, count=nq * K
|
||||
).reshape(nq, K)
|
||||
return ids, dists
|
||||
|
||||
|
||||
def calculate_recall(
|
||||
result_set_indices: np.ndarray[int],
|
||||
truth_set_indices: np.ndarray[int],
|
||||
recall_at: int = 5,
|
||||
) -> float:
|
||||
"""
|
||||
result_set_indices and truth_set_indices correspond by row index. the columns in each row contain the indices of
|
||||
the nearest neighbors, with result_set_indices being the approximate nearest neighbor results and truth_set_indices
|
||||
being the brute force nearest neighbor calculation via sklearn's NearestNeighbor class.
|
||||
:param result_set_indices:
|
||||
:param truth_set_indices:
|
||||
:param recall_at:
|
||||
:return:
|
||||
"""
|
||||
found = 0
|
||||
for i in range(0, result_set_indices.shape[0]):
|
||||
result_set_set = set(result_set_indices[i][0:recall_at])
|
||||
truth_set_set = set(truth_set_indices[i][0:recall_at])
|
||||
found += len(result_set_set.intersection(truth_set_set))
|
||||
return found / (result_set_indices.shape[0] * recall_at)
|
||||
|
||||
|
||||
def calculate_recall_from_gt_file(K: int, ids: np.ndarray[int], gt_file: str) -> float:
|
||||
"""
|
||||
Calculate recall from ids returned from search and those read from file
|
||||
"""
|
||||
gt_ids, gt_dists = read_gt_file(gt_file)
|
||||
return calculate_recall(ids, gt_ids, K)
|
||||
|
||||
|
||||
def cluster_and_permute(
|
||||
dtype_str, npts, ndims, data, num_clusters
|
||||
) -> Tuple[np.ndarray[int], np.ndarray[int]]:
|
||||
"""
|
||||
Cluster the data and return permutation of row indices
|
||||
that would group indices of the same cluster together
|
||||
"""
|
||||
sample_size = min(100000, npts)
|
||||
sample_indices = np.random.choice(range(npts), size=sample_size, replace=False)
|
||||
sampled_data = data[sample_indices, :]
|
||||
centroids, sample_labels = kmeans2(sampled_data, num_clusters, minit="++", iter=10)
|
||||
labels, dist = vq(data, centroids)
|
||||
|
||||
count = np.zeros(num_clusters)
|
||||
for i in range(npts):
|
||||
count[labels[i]] += 1
|
||||
print("Cluster counts")
|
||||
print(count)
|
||||
|
||||
offsets = np.zeros(num_clusters + 1, dtype=int)
|
||||
for i in range(0, num_clusters, 1):
|
||||
offsets[i + 1] = offsets[i] + count[i]
|
||||
|
||||
permutation = np.zeros(npts, dtype=int)
|
||||
counters = np.zeros(num_clusters, dtype=int)
|
||||
for i in range(npts):
|
||||
label = labels[i]
|
||||
row = offsets[label] + counters[label]
|
||||
counters[label] += 1
|
||||
permutation[row] = i
|
||||
|
||||
return offsets, permutation
|
||||
Reference in New Issue
Block a user