Files
LEANN/packages/leann-backend-hnsw/third_party/faiss/demo/simple_build_dpr.py
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

212 lines
7.6 KiB
Python

import sys
import time
import faiss
import numpy as np
import pickle
import os
import json
import time
import torch
from tqdm import tqdm
from pathlib import Path
import subprocess
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(os.path.join(project_root, "demo"))
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS, get_embedding_path
sys.path.append(project_root)
from contriever.src.contriever import Contriever, load_retriever
M = 32
efConstruction = 256
K_NEIGHBORS = 3
# Original configuration (commented out)
# DB_EMBEDDING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/passages_00.pkl"
# INDEX_SAVING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices"
# New configuration using DPR
DOMAIN_NAME = "dpr"
EMBEDDER_NAME = "facebook/contriever-msmarco"
TASK_NAME = "nq"
MAX_QUERIES_TO_LOAD = 1000
QUERY_ENCODING_BATCH_SIZE = 64
# Get the embedding path using the function from config
embed_path = get_embedding_path(DOMAIN_NAME, EMBEDDER_NAME, 0)
INDEX_SAVING_FILE = os.path.join(os.path.dirname(embed_path), "indices")
os.makedirs(INDEX_SAVING_FILE, exist_ok=True)
# Load embeddings
print(f"Loading embeddings from {embed_path}...")
with open(embed_path, 'rb') as f:
data = pickle.load(f)
xb = data[1]
print(f"Original dtype: {xb.dtype}")
if xb.dtype != np.float32:
print("Converting embeddings to float32.")
xb = xb.astype(np.float32)
else:
print("Embeddings are already float32.")
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
d = xb.shape[1] # Get dimension
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
query_texts = []
print(f"Reading queries from: {query_file_path}")
with open(query_file_path, 'r') as f:
for i, line in enumerate(f):
if i >= MAX_QUERIES_TO_LOAD:
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
break
record = json.loads(line)
query_texts.append(record["query"])
print(f"Loaded {len(query_texts)} query texts.")
print("\nInitializing retriever model for encoding queries...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model, tokenizer, _ = load_retriever(EMBEDDER_NAME)
model.to(device)
model.eval() # Set to evaluation mode
print("Retriever model loaded.")
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
"""Embed queries using the model with batching"""
model = model.half()
model.eval()
embeddings = []
batch_question = []
with torch.no_grad():
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
batch_question.append(query)
# Process when batch is full or at the end
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=512,
padding=True,
truncation=True,
)
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
output = model(**encoded_batch)
# Contriever typically uses output.last_hidden_state pooling or something specialized
if "contriever" not in model_name_or_path:
output = output.last_hidden_state[:, 0, :]
embeddings.append(output.cpu())
batch_question = [] # Reset batch
embeddings = torch.cat(embeddings, dim=0).numpy()
print(f"Query embeddings shape: {embeddings.shape}")
return embeddings
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
# Ensure float32 for Faiss compatibility after encoding
if xq_full.dtype != np.float32:
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
xq_full = xq_full.astype(np.float32)
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
# Check dimension consistency
if xq_full.shape[1] != d:
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
# Build flat index for ground truth
print("\nBuilding FlatIP index for ground truth...")
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
index_flat.add(xb)
print(f"Searching FlatIP index with {len(xq_full)} queries (k={K_NEIGHBORS})...")
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
# Create a specific directory for this index configuration
index_dir = f"{INDEX_SAVING_FILE}/hahahdpr_hnsw_IP_M{M}_efC{efConstruction}"
os.makedirs(index_dir, exist_ok=True)
index_filename = f"{index_dir}/index.faiss"
# Check if index already exists
if os.path.exists(index_filename):
print(f"Found existing index at {index_filename}, loading...")
index = faiss.read_index(index_filename)
print("Index loaded successfully.")
else:
print('Building HNSW index (IP)...')
# add build time
start_time = time.time()
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = efConstruction
index.add(xb)
end_time = time.time()
print(f'time: {end_time - start_time}')
print('HNSW index built.')
# Save the HNSW index
print(f"Saving index to {index_filename}...")
faiss.write_index(index, index_filename)
print("Index saved successfully.")
# Analyze the HNSW index
print("\nAnalyzing HNSW index...")
print(f"Total number of nodes: {index.ntotal}")
print("Neighbor statistics:")
print(index.hnsw.print_neighbor_stats(0))
# Save degree distribution
distribution_filename = f"{index_dir}/degree_distribution.txt"
print(f"Saving degree distribution to {distribution_filename}...")
index.hnsw.save_degree_distribution(0, distribution_filename)
print("Degree distribution saved successfully.")
# Plot the degree distribution
plot_output_path = f"{index_dir}/degree_distribution.png"
print(f"Generating degree distribution plot to {plot_output_path}...")
try:
subprocess.run(
["python", f"{project_root}/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
check=True
)
print(f"Degree distribution plot saved to {plot_output_path}")
except subprocess.CalledProcessError as e:
print(f"Error generating degree distribution plot: {e}")
except FileNotFoundError:
print("Warning: plot_degree_distribution.py script not found in specified path")
print('Searching HNSW index...')
recall_result_file = f"{index_dir}/recall_result.txt"
with open(recall_result_file, 'w') as f:
for efSearch in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
index.hnsw.efSearch = efSearch
# calculate the time of searching
start_time = time.time()
D, I = index.search(xq_full, K_NEIGHBORS)
end_time = time.time()
print(f'time: {end_time - start_time}')
# calculate the recall using the flat index
recall = []
for i in range(len(I)):
acc = 0
for j in range(len(I[i])):
if I[i][j] in recall_idx_flat[i]:
acc += 1
recall.append(acc / len(I[i]))
recall = sum(recall) / len(recall)
print(f'efSearch: {efSearch}')
print(f'recall: {recall}')
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
print(f'Done and result saved to {recall_result_file}')