Initial commit
This commit is contained in:
77
packages/leann-backend-hnsw/third_party/faiss/demos/demo_qinco.py
vendored
Normal file
77
packages/leann-backend-hnsw/third_party/faiss/demos/demo_qinco.py
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This demonstrates how to reproduce the QINCo paper results using the Faiss
|
||||
QINCo implementation. The code loads the reference model because training
|
||||
is not implemented in Faiss.
|
||||
|
||||
Prepare the data with
|
||||
|
||||
cd /tmp
|
||||
|
||||
# get the reference qinco code
|
||||
git clone https://github.com/facebookresearch/Qinco.git
|
||||
|
||||
# get the data
|
||||
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs
|
||||
|
||||
# get the model
|
||||
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from faiss.contrib.vecs_io import bvecs_mmap
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import faiss
|
||||
|
||||
# make sure pickle deserialization will work
|
||||
sys.path.append("/tmp/Qinco")
|
||||
import model_qinco
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
qinco = torch.load("/tmp/bigann_8x8_L2.pt", weights_only=False)
|
||||
qinco.eval()
|
||||
# print(qinco)
|
||||
if True:
|
||||
torch.set_num_threads(1)
|
||||
faiss.omp_set_num_threads(1)
|
||||
|
||||
x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32')
|
||||
x_scaled = torch.from_numpy(x_base) / qinco.db_scale
|
||||
|
||||
t0 = time.time()
|
||||
codes, _ = qinco.encode(x_scaled)
|
||||
x_decoded_scaled = qinco.decode(codes)
|
||||
print(f"Pytorch encode {time.time() - t0:.3f} s")
|
||||
# multi-thread: 1.13s, single-thread: 7.744
|
||||
|
||||
x_decoded = x_decoded_scaled.numpy() * qinco.db_scale
|
||||
|
||||
err = ((x_decoded - x_base) ** 2).sum(1).mean()
|
||||
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper
|
||||
|
||||
qinco2 = faiss.QINCo(qinco)
|
||||
t0 = time.time()
|
||||
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled))
|
||||
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale
|
||||
print(f"Faiss encode {time.time() - t0:.3f} s")
|
||||
# multi-thread: 3.2s, single thread: 7.019
|
||||
|
||||
# these tests don't work because there are outlier encodings
|
||||
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy())
|
||||
# np.testing.assert_allclose(x_decoded, x_decoded2)
|
||||
|
||||
ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel()
|
||||
assert ndiff < 0.01
|
||||
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum()
|
||||
assert ndiff / len(x_base) < 0.01
|
||||
|
||||
err = ((x_decoded2 - x_base) ** 2).sum(1).mean()
|
||||
print("MSE=", err) # = 14213.551
|
||||
Reference in New Issue
Block a user