212 lines
7.6 KiB
Python
212 lines
7.6 KiB
Python
import sys
|
|
import time
|
|
import faiss
|
|
import numpy as np
|
|
import pickle
|
|
import os
|
|
import json
|
|
import time
|
|
import torch
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
import subprocess
|
|
|
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
|
sys.path.append(os.path.join(project_root, "demo"))
|
|
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS, get_embedding_path
|
|
sys.path.append(project_root)
|
|
from contriever.src.contriever import Contriever, load_retriever
|
|
|
|
M = 32
|
|
efConstruction = 256
|
|
K_NEIGHBORS = 3
|
|
|
|
# Original configuration (commented out)
|
|
# DB_EMBEDDING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/passages_00.pkl"
|
|
# INDEX_SAVING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices"
|
|
|
|
# New configuration using DPR
|
|
DOMAIN_NAME = "dpr"
|
|
EMBEDDER_NAME = "facebook/contriever-msmarco"
|
|
TASK_NAME = "nq"
|
|
MAX_QUERIES_TO_LOAD = 1000
|
|
QUERY_ENCODING_BATCH_SIZE = 64
|
|
|
|
# Get the embedding path using the function from config
|
|
embed_path = get_embedding_path(DOMAIN_NAME, EMBEDDER_NAME, 0)
|
|
INDEX_SAVING_FILE = os.path.join(os.path.dirname(embed_path), "indices")
|
|
os.makedirs(INDEX_SAVING_FILE, exist_ok=True)
|
|
|
|
# Load embeddings
|
|
print(f"Loading embeddings from {embed_path}...")
|
|
with open(embed_path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
|
|
xb = data[1]
|
|
print(f"Original dtype: {xb.dtype}")
|
|
|
|
if xb.dtype != np.float32:
|
|
print("Converting embeddings to float32.")
|
|
xb = xb.astype(np.float32)
|
|
else:
|
|
print("Embeddings are already float32.")
|
|
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
|
|
d = xb.shape[1] # Get dimension
|
|
|
|
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
|
|
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
|
|
|
|
query_texts = []
|
|
print(f"Reading queries from: {query_file_path}")
|
|
with open(query_file_path, 'r') as f:
|
|
for i, line in enumerate(f):
|
|
if i >= MAX_QUERIES_TO_LOAD:
|
|
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
|
|
break
|
|
record = json.loads(line)
|
|
query_texts.append(record["query"])
|
|
print(f"Loaded {len(query_texts)} query texts.")
|
|
|
|
print("\nInitializing retriever model for encoding queries...")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Using device: {device}")
|
|
model, tokenizer, _ = load_retriever(EMBEDDER_NAME)
|
|
model.to(device)
|
|
model.eval() # Set to evaluation mode
|
|
print("Retriever model loaded.")
|
|
|
|
|
|
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
|
"""Embed queries using the model with batching"""
|
|
model = model.half()
|
|
model.eval()
|
|
embeddings = []
|
|
batch_question = []
|
|
|
|
with torch.no_grad():
|
|
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
|
batch_question.append(query)
|
|
|
|
# Process when batch is full or at the end
|
|
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
|
encoded_batch = tokenizer.batch_encode_plus(
|
|
batch_question,
|
|
return_tensors="pt",
|
|
max_length=512,
|
|
padding=True,
|
|
truncation=True,
|
|
)
|
|
|
|
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
|
output = model(**encoded_batch)
|
|
|
|
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
|
if "contriever" not in model_name_or_path:
|
|
output = output.last_hidden_state[:, 0, :]
|
|
|
|
embeddings.append(output.cpu())
|
|
batch_question = [] # Reset batch
|
|
|
|
embeddings = torch.cat(embeddings, dim=0).numpy()
|
|
print(f"Query embeddings shape: {embeddings.shape}")
|
|
return embeddings
|
|
|
|
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
|
|
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
|
|
|
|
# Ensure float32 for Faiss compatibility after encoding
|
|
if xq_full.dtype != np.float32:
|
|
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
|
|
xq_full = xq_full.astype(np.float32)
|
|
|
|
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
|
|
|
|
# Check dimension consistency
|
|
if xq_full.shape[1] != d:
|
|
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
|
|
|
|
# Build flat index for ground truth
|
|
print("\nBuilding FlatIP index for ground truth...")
|
|
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
|
|
index_flat.add(xb)
|
|
print(f"Searching FlatIP index with {len(xq_full)} queries (k={K_NEIGHBORS})...")
|
|
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
|
|
|
|
# Create a specific directory for this index configuration
|
|
index_dir = f"{INDEX_SAVING_FILE}/hahahdpr_hnsw_IP_M{M}_efC{efConstruction}"
|
|
os.makedirs(index_dir, exist_ok=True)
|
|
index_filename = f"{index_dir}/index.faiss"
|
|
|
|
# Check if index already exists
|
|
if os.path.exists(index_filename):
|
|
print(f"Found existing index at {index_filename}, loading...")
|
|
index = faiss.read_index(index_filename)
|
|
print("Index loaded successfully.")
|
|
else:
|
|
print('Building HNSW index (IP)...')
|
|
# add build time
|
|
start_time = time.time()
|
|
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
|
|
index.hnsw.efConstruction = efConstruction
|
|
index.add(xb)
|
|
end_time = time.time()
|
|
print(f'time: {end_time - start_time}')
|
|
print('HNSW index built.')
|
|
|
|
# Save the HNSW index
|
|
print(f"Saving index to {index_filename}...")
|
|
faiss.write_index(index, index_filename)
|
|
print("Index saved successfully.")
|
|
|
|
# Analyze the HNSW index
|
|
print("\nAnalyzing HNSW index...")
|
|
print(f"Total number of nodes: {index.ntotal}")
|
|
print("Neighbor statistics:")
|
|
print(index.hnsw.print_neighbor_stats(0))
|
|
|
|
# Save degree distribution
|
|
distribution_filename = f"{index_dir}/degree_distribution.txt"
|
|
print(f"Saving degree distribution to {distribution_filename}...")
|
|
index.hnsw.save_degree_distribution(0, distribution_filename)
|
|
print("Degree distribution saved successfully.")
|
|
|
|
# Plot the degree distribution
|
|
plot_output_path = f"{index_dir}/degree_distribution.png"
|
|
print(f"Generating degree distribution plot to {plot_output_path}...")
|
|
try:
|
|
subprocess.run(
|
|
["python", f"{project_root}/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
|
|
check=True
|
|
)
|
|
print(f"Degree distribution plot saved to {plot_output_path}")
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error generating degree distribution plot: {e}")
|
|
except FileNotFoundError:
|
|
print("Warning: plot_degree_distribution.py script not found in specified path")
|
|
|
|
print('Searching HNSW index...')
|
|
|
|
recall_result_file = f"{index_dir}/recall_result.txt"
|
|
with open(recall_result_file, 'w') as f:
|
|
for efSearch in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
|
|
index.hnsw.efSearch = efSearch
|
|
# calculate the time of searching
|
|
start_time = time.time()
|
|
|
|
D, I = index.search(xq_full, K_NEIGHBORS)
|
|
end_time = time.time()
|
|
print(f'time: {end_time - start_time}')
|
|
|
|
# calculate the recall using the flat index
|
|
recall = []
|
|
for i in range(len(I)):
|
|
acc = 0
|
|
for j in range(len(I[i])):
|
|
if I[i][j] in recall_idx_flat[i]:
|
|
acc += 1
|
|
recall.append(acc / len(I[i]))
|
|
recall = sum(recall) / len(recall)
|
|
print(f'efSearch: {efSearch}')
|
|
print(f'recall: {recall}')
|
|
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
|
|
print(f'Done and result saved to {recall_result_file}') |