Initial commit
This commit is contained in:
82
packages/leann-backend-hnsw/third_party/faiss/benchs/bench_scalar_quantizer.py
vendored
Normal file
82
packages/leann-backend-hnsw/third_party/faiss/benchs/bench_scalar_quantizer.py
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import faiss
|
||||
from datasets import load_sift1M
|
||||
|
||||
|
||||
print("load data")
|
||||
|
||||
xb, xq, xt, gt = load_sift1M()
|
||||
nq, d = xq.shape
|
||||
|
||||
ncent = 256
|
||||
|
||||
variants = [(name, getattr(faiss.ScalarQuantizer, name))
|
||||
for name in dir(faiss.ScalarQuantizer)
|
||||
if name.startswith('QT_')]
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
# quantizer.add(np.zeros((1, d), dtype='float32'))
|
||||
|
||||
if False:
|
||||
for name, qtype in [('flat', 0)] + variants:
|
||||
|
||||
print("============== test", name)
|
||||
t0 = time.time()
|
||||
|
||||
if name == 'flat':
|
||||
index = faiss.IndexIVFFlat(quantizer, d, ncent,
|
||||
faiss.METRIC_L2)
|
||||
else:
|
||||
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
|
||||
qtype, faiss.METRIC_L2)
|
||||
|
||||
index.nprobe = 16
|
||||
print("[%.3f s] train" % (time.time() - t0))
|
||||
index.train(xt)
|
||||
print("[%.3f s] add" % (time.time() - t0))
|
||||
index.add(xb)
|
||||
print("[%.3f s] search" % (time.time() - t0))
|
||||
D, I = index.search(xq, 100)
|
||||
print("[%.3f s] eval" % (time.time() - t0))
|
||||
|
||||
for rank in 1, 10, 100:
|
||||
n_ok = (I[:, :rank] == gt[:, :1]).sum()
|
||||
print("%.4f" % (n_ok / float(nq)), end=' ')
|
||||
print()
|
||||
|
||||
if True:
|
||||
for name, qtype in variants:
|
||||
|
||||
print("============== test", name)
|
||||
|
||||
for rsname, vals in [('RS_minmax',
|
||||
[-0.4, -0.2, -0.1, -0.05, 0.0, 0.1, 0.5]),
|
||||
('RS_meanstd', [0.8, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0]),
|
||||
('RS_quantiles', [0.02, 0.05, 0.1, 0.15]),
|
||||
('RS_optim', [0.0])]:
|
||||
for val in vals:
|
||||
print("%-15s %5g " % (rsname, val), end=' ')
|
||||
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
|
||||
qtype, faiss.METRIC_L2)
|
||||
index.nprobe = 16
|
||||
index.sq.rangestat = getattr(faiss.ScalarQuantizer,
|
||||
rsname)
|
||||
|
||||
index.rangestat_arg = val
|
||||
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
t0 = time.time()
|
||||
D, I = index.search(xq, 100)
|
||||
t1 = time.time()
|
||||
|
||||
for rank in 1, 10, 100:
|
||||
n_ok = (I[:, :rank] == gt[:, :1]).sum()
|
||||
print("%.4f" % (n_ok / float(nq)), end=' ')
|
||||
print(" %.3f s" % (t1 - t0))
|
||||
Reference in New Issue
Block a user