Initial commit
This commit is contained in:
329
packages/leann-backend-hnsw/third_party/faiss/demo/simple_build.py
vendored
Normal file
329
packages/leann-backend-hnsw/third_party/faiss/demo/simple_build.py
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
import sys
|
||||
import time
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
# Add argument parsing
|
||||
parser = argparse.ArgumentParser(description='Build and evaluate HNSW index')
|
||||
parser.add_argument('--config', type=str, default="0.02per_M6-7_degree_based",
|
||||
help='Configuration name for the index (default: 0.02per_M6-7_degree_based)')
|
||||
parser.add_argument('--M', type=int, default=32,
|
||||
help='HNSW M parameter (default: 32)')
|
||||
parser.add_argument('--efConstruction', type=int, default=256,
|
||||
help='HNSW efConstruction parameter (default: 256)')
|
||||
parser.add_argument('--K_NEIGHBORS', type=int, default=3,
|
||||
help='Number of neighbors to retrieve (default: 3)')
|
||||
parser.add_argument('--max_queries', type=int, default=1000,
|
||||
help='Maximum number of queries to load (default: 1000)')
|
||||
parser.add_argument('--batch_size', type=int, default=64,
|
||||
help='Batch size for query encoding (default: 64)')
|
||||
parser.add_argument('--db_embedding_file', type=str,
|
||||
default="/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/passages_00.pkl",
|
||||
help='Path to database embedding file')
|
||||
parser.add_argument('--index_saving_dir', type=str,
|
||||
default="/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/indices",
|
||||
help='Directory to save the index')
|
||||
parser.add_argument('--task_name', type=str, default="nq",
|
||||
help='Task name from TASK_CONFIGS (default: nq)')
|
||||
parser.add_argument('--embedder_model', type=str, default="facebook/contriever-msmarco",
|
||||
help='Model name for query embedding (default: facebook/contriever-msmarco)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Replace hardcoded constants with arguments
|
||||
M = args.M
|
||||
efConstruction = args.efConstruction
|
||||
K_NEIGHBORS = args.K_NEIGHBORS
|
||||
DB_EMBEDDING_FILE = args.db_embedding_file
|
||||
INDEX_SAVING_FILE = args.index_saving_dir
|
||||
TASK_NAME = args.task_name
|
||||
EMBEDDER_MODEL_NAME = args.embedder_model
|
||||
MAX_QUERIES_TO_LOAD = args.max_queries
|
||||
QUERY_ENCODING_BATCH_SIZE = args.batch_size
|
||||
CONFIG_NAME = args.config
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
sys.path.append(os.path.join(project_root, "demo"))
|
||||
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS
|
||||
sys.path.append(project_root)
|
||||
from contriever.src.contriever import Contriever, load_retriever
|
||||
|
||||
# 1M samples
|
||||
print(f"Loading embeddings from {DB_EMBEDDING_FILE}...")
|
||||
with open(DB_EMBEDDING_FILE, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
xb = data[1]
|
||||
print(f"Original dtype: {xb.dtype}")
|
||||
|
||||
if xb.dtype != np.float32:
|
||||
print("Converting embeddings to float32.")
|
||||
xb = xb.astype(np.float32)
|
||||
else:
|
||||
print("Embeddings are already float32.")
|
||||
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
|
||||
d = xb.shape[1] # Get dimension
|
||||
|
||||
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
|
||||
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
|
||||
|
||||
query_texts = []
|
||||
print(f"Reading queries from: {query_file_path}")
|
||||
with open(query_file_path, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= MAX_QUERIES_TO_LOAD:
|
||||
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
|
||||
break
|
||||
record = json.loads(line)
|
||||
query_texts.append(record["query"])
|
||||
print(f"Loaded {len(query_texts)} query texts.")
|
||||
|
||||
print("\nInitializing retriever model for encoding queries...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
model, tokenizer, _ = load_retriever(EMBEDDER_MODEL_NAME)
|
||||
model.to(device)
|
||||
model.eval() # Set to evaluation mode
|
||||
print("Retriever model loaded.")
|
||||
|
||||
|
||||
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
||||
"""Embed queries using the model with batching"""
|
||||
model = model.half()
|
||||
model.eval()
|
||||
embeddings = []
|
||||
batch_question = []
|
||||
|
||||
with torch.no_grad():
|
||||
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
||||
batch_question.append(query)
|
||||
|
||||
# Process when batch is full or at the end
|
||||
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_question,
|
||||
return_tensors="pt",
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
output = model(**encoded_batch)
|
||||
|
||||
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
||||
if "contriever" not in model_name_or_path:
|
||||
output = output.last_hidden_state[:, 0, :]
|
||||
|
||||
embeddings.append(output.cpu())
|
||||
batch_question = [] # Reset batch
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0).numpy()
|
||||
print(f"Query embeddings shape: {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
|
||||
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_MODEL_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
|
||||
|
||||
# Ensure float32 for Faiss compatibility after encoding
|
||||
if xq_full.dtype != np.float32:
|
||||
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
|
||||
xq_full = xq_full.astype(np.float32)
|
||||
|
||||
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
|
||||
|
||||
# Check dimension consistency
|
||||
if xq_full.shape[1] != d:
|
||||
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
|
||||
|
||||
# recall_idx = []
|
||||
|
||||
print("\nBuilding FlatIP index for ground truth...")
|
||||
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
|
||||
index_flat.add(xb)
|
||||
print(f"Searching FlatIP index with {MAX_QUERIES_TO_LOAD} queries (k={K_NEIGHBORS})...")
|
||||
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
|
||||
|
||||
# print(recall_idx_flat)
|
||||
|
||||
# Create a specific directory for this index configuration
|
||||
index_dir = f"{INDEX_SAVING_FILE}/{CONFIG_NAME}_hnsw_IP_M{M}_efC{efConstruction}"
|
||||
if CONFIG_NAME == "origin":
|
||||
index_dir = f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}"
|
||||
os.makedirs(index_dir, exist_ok=True)
|
||||
index_filename = f"{index_dir}/index.faiss"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(index_filename):
|
||||
print(f"Found existing index at {index_filename}, loading...")
|
||||
index = faiss.read_index(index_filename)
|
||||
print("Index loaded successfully.")
|
||||
else:
|
||||
print('Building HNSW index (IP)...')
|
||||
# add build time
|
||||
start_time = time.time()
|
||||
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
|
||||
index.hnsw.efConstruction = efConstruction
|
||||
index.hnsw.set_percentile_thresholds()
|
||||
index.add(xb)
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
print('HNSW index built.')
|
||||
|
||||
# Save the HNSW index
|
||||
print(f"Saving index to {index_filename}...")
|
||||
faiss.write_index(index, index_filename)
|
||||
print("Index saved successfully.")
|
||||
# Analyze the HNSW index
|
||||
print("\nAnalyzing HNSW index...")
|
||||
print(f"Total number of nodes: {index.ntotal}")
|
||||
print("Neighbor statistics:")
|
||||
print(index.hnsw.print_neighbor_stats(0))
|
||||
|
||||
# Save degree distribution
|
||||
distribution_filename = f"{index_dir}/degree_distribution.txt"
|
||||
print(f"Saving degree distribution to {distribution_filename}...")
|
||||
index.hnsw.save_degree_distribution(0, distribution_filename)
|
||||
print("Degree distribution saved successfully.")
|
||||
|
||||
# Plot the degree distribution
|
||||
plot_output_path = f"{index_dir}/degree_distribution.png"
|
||||
print(f"Generating degree distribution plot to {plot_output_path}...")
|
||||
try:
|
||||
subprocess.run(
|
||||
["python", "/home/ubuntu/Power-RAG/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
|
||||
check=True
|
||||
)
|
||||
print(f"Degree distribution plot saved to {plot_output_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error generating degree distribution plot: {e}")
|
||||
except FileNotFoundError:
|
||||
print("Warning: plot_degree_distribution.py script not found in current directory")
|
||||
|
||||
print('Searching HNSW index...')
|
||||
|
||||
|
||||
|
||||
# for efSearch in [2, 4, 8, 16, 32, 64,128,256,512,1024]:
|
||||
# print(f'*************efSearch: {efSearch}*************')
|
||||
# for i in range(10):
|
||||
# index.hnsw.efSearch = efSearch
|
||||
# D, I = index.search(xq_full[i:i+1], K_NEIGHBORS)
|
||||
# exit()
|
||||
|
||||
|
||||
recall_result_file = f"{index_dir}/recall_result.txt"
|
||||
time_list = []
|
||||
recall_list = []
|
||||
recompute_list = []
|
||||
with open(recall_result_file, 'w') as f:
|
||||
for efSearch in [2, 4, 8, 16, 24, 32, 48, 64, 96,114,128,144,160,176,192,208,224,240,256,384,420,440,460,480,512,768,1024,1152,1536,1792,2048,2230,2408,2880]:
|
||||
index.hnsw.efSearch = efSearch
|
||||
# calculate the time of searching
|
||||
start_time = time.time()
|
||||
faiss.cvar.hnsw_stats.reset()
|
||||
# print faiss.cvar.hnsw_stats.ndis
|
||||
print(f'ndis: {faiss.cvar.hnsw_stats.ndis}')
|
||||
D, I = index.search(xq_full, K_NEIGHBORS)
|
||||
print('D[0]:', D[0])
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
time_list.append(end_time - start_time)
|
||||
print("recompute:", faiss.cvar.hnsw_stats.ndis/len(I))
|
||||
recompute_list.append(faiss.cvar.hnsw_stats.ndis/len(I))
|
||||
# print(I)
|
||||
|
||||
# calculate the recall using the flat index the formula:
|
||||
# recall = sum(recall_idx == recall_idx_flat) / len(recall_idx)
|
||||
recall=[]
|
||||
for i in range(len(I)):
|
||||
acc = 0
|
||||
for j in range(len(I[i])):
|
||||
if I[i][j] in recall_idx_flat[i]:
|
||||
acc += 1
|
||||
recall.append(acc / len(I[i]))
|
||||
recall = sum(recall) / len(recall)
|
||||
recall_list.append(recall)
|
||||
print(f'efSearch: {efSearch}')
|
||||
print(f'recall: {recall}')
|
||||
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
|
||||
print(f'Done and result saved to {recall_result_file}')
|
||||
print(f'time_list: {time_list}')
|
||||
print(f'recall_list: {recall_list}')
|
||||
print(f'recompute_list: {recompute_list}')
|
||||
exit()
|
||||
# Analyze edge stats
|
||||
print("\nAnalyzing edge statistics...")
|
||||
edge_stats_file = f"{index_dir}/edge_stats.txt"
|
||||
if not os.path.exists(edge_stats_file):
|
||||
index.save_edge_stats(edge_stats_file)
|
||||
print(f'Edge stats saved to {edge_stats_file}')
|
||||
else:
|
||||
print(f'Edge stats already exists at {edge_stats_file}')
|
||||
|
||||
|
||||
def analyze_edge_stats(filename):
|
||||
"""
|
||||
Analyze edge statistics from a CSV file and print thresholds at various percentiles.
|
||||
|
||||
Args:
|
||||
filename: Path to the edge statistics CSV file
|
||||
"""
|
||||
if not os.path.exists(filename):
|
||||
print(f"Error: File {filename} does not exist")
|
||||
return
|
||||
|
||||
print(f"Analyzing edge statistics from {filename}...")
|
||||
|
||||
# Read the file
|
||||
distances = []
|
||||
with open(filename, 'r') as f:
|
||||
# Skip header
|
||||
header = f.readline()
|
||||
|
||||
# Read all edges
|
||||
for line in f:
|
||||
parts = line.strip().split(',')
|
||||
if len(parts) >= 4:
|
||||
try:
|
||||
src = int(parts[0])
|
||||
dst = int(parts[1])
|
||||
level = int(parts[2])
|
||||
distance = float(parts[3])
|
||||
distances.append(distance)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not distances:
|
||||
print("No valid edges found in file")
|
||||
return
|
||||
|
||||
# Sort distances
|
||||
distances = np.array(distances)
|
||||
distances.sort()
|
||||
|
||||
# Calculate and print statistics
|
||||
print(f"Total edges: {len(distances)}")
|
||||
print(f"Min distance: {distances[0]:.6f}")
|
||||
print(f"Max distance: {distances[-1]:.6f}")
|
||||
|
||||
# Print thresholds at specified percentiles
|
||||
percentiles = [0.5, 1, 2, 3, 5, 8, 10, 15, 20,30,40,50,60,70]
|
||||
print("\nDistance thresholds at percentiles:")
|
||||
for p in percentiles:
|
||||
idx = int(len(distances) * p / 100)
|
||||
if idx < len(distances):
|
||||
print(f"{p:.1f}%: {distances[idx]:.6f}")
|
||||
|
||||
return distances
|
||||
|
||||
analyze_edge_stats(edge_stats_file)
|
||||
Reference in New Issue
Block a user