329 lines
12 KiB
Python
329 lines
12 KiB
Python
import sys
|
|
import time
|
|
import faiss
|
|
import numpy as np
|
|
import pickle
|
|
import os
|
|
import json
|
|
import time
|
|
import torch
|
|
import argparse
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
import subprocess
|
|
|
|
# Add argument parsing
|
|
parser = argparse.ArgumentParser(description='Build and evaluate HNSW index')
|
|
parser.add_argument('--config', type=str, default="0.02per_M6-7_degree_based",
|
|
help='Configuration name for the index (default: 0.02per_M6-7_degree_based)')
|
|
parser.add_argument('--M', type=int, default=32,
|
|
help='HNSW M parameter (default: 32)')
|
|
parser.add_argument('--efConstruction', type=int, default=256,
|
|
help='HNSW efConstruction parameter (default: 256)')
|
|
parser.add_argument('--K_NEIGHBORS', type=int, default=3,
|
|
help='Number of neighbors to retrieve (default: 3)')
|
|
parser.add_argument('--max_queries', type=int, default=1000,
|
|
help='Maximum number of queries to load (default: 1000)')
|
|
parser.add_argument('--batch_size', type=int, default=64,
|
|
help='Batch size for query encoding (default: 64)')
|
|
parser.add_argument('--db_embedding_file', type=str,
|
|
default="/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/passages_00.pkl",
|
|
help='Path to database embedding file')
|
|
parser.add_argument('--index_saving_dir', type=str,
|
|
default="/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/indices",
|
|
help='Directory to save the index')
|
|
parser.add_argument('--task_name', type=str, default="nq",
|
|
help='Task name from TASK_CONFIGS (default: nq)')
|
|
parser.add_argument('--embedder_model', type=str, default="facebook/contriever-msmarco",
|
|
help='Model name for query embedding (default: facebook/contriever-msmarco)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Replace hardcoded constants with arguments
|
|
M = args.M
|
|
efConstruction = args.efConstruction
|
|
K_NEIGHBORS = args.K_NEIGHBORS
|
|
DB_EMBEDDING_FILE = args.db_embedding_file
|
|
INDEX_SAVING_FILE = args.index_saving_dir
|
|
TASK_NAME = args.task_name
|
|
EMBEDDER_MODEL_NAME = args.embedder_model
|
|
MAX_QUERIES_TO_LOAD = args.max_queries
|
|
QUERY_ENCODING_BATCH_SIZE = args.batch_size
|
|
CONFIG_NAME = args.config
|
|
|
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
|
sys.path.append(os.path.join(project_root, "demo"))
|
|
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS
|
|
sys.path.append(project_root)
|
|
from contriever.src.contriever import Contriever, load_retriever
|
|
|
|
# 1M samples
|
|
print(f"Loading embeddings from {DB_EMBEDDING_FILE}...")
|
|
with open(DB_EMBEDDING_FILE, 'rb') as f:
|
|
data = pickle.load(f)
|
|
|
|
xb = data[1]
|
|
print(f"Original dtype: {xb.dtype}")
|
|
|
|
if xb.dtype != np.float32:
|
|
print("Converting embeddings to float32.")
|
|
xb = xb.astype(np.float32)
|
|
else:
|
|
print("Embeddings are already float32.")
|
|
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
|
|
d = xb.shape[1] # Get dimension
|
|
|
|
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
|
|
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
|
|
|
|
query_texts = []
|
|
print(f"Reading queries from: {query_file_path}")
|
|
with open(query_file_path, 'r') as f:
|
|
for i, line in enumerate(f):
|
|
if i >= MAX_QUERIES_TO_LOAD:
|
|
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
|
|
break
|
|
record = json.loads(line)
|
|
query_texts.append(record["query"])
|
|
print(f"Loaded {len(query_texts)} query texts.")
|
|
|
|
print("\nInitializing retriever model for encoding queries...")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Using device: {device}")
|
|
model, tokenizer, _ = load_retriever(EMBEDDER_MODEL_NAME)
|
|
model.to(device)
|
|
model.eval() # Set to evaluation mode
|
|
print("Retriever model loaded.")
|
|
|
|
|
|
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
|
"""Embed queries using the model with batching"""
|
|
model = model.half()
|
|
model.eval()
|
|
embeddings = []
|
|
batch_question = []
|
|
|
|
with torch.no_grad():
|
|
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
|
batch_question.append(query)
|
|
|
|
# Process when batch is full or at the end
|
|
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
|
encoded_batch = tokenizer.batch_encode_plus(
|
|
batch_question,
|
|
return_tensors="pt",
|
|
max_length=512,
|
|
padding=True,
|
|
truncation=True,
|
|
)
|
|
|
|
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
|
output = model(**encoded_batch)
|
|
|
|
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
|
if "contriever" not in model_name_or_path:
|
|
output = output.last_hidden_state[:, 0, :]
|
|
|
|
embeddings.append(output.cpu())
|
|
batch_question = [] # Reset batch
|
|
|
|
embeddings = torch.cat(embeddings, dim=0).numpy()
|
|
print(f"Query embeddings shape: {embeddings.shape}")
|
|
return embeddings
|
|
|
|
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
|
|
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_MODEL_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
|
|
|
|
# Ensure float32 for Faiss compatibility after encoding
|
|
if xq_full.dtype != np.float32:
|
|
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
|
|
xq_full = xq_full.astype(np.float32)
|
|
|
|
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
|
|
|
|
# Check dimension consistency
|
|
if xq_full.shape[1] != d:
|
|
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
|
|
|
|
# recall_idx = []
|
|
|
|
print("\nBuilding FlatIP index for ground truth...")
|
|
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
|
|
index_flat.add(xb)
|
|
print(f"Searching FlatIP index with {MAX_QUERIES_TO_LOAD} queries (k={K_NEIGHBORS})...")
|
|
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
|
|
|
|
# print(recall_idx_flat)
|
|
|
|
# Create a specific directory for this index configuration
|
|
index_dir = f"{INDEX_SAVING_FILE}/{CONFIG_NAME}_hnsw_IP_M{M}_efC{efConstruction}"
|
|
if CONFIG_NAME == "origin":
|
|
index_dir = f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}"
|
|
os.makedirs(index_dir, exist_ok=True)
|
|
index_filename = f"{index_dir}/index.faiss"
|
|
|
|
# Check if index already exists
|
|
if os.path.exists(index_filename):
|
|
print(f"Found existing index at {index_filename}, loading...")
|
|
index = faiss.read_index(index_filename)
|
|
print("Index loaded successfully.")
|
|
else:
|
|
print('Building HNSW index (IP)...')
|
|
# add build time
|
|
start_time = time.time()
|
|
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
|
|
index.hnsw.efConstruction = efConstruction
|
|
index.hnsw.set_percentile_thresholds()
|
|
index.add(xb)
|
|
end_time = time.time()
|
|
print(f'time: {end_time - start_time}')
|
|
print('HNSW index built.')
|
|
|
|
# Save the HNSW index
|
|
print(f"Saving index to {index_filename}...")
|
|
faiss.write_index(index, index_filename)
|
|
print("Index saved successfully.")
|
|
# Analyze the HNSW index
|
|
print("\nAnalyzing HNSW index...")
|
|
print(f"Total number of nodes: {index.ntotal}")
|
|
print("Neighbor statistics:")
|
|
print(index.hnsw.print_neighbor_stats(0))
|
|
|
|
# Save degree distribution
|
|
distribution_filename = f"{index_dir}/degree_distribution.txt"
|
|
print(f"Saving degree distribution to {distribution_filename}...")
|
|
index.hnsw.save_degree_distribution(0, distribution_filename)
|
|
print("Degree distribution saved successfully.")
|
|
|
|
# Plot the degree distribution
|
|
plot_output_path = f"{index_dir}/degree_distribution.png"
|
|
print(f"Generating degree distribution plot to {plot_output_path}...")
|
|
try:
|
|
subprocess.run(
|
|
["python", "/home/ubuntu/Power-RAG/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
|
|
check=True
|
|
)
|
|
print(f"Degree distribution plot saved to {plot_output_path}")
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error generating degree distribution plot: {e}")
|
|
except FileNotFoundError:
|
|
print("Warning: plot_degree_distribution.py script not found in current directory")
|
|
|
|
print('Searching HNSW index...')
|
|
|
|
|
|
|
|
# for efSearch in [2, 4, 8, 16, 32, 64,128,256,512,1024]:
|
|
# print(f'*************efSearch: {efSearch}*************')
|
|
# for i in range(10):
|
|
# index.hnsw.efSearch = efSearch
|
|
# D, I = index.search(xq_full[i:i+1], K_NEIGHBORS)
|
|
# exit()
|
|
|
|
|
|
recall_result_file = f"{index_dir}/recall_result.txt"
|
|
time_list = []
|
|
recall_list = []
|
|
recompute_list = []
|
|
with open(recall_result_file, 'w') as f:
|
|
for efSearch in [2, 4, 8, 16, 24, 32, 48, 64, 96,114,128,144,160,176,192,208,224,240,256,384,420,440,460,480,512,768,1024,1152,1536,1792,2048,2230,2408,2880]:
|
|
index.hnsw.efSearch = efSearch
|
|
# calculate the time of searching
|
|
start_time = time.time()
|
|
faiss.cvar.hnsw_stats.reset()
|
|
# print faiss.cvar.hnsw_stats.ndis
|
|
print(f'ndis: {faiss.cvar.hnsw_stats.ndis}')
|
|
D, I = index.search(xq_full, K_NEIGHBORS)
|
|
print('D[0]:', D[0])
|
|
end_time = time.time()
|
|
print(f'time: {end_time - start_time}')
|
|
time_list.append(end_time - start_time)
|
|
print("recompute:", faiss.cvar.hnsw_stats.ndis/len(I))
|
|
recompute_list.append(faiss.cvar.hnsw_stats.ndis/len(I))
|
|
# print(I)
|
|
|
|
# calculate the recall using the flat index the formula:
|
|
# recall = sum(recall_idx == recall_idx_flat) / len(recall_idx)
|
|
recall=[]
|
|
for i in range(len(I)):
|
|
acc = 0
|
|
for j in range(len(I[i])):
|
|
if I[i][j] in recall_idx_flat[i]:
|
|
acc += 1
|
|
recall.append(acc / len(I[i]))
|
|
recall = sum(recall) / len(recall)
|
|
recall_list.append(recall)
|
|
print(f'efSearch: {efSearch}')
|
|
print(f'recall: {recall}')
|
|
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
|
|
print(f'Done and result saved to {recall_result_file}')
|
|
print(f'time_list: {time_list}')
|
|
print(f'recall_list: {recall_list}')
|
|
print(f'recompute_list: {recompute_list}')
|
|
exit()
|
|
# Analyze edge stats
|
|
print("\nAnalyzing edge statistics...")
|
|
edge_stats_file = f"{index_dir}/edge_stats.txt"
|
|
if not os.path.exists(edge_stats_file):
|
|
index.save_edge_stats(edge_stats_file)
|
|
print(f'Edge stats saved to {edge_stats_file}')
|
|
else:
|
|
print(f'Edge stats already exists at {edge_stats_file}')
|
|
|
|
|
|
def analyze_edge_stats(filename):
|
|
"""
|
|
Analyze edge statistics from a CSV file and print thresholds at various percentiles.
|
|
|
|
Args:
|
|
filename: Path to the edge statistics CSV file
|
|
"""
|
|
if not os.path.exists(filename):
|
|
print(f"Error: File {filename} does not exist")
|
|
return
|
|
|
|
print(f"Analyzing edge statistics from {filename}...")
|
|
|
|
# Read the file
|
|
distances = []
|
|
with open(filename, 'r') as f:
|
|
# Skip header
|
|
header = f.readline()
|
|
|
|
# Read all edges
|
|
for line in f:
|
|
parts = line.strip().split(',')
|
|
if len(parts) >= 4:
|
|
try:
|
|
src = int(parts[0])
|
|
dst = int(parts[1])
|
|
level = int(parts[2])
|
|
distance = float(parts[3])
|
|
distances.append(distance)
|
|
except ValueError:
|
|
continue
|
|
|
|
if not distances:
|
|
print("No valid edges found in file")
|
|
return
|
|
|
|
# Sort distances
|
|
distances = np.array(distances)
|
|
distances.sort()
|
|
|
|
# Calculate and print statistics
|
|
print(f"Total edges: {len(distances)}")
|
|
print(f"Min distance: {distances[0]:.6f}")
|
|
print(f"Max distance: {distances[-1]:.6f}")
|
|
|
|
# Print thresholds at specified percentiles
|
|
percentiles = [0.5, 1, 2, 3, 5, 8, 10, 15, 20,30,40,50,60,70]
|
|
print("\nDistance thresholds at percentiles:")
|
|
for p in percentiles:
|
|
idx = int(len(distances) * p / 100)
|
|
if idx < len(distances):
|
|
print(f"{p:.1f}%: {distances[idx]:.6f}")
|
|
|
|
return distances
|
|
|
|
analyze_edge_stats(edge_stats_file) |