Initial commit
This commit is contained in:
152
packages/leann-backend-diskann/third_party/DiskANN/python/apps/cli/__main__.py
vendored
Normal file
152
packages/leann-backend-diskann/third_party/DiskANN/python/apps/cli/__main__.py
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
import diskannpy as dap
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
import fire
|
||||
|
||||
from contextlib import contextmanager
|
||||
from time import perf_counter
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def _basic_setup(
|
||||
dtype: str,
|
||||
query_vectors_file: str
|
||||
) -> Tuple[dap.VectorDType, npt.NDArray[dap.VectorDType]]:
|
||||
_dtype = dap.valid_dtype(dtype)
|
||||
vectors_to_query = dap.vectors_from_binary(query_vectors_file, dtype=_dtype)
|
||||
return _dtype, vectors_to_query
|
||||
|
||||
|
||||
def dynamic(
|
||||
dtype: str,
|
||||
index_vectors_file: str,
|
||||
query_vectors_file: str,
|
||||
build_complexity: int,
|
||||
graph_degree: int,
|
||||
K: int,
|
||||
search_complexity: int,
|
||||
num_insert_threads: int,
|
||||
num_search_threads: int,
|
||||
gt_file: str = "",
|
||||
):
|
||||
_dtype, vectors_to_query = _basic_setup(dtype, query_vectors_file)
|
||||
vectors_to_index = dap.vectors_from_binary(index_vectors_file, dtype=_dtype)
|
||||
|
||||
npts, ndims = vectors_to_index.shape
|
||||
index = dap.DynamicMemoryIndex(
|
||||
"l2", _dtype, ndims, npts, build_complexity, graph_degree
|
||||
)
|
||||
|
||||
tags = np.arange(1, npts+1, dtype=np.uintc)
|
||||
timer = Timer()
|
||||
|
||||
with timer.time("batch insert"):
|
||||
index.batch_insert(vectors_to_index, tags, num_insert_threads)
|
||||
|
||||
delete_tags = np.random.choice(
|
||||
np.array(range(1, npts + 1, 1), dtype=np.uintc),
|
||||
size=int(0.5 * npts),
|
||||
replace=False
|
||||
)
|
||||
with timer.time("mark deletion"):
|
||||
for tag in delete_tags:
|
||||
index.mark_deleted(tag)
|
||||
|
||||
with timer.time("consolidation"):
|
||||
index.consolidate_delete()
|
||||
|
||||
deleted_data = vectors_to_index[delete_tags - 1, :]
|
||||
|
||||
with timer.time("re-insertion"):
|
||||
index.batch_insert(deleted_data, delete_tags, num_insert_threads)
|
||||
|
||||
with timer.time("batch searched"):
|
||||
tags, dists = index.batch_search(vectors_to_query, K, search_complexity, num_search_threads)
|
||||
|
||||
# res_ids = tags - 1
|
||||
# if gt_file != "":
|
||||
# recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file)
|
||||
# print(f"recall@{K} is {recall}")
|
||||
|
||||
def static(
|
||||
dtype: str,
|
||||
index_directory: str,
|
||||
index_vectors_file: str,
|
||||
query_vectors_file: str,
|
||||
build_complexity: int,
|
||||
graph_degree: int,
|
||||
K: int,
|
||||
search_complexity: int,
|
||||
num_threads: int,
|
||||
gt_file: str = "",
|
||||
index_prefix: str = "ann"
|
||||
):
|
||||
_dtype, vectors_to_query = _basic_setup(dtype, query_vectors_file)
|
||||
timer = Timer()
|
||||
with timer.time("build static index"):
|
||||
# build index
|
||||
dap.build_memory_index(
|
||||
data=index_vectors_file,
|
||||
metric="l2",
|
||||
vector_dtype=_dtype,
|
||||
index_directory=index_directory,
|
||||
complexity=build_complexity,
|
||||
graph_degree=graph_degree,
|
||||
num_threads=num_threads,
|
||||
index_prefix=index_prefix,
|
||||
alpha=1.2,
|
||||
use_pq_build=False,
|
||||
num_pq_bytes=8,
|
||||
use_opq=False,
|
||||
)
|
||||
|
||||
with timer.time("load static index"):
|
||||
# ready search object
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=_dtype,
|
||||
data_path=index_vectors_file,
|
||||
index_directory=index_directory,
|
||||
num_threads=num_threads, # this can be different at search time if you would like
|
||||
initial_search_complexity=search_complexity,
|
||||
index_prefix=index_prefix
|
||||
)
|
||||
|
||||
ids, dists = index.batch_search(vectors_to_query, K, search_complexity, num_threads)
|
||||
|
||||
# if gt_file != "":
|
||||
# recall = utils.calculate_recall_from_gt_file(K, ids, gt_file)
|
||||
# print(f"recall@{K} is {recall}")
|
||||
|
||||
def dynamic_clustered():
|
||||
pass
|
||||
|
||||
def generate_clusters():
|
||||
pass
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self._start = -1
|
||||
|
||||
@contextmanager
|
||||
def time(self, message: str):
|
||||
start = perf_counter()
|
||||
if self._start == -1:
|
||||
self._start = start
|
||||
yield
|
||||
now = perf_counter()
|
||||
print(f"Operation {message} completed in {(now - start):.3f}s, total: {(now - self._start):.3f}s")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire({
|
||||
"in-mem-dynamic": dynamic,
|
||||
"in-mem-static": static,
|
||||
"in-mem-dynamic-clustered": dynamic_clustered,
|
||||
"generate-clusters": generate_clusters
|
||||
}, name="cli")
|
||||
Reference in New Issue
Block a user