Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

View File

@@ -0,0 +1,67 @@
import os
import pickle
import faiss
import numpy as np
import time
EMBEDDING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/passages_00.pkl"
INDEX_OUTPUT_DIR = "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw" # 保存索引的目录
M_VALUES_FOR_L2 = [30, 60] # M values for L2
EF_CONSTRUCTION_FOR_L2 = 128 # fixed efConstruction
if not os.path.exists(INDEX_OUTPUT_DIR):
print(f"Creating index directory: {INDEX_OUTPUT_DIR}")
os.makedirs(INDEX_OUTPUT_DIR)
print(f"Loading embeddings from {EMBEDDING_FILE}...")
with open(EMBEDDING_FILE, 'rb') as f:
data = pickle.load(f)
# Directly assume data is a tuple and the second element is embeddings
embeddings = data[1]
print(f"Converting embeddings from {embeddings.dtype} to float32.")
embeddings = embeddings.astype(np.float32)
print(f"Loaded embeddings, shape: {embeddings.shape}")
dim = embeddings.shape[1]
# --- Build HNSW L2 index ---
print("\n--- Build HNSW L2 index ---")
# Loop through M values
for HNSW_M in M_VALUES_FOR_L2:
efConstruction = EF_CONSTRUCTION_FOR_L2
print(f"\nBuilding HNSW L2 index: M={HNSW_M}, efConstruction={efConstruction}...")
# Define the filename and path for the L2 index
hnsw_filename = f"hnsw_IP_M{HNSW_M}_efC{efConstruction}.index"
hnsw_filepath = os.path.join(INDEX_OUTPUT_DIR, hnsw_filename)
# Note: No longer check if the file exists, it will be overwritten if it exists
# Create HNSW L2 index
index_hnsw = faiss.IndexHNSWFlat(dim, HNSW_M, faiss.METRIC_INNER_PRODUCT)
index_hnsw.hnsw.efConstruction = efConstruction
index_hnsw.verbose = True
print(f"Adding {embeddings.shape[0]} vectors to HNSW L2 (M={HNSW_M}) index...")
start_time_build = time.time()
index_hnsw.add(embeddings)
end_time_build = time.time()
build_time_s = end_time_build - start_time_build
print(f"HNSW L2 build time: {build_time_s:.4f} seconds")
# Save L2 index (direct operation, no try-except)
print(f"Saving HNSW L2 index to {hnsw_filepath}")
faiss.write_index(index_hnsw, hnsw_filepath)
# Do not check storage size or handle save errors
print(f"Index {hnsw_filename} saved.")
del index_hnsw
print("\n--- HNSW L2 index build completed ---")
print("\nScript ended.")

View File

@@ -0,0 +1,250 @@
import os
import pickle
import numpy as np
import time
import math
from pathlib import Path
# --- Configuration ---
DOMAIN_NAME = "rpj_wiki" # Domain name used for finding passages
EMBEDDER_NAME = "facebook/contriever-msmarco" # Used in paths
ORIGINAL_EMBEDDING_SHARD_ID = 0 # The shard ID of the embedding file we are loading
# Define the base directory
SCALING_OUT_DIR = Path("/powerrag/scaling_out").resolve()
# Original Data Paths (using functions similar to your utils)
# Assuming embeddings for rpj_wiki are in a single file despite passage sharding
# Adjust NUM_SHARDS_EMBEDDING if embeddings are also sharded
NUM_SHARDS_EMBEDDING = 1
ORIGINAL_EMBEDDING_FILE_TEMPLATE = (
SCALING_OUT_DIR
/ "embeddings/{embedder_name}/{domain_name}/{total_shards}-shards/passages_{shard_id:02d}.pkl"
)
ORIGINAL_EMBEDDING_FILE = str(ORIGINAL_EMBEDDING_FILE_TEMPLATE).format(
embedder_name=EMBEDDER_NAME,
domain_name=DOMAIN_NAME,
total_shards=NUM_SHARDS_EMBEDDING,
shard_id=ORIGINAL_EMBEDDING_SHARD_ID,
)
# Passage Paths
NUM_SHARDS_PASSAGE = 8 # As specified in your original utils (NUM_SHARDS['rpj_wiki'])
ORIGINAL_PASSAGE_FILE_TEMPLATE = (
SCALING_OUT_DIR
/ "passages/{domain_name}/{total_shards}-shards/raw_passages-{shard_id}-of-{total_shards}.pkl"
)
# New identifier for the sampled dataset
NEW_DATASET_NAME = "rpj_wiki_1M"
# Fraction to sample (1/60)
SAMPLE_FRACTION = 1 / 60
# Output Paths for the new sampled dataset
OUTPUT_EMBEDDING_DIR = SCALING_OUT_DIR / "embeddings" / EMBEDDER_NAME / NEW_DATASET_NAME / "1-shards"
OUTPUT_PASSAGE_DIR = SCALING_OUT_DIR / "passages" / NEW_DATASET_NAME / "1-shards"
OUTPUT_EMBEDDING_FILE = OUTPUT_EMBEDDING_DIR / f"passages_{ORIGINAL_EMBEDDING_SHARD_ID:02d}.pkl"
# The new passage file represents the *single* shard of the sampled data
OUTPUT_PASSAGE_FILE = OUTPUT_PASSAGE_DIR / f"raw_passages-0-of-1.pkl"
# --- Directory Setup ---
print("Creating output directories if they don't exist...")
OUTPUT_EMBEDDING_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_PASSAGE_DIR.mkdir(parents=True, exist_ok=True)
print(f"Embeddings output dir: {OUTPUT_EMBEDDING_DIR}")
print(f"Passages output dir: {OUTPUT_PASSAGE_DIR}")
# --- Helper Function to Load Passages ---
def load_all_passages(domain_name, num_shards, template):
"""Loads all passage shards and creates an ID-to-content map."""
all_passages_list = []
passage_id_to_content_map = {}
print(f"Loading passages for domain '{domain_name}' from {num_shards} shards...")
total_loaded = 0
start_time = time.time()
for shard_id in range(num_shards):
shard_path_str = str(template).format(
domain_name=domain_name,
total_shards=num_shards,
shard_id=shard_id,
)
shard_path = Path(shard_path_str)
if not shard_path.exists():
print(f"Warning: Passage shard file not found, skipping: {shard_path}")
continue
try:
print(f" Loading shard {shard_id} from {shard_path}...")
with open(shard_path, 'rb') as f:
shard_passages = pickle.load(f) # Expected: list of dicts
if not isinstance(shard_passages, list):
print(f"Warning: Shard {shard_id} data is not a list.")
continue
all_passages_list.extend(shard_passages)
# Build the map, ensuring IDs are strings for consistent lookup
for passage_dict in shard_passages:
if 'id' in passage_dict:
passage_id_to_content_map[str(passage_dict['id'])] = passage_dict
else:
print(f"Warning: Passage dict in shard {shard_id} missing 'id' key.")
print(f" Loaded {len(shard_passages)} passages from shard {shard_id}.")
total_loaded += len(shard_passages)
except Exception as e:
print(f"Error loading passage shard {shard_id} from {shard_path}: {e}")
load_time = time.time() - start_time
print(f"Finished loading passages. Total passages loaded: {total_loaded} in {load_time:.2f} seconds.")
print(f"Total unique passages mapped by ID: {len(passage_id_to_content_map)}")
return all_passages_list, passage_id_to_content_map
# --- Load Original Embeddings ---
print(f"\nLoading original embeddings from {ORIGINAL_EMBEDDING_FILE}...")
start_load_time = time.time()
try:
with open(ORIGINAL_EMBEDDING_FILE, 'rb') as f:
original_embedding_data = pickle.load(f)
except FileNotFoundError:
print(f"Error: Original embedding file not found at {ORIGINAL_EMBEDDING_FILE}")
exit(1)
except Exception as e:
print(f"Error loading embedding pickle file: {e}")
exit(1)
load_time = time.time() - start_load_time
print(f"Loaded original embeddings data in {load_time:.2f} seconds.")
# --- Extract and Validate Embeddings ---
try:
if not isinstance(original_embedding_data, (list, tuple)) or len(original_embedding_data) != 2:
raise TypeError("Expected embedding data to be a list or tuple of length 2 (ids, embeddings)")
original_embedding_ids = original_embedding_data[0] # Should be a list/iterable of IDs
original_embeddings = original_embedding_data[1] # Should be a NumPy array
# Ensure IDs are in a list for easier indexing later if they aren't already
if not isinstance(original_embedding_ids, list):
print("Converting embedding IDs to list...")
original_embedding_ids = list(original_embedding_ids)
if not isinstance(original_embeddings, np.ndarray):
raise TypeError("Expected second element of embedding data to be a NumPy array")
print(f"Original data contains {len(original_embedding_ids)} embedding IDs.")
print(f"Original embeddings shape: {original_embeddings.shape}, dtype: {original_embeddings.dtype}")
if len(original_embedding_ids) != original_embeddings.shape[0]:
raise ValueError(f"Mismatch! Number of embedding IDs ({len(original_embedding_ids)}) does not match number of embeddings ({original_embeddings.shape[0]})")
except (TypeError, ValueError, IndexError) as e:
print(f"Error processing loaded embedding data: {e}")
print("Please ensure the embedding pickle file contains: (list_of_passage_ids, numpy_embedding_array)")
exit(1)
total_embeddings = original_embeddings.shape[0]
# --- Load Original Passages ---
# This might take time and memory depending on the dataset size
_, passage_id_to_content_map = load_all_passages(
DOMAIN_NAME, NUM_SHARDS_PASSAGE, ORIGINAL_PASSAGE_FILE_TEMPLATE
)
if not passage_id_to_content_map:
print("Error: No passages were loaded. Cannot proceed with sampling.")
exit(1)
# --- Calculate Sample Size ---
num_samples = math.ceil(total_embeddings * SAMPLE_FRACTION) # Use ceil to get at least 1/60th
print(f"\nTotal original embeddings: {total_embeddings}")
print(f"Sampling fraction: {SAMPLE_FRACTION:.6f} (1/60)")
print(f"Target number of samples: {num_samples}")
if num_samples > total_embeddings:
print("Warning: Calculated sample size exceeds total embeddings. Using all embeddings.")
num_samples = total_embeddings
elif num_samples <= 0:
print("Error: Calculated sample size is zero or negative.")
exit(1)
# --- Perform Random Sampling (Based on Embeddings) ---
print("\nPerforming random sampling based on embeddings...")
start_sample_time = time.time()
# Set a seed for reproducibility if needed
# np.random.seed(42)
# Generate unique random indices from the embeddings list
sampled_indices = np.random.choice(total_embeddings, size=num_samples, replace=False)
# Retrieve the corresponding IDs and embeddings using the sampled indices
sampled_embedding_ids = [original_embedding_ids[i] for i in sampled_indices]
sampled_embeddings = original_embeddings[sampled_indices]
sample_time = time.time() - start_sample_time
print(f"Sampling completed in {sample_time:.2f} seconds.")
print(f"Sampled {len(sampled_embedding_ids)} IDs and embeddings.")
print(f"Sampled embeddings shape: {sampled_embeddings.shape}")
# --- Retrieve Corresponding Passages ---
print("\nRetrieving corresponding passages for sampled IDs...")
start_passage_retrieval_time = time.time()
sampled_passages = []
missing_ids_count = 0
for i, pid in enumerate(sampled_embedding_ids):
# Convert pid to string for lookup in the map
pid_str = str(pid)
if pid_str in passage_id_to_content_map:
sampled_passages.append(passage_id_to_content_map[pid_str])
else:
# This indicates an inconsistency between embedding IDs and passage IDs
print(f"Warning: Passage ID '{pid_str}' (from embedding index {sampled_indices[i]}) not found in passage map.")
missing_ids_count += 1
passage_retrieval_time = time.time() - start_passage_retrieval_time
print(f"Retrieved {len(sampled_passages)} passages in {passage_retrieval_time:.2f} seconds.")
if missing_ids_count > 0:
print(f"Warning: Could not find passages for {missing_ids_count} sampled IDs.")
if not sampled_passages:
print("Error: No corresponding passages found for the sampled embeddings. Check ID matching.")
exit(1)
# --- Prepare Output Data ---
# Embeddings output format: tuple(list_of_ids, numpy_array_of_embeddings)
output_embedding_data = (sampled_embedding_ids, sampled_embeddings)
# Passages output format: list[dict] (matching input shard format)
output_passage_data = sampled_passages
# --- Save Sampled Embeddings ---
print(f"\nSaving sampled embeddings to {OUTPUT_EMBEDDING_FILE}...")
start_save_time = time.time()
try:
with open(OUTPUT_EMBEDDING_FILE, 'wb') as f:
pickle.dump(output_embedding_data, f, protocol=pickle.HIGHEST_PROTOCOL)
except Exception as e:
print(f"Error saving sampled embeddings: {e}")
# Continue to try saving passages if desired, or exit(1)
save_time = time.time() - start_save_time
print(f"Saved sampled embeddings in {save_time:.2f} seconds.")
# --- Save Sampled Passages ---
print(f"\nSaving sampled passages to {OUTPUT_PASSAGE_FILE}...")
start_save_time = time.time()
try:
with open(OUTPUT_PASSAGE_FILE, 'wb') as f:
pickle.dump(output_passage_data, f, protocol=pickle.HIGHEST_PROTOCOL)
except Exception as e:
print(f"Error saving sampled passages: {e}")
exit(1)
save_time = time.time() - start_save_time
print(f"Saved sampled passages in {save_time:.2f} seconds.")
print(f"\nScript finished successfully.")
print(f"Sampled embeddings saved to: {OUTPUT_EMBEDDING_FILE}")
print(f"Sampled passages saved to: {OUTPUT_PASSAGE_FILE}")

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

View File

@@ -0,0 +1,354 @@
import argparse
import sys
import time
import faiss
import numpy as np
import pickle
import os
import json
import time
import torch
from tqdm import tqdm
from pathlib import Path
import subprocess
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(os.path.join(project_root, "demo"))
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS
sys.path.append(project_root)
from contriever.src.contriever import Contriever, load_retriever
M = 32
efConstruction = 256
K_NEIGHBORS = 3
DB_EMBEDDING_FILE = "/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/passages_00.pkl"
INDEX_SAVING_FILE = "/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/indices"
TASK_NAME = "nq"
EMBEDDER_MODEL_NAME = "facebook/contriever-msmarco"
MAX_QUERIES_TO_LOAD = 1000
QUERY_ENCODING_BATCH_SIZE = 64
# 1M samples
print(f"Loading embeddings from {DB_EMBEDDING_FILE}...")
with open(DB_EMBEDDING_FILE, 'rb') as f:
data = pickle.load(f)
xb = data[1]
print(f"Original dtype: {xb.dtype}")
if xb.dtype != np.float32:
print("Converting embeddings to float32.")
xb = xb.astype(np.float32)
else:
print("Embeddings are already float32.")
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
d = xb.shape[1] # Get dimension
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
query_texts = []
print(f"Reading queries from: {query_file_path}")
with open(query_file_path, 'r') as f:
for i, line in enumerate(f):
if i >= MAX_QUERIES_TO_LOAD:
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
break
record = json.loads(line)
query_texts.append(record["query"])
print(f"Loaded {len(query_texts)} query texts.")
print("\nInitializing retriever model for encoding queries...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model, tokenizer, _ = load_retriever(EMBEDDER_MODEL_NAME)
model.to(device)
model.eval() # Set to evaluation mode
print("Retriever model loaded.")
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
"""Embed queries using the model with batching"""
model = model.half()
model.eval()
embeddings = []
batch_question = []
with torch.no_grad():
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
batch_question.append(query)
# Process when batch is full or at the end
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=512,
padding=True,
truncation=True,
)
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
output = model(**encoded_batch)
# Contriever typically uses output.last_hidden_state pooling or something specialized
if "contriever" not in model_name_or_path:
output = output.last_hidden_state[:, 0, :]
embeddings.append(output.cpu())
batch_question = [] # Reset batch
embeddings = torch.cat(embeddings, dim=0).numpy()
print(f"Query embeddings shape: {embeddings.shape}")
return embeddings
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_MODEL_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
# Ensure float32 for Faiss compatibility after encoding
if xq_full.dtype != np.float32:
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
xq_full = xq_full.astype(np.float32)
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
# Check dimension consistency
if xq_full.shape[1] != d:
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
# loading index_flat from cache
cache_file = f"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json"
recall_idx_flat = None
if os.path.exists(cache_file):
print(f"Loading cached FLAT index results from {cache_file}...")
start_time = time.time()
with open(cache_file, 'r') as f:
cached_results = json.load(f)
D_flat = np.array(cached_results["distances"])
recall_idx_flat = np.array(cached_results["indices"])
end_time = time.time()
print(f"Loaded cached results in {end_time - start_time:.3f} seconds")
else:
print("\nBuilding FlatIP index for ground truth...")
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
index_flat.add(xb)
print(f"Searching FlatIP index with {len(xq_full)} queries (k={K_NEIGHBORS})...")
start_time = time.time()
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
end_time = time.time()
print(f"Time taken for FLAT index search: {end_time - start_time:.3f} seconds")
# Save results to cache
# with open(cache_file, 'w') as f:
# json.dump({
# "distances": D_flat.tolist(),
# "indices": recall_idx_flat.tolist(),
# "metadata": {
# "task": TASK_NAME,
# "k": K_NEIGHBORS,
# "timestamp": time.strftime("%Y%m%d_%H%M%S"),
# }
# }, f)
# print(f"Cached FLAT index results to {cache_file}")
# print(recall_idx_flat)
# Create a specific directory for this index configuration
# index_dir = f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}"
# os.makedirs(index_of, exist_ok=True)
parser = argparse.ArgumentParser()
# parser.add_argument("--index-dir", type=str, default=f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}")
parser.add_argument("--index-file", type=str, default=f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}/index.faiss")
args = parser.parse_args()
index_filename = args.index_file
index_dir = os.path.dirname(index_filename)
os.makedirs(index_dir, exist_ok=True)
# Check if index already exists
if os.path.exists(index_filename):
print(f"Found existing index at {index_filename}, loading...")
index = faiss.read_index(index_filename)
print("Index loaded successfully.")
else:
assert False, "Index does not exist"
print(f'Building {"NSG" if "nsg" in index_filename else "HNSW"} index (IP)...')
# add build time
start_time = time.time()
if 'nsg' in index_filename:
index = faiss.IndexNSGFlat(d, M, faiss.METRIC_INNER_PRODUCT)
index.verbose = True
else:
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = efConstruction
index.hnsw.set_percentile_thresholds()
index.add(xb)
end_time = time.time()
print(f'time: {end_time - start_time}')
print(f'{"NSG" if "nsg" in index_filename else "HNSW"} index built.')
# Save the index
print(f"Saving index to {index_filename}...")
faiss.write_index(index, index_filename)
print("Index saved successfully.")
# Analyze the index
print("\nAnalyzing index...")
print(f"Total number of nodes: {index.ntotal}")
print("Neighbor statistics:")
if 'nsg' in index_filename:
print(index.nsg.print_neighbor_stats(0))
else:
print(index.hnsw.print_neighbor_stats(0))
# Save degree distribution
distribution_filename = f"{index_dir}/degree_distribution.txt"
print(f"Saving degree distribution to {distribution_filename}...")
if 'nsg' in index_filename:
index.nsg.save_degree_distribution(distribution_filename)
else:
index.hnsw.save_degree_distribution(0, distribution_filename)
print("Degree distribution saved successfully.")
# Plot the degree distribution
plot_output_path = f"{index_dir}/degree_distribution.png"
print(f"Generating degree distribution plot to {plot_output_path}...")
try:
subprocess.run(
["python", "/home/ubuntu/Power-RAG/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
check=True
)
print(f"Degree distribution plot saved to {plot_output_path}")
except subprocess.CalledProcessError as e:
print(f"Error generating degree distribution plot: {e}")
except FileNotFoundError:
print("Warning: plot_degree_distribution.py script not found in current directory")
print('Searching HNSW index...')
# for efSearch in [2, 4, 8, 16, 32, 64,128,256,512,1024]:
# print(f'*************efSearch: {efSearch}*************')
# for i in range(10):
# index.hnsw.efSearch = efSearch
# D, I = index.search(xq_full[i:i+1], K_NEIGHBORS)
# exit()
recall_result_file = f"{index_dir}/recall_result.txt"
time_list = []
recall_list = []
recompute_list = []
with open(recall_result_file, 'w') as f:
for efSearch in [2, 4, 8, 16, 24, 32, 48, 64, 96,114,128,144,160,176,192,208,224,240,256,384,420,440,460,480,512,768,1024,1152,1536,1792,2048,2230,2408,2880]:
if 'nsg' in index_filename:
index.nsg.efSearch = efSearch
else:
index.hnsw.efSearch = efSearch
# calculate the time of searching
start_time = time.time()
if not ('nsg' in index_filename):
faiss.cvar.hnsw_stats.reset()
else:
faiss.cvar.nsg_stats.reset()
D, I = index.search(xq_full, K_NEIGHBORS)
print('D[0]:', D[0])
end_time = time.time()
print(f'time: {end_time - start_time}')
time_list.append(end_time - start_time)
if 'nsg' in index_filename:
print("recompute:", faiss.cvar.nsg_stats.ndis/len(I))
recompute_list.append(faiss.cvar.nsg_stats.ndis/len(I))
else:
print("recompute:", faiss.cvar.hnsw_stats.ndis/len(I))
recompute_list.append(faiss.cvar.hnsw_stats.ndis/len(I))
# print(I)
# calculate the recall using the flat index the formula:
# recall = sum(recall_idx == recall_idx_flat) / len(recall_idx)
recall=[]
for i in range(len(I)):
acc = 0
for j in range(len(I[i])):
if I[i][j] in recall_idx_flat[i]:
acc += 1
recall.append(acc / len(I[i]))
recall = sum(recall) / len(recall)
recall_list.append(recall)
print(f'efSearch: {efSearch}')
print(f'recall: {recall}')
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
print(f'Done and result saved to {recall_result_file}')
print(f'time_list: {time_list}')
print(f'recall_list: {recall_list}')
print(f'recompute_list: {recompute_list}')
exit()
# Analyze edge stats
print("\nAnalyzing edge statistics...")
edge_stats_file = f"{index_dir}/edge_stats.txt"
if not os.path.exists(edge_stats_file):
index.save_edge_stats(edge_stats_file)
print(f'Edge stats saved to {edge_stats_file}')
else:
print(f'Edge stats already exists at {edge_stats_file}')
def analyze_edge_stats(filename):
"""
Analyze edge statistics from a CSV file and print thresholds at various percentiles.
Args:
filename: Path to the edge statistics CSV file
"""
if not os.path.exists(filename):
print(f"Error: File {filename} does not exist")
return
print(f"Analyzing edge statistics from {filename}...")
# Read the file
distances = []
with open(filename, 'r') as f:
# Skip header
header = f.readline()
# Read all edges
for line in f:
parts = line.strip().split(',')
if len(parts) >= 4:
try:
src = int(parts[0])
dst = int(parts[1])
level = int(parts[2])
distance = float(parts[3])
distances.append(distance)
except ValueError:
continue
if not distances:
print("No valid edges found in file")
return
# Sort distances
distances = np.array(distances)
distances.sort()
# Calculate and print statistics
print(f"Total edges: {len(distances)}")
print(f"Min distance: {distances[0]:.6f}")
print(f"Max distance: {distances[-1]:.6f}")
# Print thresholds at specified percentiles
percentiles = [0.5, 1, 2, 3, 5, 8, 10, 15, 20,30,40,50,60,70]
print("\nDistance thresholds at percentiles:")
for p in percentiles:
idx = int(len(distances) * p / 100)
if idx < len(distances):
print(f"{p:.1f}%: {distances[idx]:.6f}")
return distances
analyze_edge_stats(edge_stats_file)

View File

@@ -0,0 +1,194 @@
import re
import matplotlib.pyplot as plt
import numpy as np
def extract_data_from_log(log_content):
"""Extract method names, recall lists, and recompute lists from the log file."""
# Regular expressions to find the data - modified to match the actual format
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
# Find all matches
method_matches = re.findall(method_pattern, log_content)
methods = []
for match in method_matches:
# Each match is a tuple with one empty string and one with the method name
method = match[0] if match[0] else match[1]
methods.append(method)
recall_lists_str = re.findall(recall_list_pattern, log_content)
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
avg_neighbors = re.findall(avg_neighbors_pattern, log_content)
# Debug information
print(f"Found {len(methods)} methods: {methods}")
print(f"Found {len(recall_lists_str)} recall lists")
print(f"Found {len(recompute_lists_str)} recompute lists")
print(f"Found {len(avg_neighbors)} average neighbors values")
# If the regex approach fails, try a more direct approach
if len(methods) < 5:
print("Regex approach failed, trying direct extraction...")
sections = log_content.split("Building HNSW index with ")[1:]
methods = []
for section in sections:
# Extract the method name (everything up to the first newline)
method_name = section.split("\n")[0].strip()
# Remove trailing dots if present
method_name = method_name.rstrip('.')
methods.append(method_name)
print(f"Direct extraction found {len(methods)} methods: {methods}")
# Convert string representations of lists to actual lists
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
# Convert average neighbors to float
avg_neighbors = [float(avg) for avg in avg_neighbors]
# Make sure we have the same number of items in each list
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
if min_length < 5:
print(f"Warning: Expected 5 methods, but only found {min_length}")
# Ensure all lists have the same length
methods = methods[:min_length]
recall_lists = recall_lists[:min_length]
recompute_lists = recompute_lists[:min_length]
avg_neighbors = avg_neighbors[:min_length]
return methods, recall_lists, recompute_lists, avg_neighbors
def plot_performance(methods, recall_lists, recompute_lists, avg_neighbors):
"""Create a plot comparing the performance of different HNSW methods."""
plt.figure(figsize=(12, 8))
colors = ['blue', 'green', 'red', 'purple', 'orange']
markers = ['o', 's', '^', 'x', 'd']
for i, method in enumerate(methods):
# Add average neighbors to the label
label = f"{method} (avg. {avg_neighbors[i]} neighbors)"
plt.plot(recompute_lists[i], recall_lists[i], label=label,
color=colors[i], marker=markers[i], markersize=8, markevery=5)
plt.xlabel('Distance Computations', fontsize=14)
plt.ylabel('Recall', fontsize=14)
plt.title('HNSW Index Performance: Recall vs. Computation Cost', fontsize=16)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=12)
plt.xscale('log')
plt.ylim(0, 1.0)
# Add horizontal lines for different recall levels
recall_levels = [0.90, 0.95, 0.96, 0.97, 0.98]
line_styles = [':', '--', '-.', '-', '-']
line_widths = [1, 1, 1, 1.5, 1.5]
for i, level in enumerate(recall_levels):
plt.axhline(y=level, color='gray', linestyle=line_styles[i],
alpha=0.7, linewidth=line_widths[i])
plt.text(130, level + 0.002, f'{level*100:.0f}% Recall', fontsize=10)
plt.tight_layout()
plt.savefig('faiss/demo/H_hnsw_performance_comparison.png')
plt.show()
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors):
"""Create a bar chart comparing computation costs at different recall levels."""
recall_levels = [0.90, 0.95, 0.96, 0.97, 0.98]
# Get computation costs for each method at each recall level
computation_costs = []
for i, method in enumerate(methods):
method_costs = []
for level in recall_levels:
# Find the first index where recall exceeds the target level
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
if recall_idx is not None:
method_costs.append(recompute_lists[i][recall_idx])
else:
# If the method doesn't reach this recall level, use None
method_costs.append(None)
computation_costs.append(method_costs)
# Set up the plot
fig, ax = plt.subplots(figsize=(14, 8))
# Set width of bars
bar_width = 0.15
# Set positions of the bars on X axis
r = np.arange(len(recall_levels))
# Colors for each method
colors = ['blue', 'green', 'red', 'purple', 'orange']
# Create bars
for i, method in enumerate(methods):
# Filter out None values
valid_costs = [cost if cost is not None else 0 for cost in computation_costs[i]]
valid_positions = [pos for pos, cost in zip(r + i*bar_width, computation_costs[i]) if cost is not None]
valid_costs = [cost for cost in computation_costs[i] if cost is not None]
bars = ax.bar(valid_positions, valid_costs, width=bar_width,
color=colors[i], label=f"{method} (avg. {avg_neighbors[i]} neighbors)")
# Add value labels on top of bars
for j, bar in enumerate(bars):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 500,
f'{height:.0f}', ha='center', va='bottom', rotation=0, fontsize=9)
# Add labels and title
ax.set_xlabel('Recall Level', fontsize=14)
ax.set_ylabel('Distance Computations', fontsize=14)
ax.set_title('Computation Cost Required to Achieve Different Recall Levels', fontsize=16)
# Set x-ticks
ax.set_xticks(r + bar_width * 2)
ax.set_xticklabels([f'{level*100:.0f}%' for level in recall_levels])
# Add legend
ax.legend(fontsize=12)
# Add grid
ax.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('faiss/demo/H_hnsw_recall_comparison.png')
plt.show()
# Read the log file
with open('faiss/demo/output.log', 'r') as f:
log_content = f.read()
# Extract data
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
# Plot the results
plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
# Plot the recall comparison
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors)
# Print a summary of the methods and their characteristics
print("\nMethod Summary:")
for i, method in enumerate(methods):
print(f"{method}:")
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
# Find the recompute values needed for different recall levels
recall_levels = [0.90, 0.95, 0.96, 0.97, 0.98]
for level in recall_levels:
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
if recall_idx is not None:
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.2f}")
else:
print(f" - Does not reach {level*100:.0f}% recall in the test")
print()

View File

@@ -0,0 +1,202 @@
import os
import re
import matplotlib.pyplot as plt
import numpy as np
import argparse
RECALL_LEVELS = [0.85, 0.90, 0.93, 0.94, 0.95, 0.96]
def extract_data_from_log(log_content):
"""Extract method names, recall lists, and recompute lists from the log file."""
# Regular expressions to find the dataz - modified to match the actual format
method_pattern = r"Building HNSW index with (.*)\.\.\.|Building HNSW index with ([^\n]+)..."
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
# Find all matches
method_matches = re.findall(method_pattern, log_content)
methods = []
for match in method_matches:
# Each match is a tuple with one empty string and one with the method name
method = match[0] if match[0] else match[1]
methods.append(method)
recall_lists_str = re.findall(recall_list_pattern, log_content)
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
avg_neighbors = re.findall(avg_neighbors_pattern, log_content)
# Debug information
print(f"Found {len(methods)} methods: {methods}")
print(f"Found {len(recall_lists_str)} recall lists")
print(f"Found {len(recompute_lists_str)} recompute lists")
print(f"Found {len(avg_neighbors)} average neighbors values")
# If the regex approach fails, try a more direct approach
if len(methods) < 5:
print("Regex approach failed, trying direct extraction...")
sections = log_content.split("Building HNSW index with ")[1:]
methods = []
for section in sections:
# Extract the method name (everything up to the first newline)
method_name = section.split("\n")[0].strip()
# Remove trailing dots if present
method_name = method_name.rstrip('.')
methods.append(method_name)
print(f"Direct extraction found {len(methods)} methods: {methods}")
# Convert string representations of lists to actual lists
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
# Convert average neighbors to float
avg_neighbors = [float(avg) for avg in avg_neighbors]
# Make sure we have the same number of items in each list
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
if min_length < 5:
print(f"Warning: Expected 5 methods, but only found {min_length}")
# Ensure all lists have the same length
methods = methods[:min_length]
recall_lists = recall_lists[:min_length]
recompute_lists = recompute_lists[:min_length]
avg_neighbors = avg_neighbors[:min_length]
return methods, recall_lists, recompute_lists, avg_neighbors
def plot_performance(methods, recall_lists, recompute_lists, avg_neighbors):
"""Create a plot comparing the performance of different HNSW methods."""
plt.figure(figsize=(12, 8))
colors = ['blue', 'green', 'red', 'purple', 'orange']
markers = ['o', 's', '^', 'x', 'd']
for i, method in enumerate(methods):
# Add average neighbors to the label
label = f"{method} (avg. {avg_neighbors[i]} neighbors)"
plt.plot(recompute_lists[i], recall_lists[i], label=label,
color=colors[i], marker=markers[i], markersize=8, markevery=5)
plt.xlabel('Distance Computations', fontsize=14)
plt.ylabel('Recall', fontsize=14)
plt.title('HNSW Index Performance: Recall vs. Computation Cost', fontsize=16)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=12)
plt.xscale('log')
plt.ylim(0, 1.0)
# Add horizontal lines for different recall levels
line_styles = [':', '--', '-.', '-', '-', ':']
line_widths = [1, 1, 1, 1.5, 1.5, 1]
for i, level in enumerate(RECALL_LEVELS):
plt.axhline(y=level, color='gray', linestyle=line_styles[i],
alpha=0.7, linewidth=line_widths[i])
plt.text(130, level + 0.002, f'{level*100:.0f}% Recall', fontsize=10)
plt.tight_layout()
plt.savefig(os.path.join(os.path.dirname(__file__), 'H_hnsw_performance_comparison.png'))
plt.show()
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors):
"""Create a bar chart comparing computation costs at different recall levels."""
# Get computation costs for each method at each recall level
computation_costs = []
for i, method in enumerate(methods):
method_costs = []
for level in RECALL_LEVELS:
# Find the first index where recall exceeds the target level
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
if recall_idx is not None:
method_costs.append(recompute_lists[i][recall_idx])
else:
# If the method doesn't reach this recall level, use None
method_costs.append(None)
computation_costs.append(method_costs)
# Set up the plot
fig, ax = plt.subplots(figsize=(14, 8))
# Set width of bars
bar_width = 0.15
# Set positions of the bars on X axis
r = np.arange(len(RECALL_LEVELS))
# Colors for each method
colors = ['blue', 'green', 'red', 'purple', 'orange']
# Create bars
for i, method in enumerate(methods):
# Filter out None values
valid_costs = [cost if cost is not None else 0 for cost in computation_costs[i]]
valid_positions = [pos for pos, cost in zip(r + i*bar_width, computation_costs[i]) if cost is not None]
valid_costs = [cost for cost in computation_costs[i] if cost is not None]
bars = ax.bar(valid_positions, valid_costs, width=bar_width,
color=colors[i], label=f"{method} (avg. {avg_neighbors[i]} neighbors)")
# Add value labels on top of bars
for j, bar in enumerate(bars):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 500,
f'{height:.0f}', ha='center', va='bottom', rotation=0, fontsize=9)
# Add labels and title
ax.set_xlabel('Recall Level', fontsize=14)
ax.set_ylabel('Distance Computations', fontsize=14)
ax.set_title('Computation Cost Required to Achieve Different Recall Levels', fontsize=16)
# Set x-ticks
ax.set_xticks(r + bar_width * 2)
ax.set_xticklabels([f'{level*100:.0f}%' for level in RECALL_LEVELS])
# Add legend
ax.legend(fontsize=12)
# Add grid
ax.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(os.path.join(os.path.dirname(__file__), 'H_hnsw_recall_comparison.png'))
plt.show()
# Read the log file
parser = argparse.ArgumentParser()
parser.add_argument('log_file', type=str, default='nlevel_output.log')
args = parser.parse_args()
log_file = args.log_file
log_file = os.path.join(os.path.dirname(__file__), log_file)
with open(log_file, 'r') as f:
log_content = f.read()
# Extract data
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
# Plot the results
plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
# Plot the recall comparison
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors)
# Print a summary of the methods and their characteristics
print("\nMethod Summary:")
for i, method in enumerate(methods):
print(f"{method}:")
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
# Find the recompute values needed for different recall levels
for level in RECALL_LEVELS:
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
if recall_idx is not None:
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.2f}")
else:
print(f" - Does not reach {level*100:.0f}% recall in the test")
print()

View File

@@ -0,0 +1,329 @@
import sys
import time
import faiss
import numpy as np
import pickle
import os
import json
import time
import torch
import argparse
from tqdm import tqdm
from pathlib import Path
import subprocess
# Add argument parsing
parser = argparse.ArgumentParser(description='Build and evaluate HNSW index')
parser.add_argument('--config', type=str, default="0.02per_M6-7_degree_based",
help='Configuration name for the index (default: 0.02per_M6-7_degree_based)')
parser.add_argument('--M', type=int, default=32,
help='HNSW M parameter (default: 32)')
parser.add_argument('--efConstruction', type=int, default=256,
help='HNSW efConstruction parameter (default: 256)')
parser.add_argument('--K_NEIGHBORS', type=int, default=3,
help='Number of neighbors to retrieve (default: 3)')
parser.add_argument('--max_queries', type=int, default=1000,
help='Maximum number of queries to load (default: 1000)')
parser.add_argument('--batch_size', type=int, default=64,
help='Batch size for query encoding (default: 64)')
parser.add_argument('--db_embedding_file', type=str,
default="/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/passages_00.pkl",
help='Path to database embedding file')
parser.add_argument('--index_saving_dir', type=str,
default="/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/indices",
help='Directory to save the index')
parser.add_argument('--task_name', type=str, default="nq",
help='Task name from TASK_CONFIGS (default: nq)')
parser.add_argument('--embedder_model', type=str, default="facebook/contriever-msmarco",
help='Model name for query embedding (default: facebook/contriever-msmarco)')
args = parser.parse_args()
# Replace hardcoded constants with arguments
M = args.M
efConstruction = args.efConstruction
K_NEIGHBORS = args.K_NEIGHBORS
DB_EMBEDDING_FILE = args.db_embedding_file
INDEX_SAVING_FILE = args.index_saving_dir
TASK_NAME = args.task_name
EMBEDDER_MODEL_NAME = args.embedder_model
MAX_QUERIES_TO_LOAD = args.max_queries
QUERY_ENCODING_BATCH_SIZE = args.batch_size
CONFIG_NAME = args.config
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(os.path.join(project_root, "demo"))
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS
sys.path.append(project_root)
from contriever.src.contriever import Contriever, load_retriever
# 1M samples
print(f"Loading embeddings from {DB_EMBEDDING_FILE}...")
with open(DB_EMBEDDING_FILE, 'rb') as f:
data = pickle.load(f)
xb = data[1]
print(f"Original dtype: {xb.dtype}")
if xb.dtype != np.float32:
print("Converting embeddings to float32.")
xb = xb.astype(np.float32)
else:
print("Embeddings are already float32.")
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
d = xb.shape[1] # Get dimension
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
query_texts = []
print(f"Reading queries from: {query_file_path}")
with open(query_file_path, 'r') as f:
for i, line in enumerate(f):
if i >= MAX_QUERIES_TO_LOAD:
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
break
record = json.loads(line)
query_texts.append(record["query"])
print(f"Loaded {len(query_texts)} query texts.")
print("\nInitializing retriever model for encoding queries...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model, tokenizer, _ = load_retriever(EMBEDDER_MODEL_NAME)
model.to(device)
model.eval() # Set to evaluation mode
print("Retriever model loaded.")
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
"""Embed queries using the model with batching"""
model = model.half()
model.eval()
embeddings = []
batch_question = []
with torch.no_grad():
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
batch_question.append(query)
# Process when batch is full or at the end
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=512,
padding=True,
truncation=True,
)
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
output = model(**encoded_batch)
# Contriever typically uses output.last_hidden_state pooling or something specialized
if "contriever" not in model_name_or_path:
output = output.last_hidden_state[:, 0, :]
embeddings.append(output.cpu())
batch_question = [] # Reset batch
embeddings = torch.cat(embeddings, dim=0).numpy()
print(f"Query embeddings shape: {embeddings.shape}")
return embeddings
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_MODEL_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
# Ensure float32 for Faiss compatibility after encoding
if xq_full.dtype != np.float32:
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
xq_full = xq_full.astype(np.float32)
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
# Check dimension consistency
if xq_full.shape[1] != d:
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
# recall_idx = []
print("\nBuilding FlatIP index for ground truth...")
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
index_flat.add(xb)
print(f"Searching FlatIP index with {MAX_QUERIES_TO_LOAD} queries (k={K_NEIGHBORS})...")
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
# print(recall_idx_flat)
# Create a specific directory for this index configuration
index_dir = f"{INDEX_SAVING_FILE}/{CONFIG_NAME}_hnsw_IP_M{M}_efC{efConstruction}"
if CONFIG_NAME == "origin":
index_dir = f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}"
os.makedirs(index_dir, exist_ok=True)
index_filename = f"{index_dir}/index.faiss"
# Check if index already exists
if os.path.exists(index_filename):
print(f"Found existing index at {index_filename}, loading...")
index = faiss.read_index(index_filename)
print("Index loaded successfully.")
else:
print('Building HNSW index (IP)...')
# add build time
start_time = time.time()
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = efConstruction
index.hnsw.set_percentile_thresholds()
index.add(xb)
end_time = time.time()
print(f'time: {end_time - start_time}')
print('HNSW index built.')
# Save the HNSW index
print(f"Saving index to {index_filename}...")
faiss.write_index(index, index_filename)
print("Index saved successfully.")
# Analyze the HNSW index
print("\nAnalyzing HNSW index...")
print(f"Total number of nodes: {index.ntotal}")
print("Neighbor statistics:")
print(index.hnsw.print_neighbor_stats(0))
# Save degree distribution
distribution_filename = f"{index_dir}/degree_distribution.txt"
print(f"Saving degree distribution to {distribution_filename}...")
index.hnsw.save_degree_distribution(0, distribution_filename)
print("Degree distribution saved successfully.")
# Plot the degree distribution
plot_output_path = f"{index_dir}/degree_distribution.png"
print(f"Generating degree distribution plot to {plot_output_path}...")
try:
subprocess.run(
["python", "/home/ubuntu/Power-RAG/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
check=True
)
print(f"Degree distribution plot saved to {plot_output_path}")
except subprocess.CalledProcessError as e:
print(f"Error generating degree distribution plot: {e}")
except FileNotFoundError:
print("Warning: plot_degree_distribution.py script not found in current directory")
print('Searching HNSW index...')
# for efSearch in [2, 4, 8, 16, 32, 64,128,256,512,1024]:
# print(f'*************efSearch: {efSearch}*************')
# for i in range(10):
# index.hnsw.efSearch = efSearch
# D, I = index.search(xq_full[i:i+1], K_NEIGHBORS)
# exit()
recall_result_file = f"{index_dir}/recall_result.txt"
time_list = []
recall_list = []
recompute_list = []
with open(recall_result_file, 'w') as f:
for efSearch in [2, 4, 8, 16, 24, 32, 48, 64, 96,114,128,144,160,176,192,208,224,240,256,384,420,440,460,480,512,768,1024,1152,1536,1792,2048,2230,2408,2880]:
index.hnsw.efSearch = efSearch
# calculate the time of searching
start_time = time.time()
faiss.cvar.hnsw_stats.reset()
# print faiss.cvar.hnsw_stats.ndis
print(f'ndis: {faiss.cvar.hnsw_stats.ndis}')
D, I = index.search(xq_full, K_NEIGHBORS)
print('D[0]:', D[0])
end_time = time.time()
print(f'time: {end_time - start_time}')
time_list.append(end_time - start_time)
print("recompute:", faiss.cvar.hnsw_stats.ndis/len(I))
recompute_list.append(faiss.cvar.hnsw_stats.ndis/len(I))
# print(I)
# calculate the recall using the flat index the formula:
# recall = sum(recall_idx == recall_idx_flat) / len(recall_idx)
recall=[]
for i in range(len(I)):
acc = 0
for j in range(len(I[i])):
if I[i][j] in recall_idx_flat[i]:
acc += 1
recall.append(acc / len(I[i]))
recall = sum(recall) / len(recall)
recall_list.append(recall)
print(f'efSearch: {efSearch}')
print(f'recall: {recall}')
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
print(f'Done and result saved to {recall_result_file}')
print(f'time_list: {time_list}')
print(f'recall_list: {recall_list}')
print(f'recompute_list: {recompute_list}')
exit()
# Analyze edge stats
print("\nAnalyzing edge statistics...")
edge_stats_file = f"{index_dir}/edge_stats.txt"
if not os.path.exists(edge_stats_file):
index.save_edge_stats(edge_stats_file)
print(f'Edge stats saved to {edge_stats_file}')
else:
print(f'Edge stats already exists at {edge_stats_file}')
def analyze_edge_stats(filename):
"""
Analyze edge statistics from a CSV file and print thresholds at various percentiles.
Args:
filename: Path to the edge statistics CSV file
"""
if not os.path.exists(filename):
print(f"Error: File {filename} does not exist")
return
print(f"Analyzing edge statistics from {filename}...")
# Read the file
distances = []
with open(filename, 'r') as f:
# Skip header
header = f.readline()
# Read all edges
for line in f:
parts = line.strip().split(',')
if len(parts) >= 4:
try:
src = int(parts[0])
dst = int(parts[1])
level = int(parts[2])
distance = float(parts[3])
distances.append(distance)
except ValueError:
continue
if not distances:
print("No valid edges found in file")
return
# Sort distances
distances = np.array(distances)
distances.sort()
# Calculate and print statistics
print(f"Total edges: {len(distances)}")
print(f"Min distance: {distances[0]:.6f}")
print(f"Max distance: {distances[-1]:.6f}")
# Print thresholds at specified percentiles
percentiles = [0.5, 1, 2, 3, 5, 8, 10, 15, 20,30,40,50,60,70]
print("\nDistance thresholds at percentiles:")
for p in percentiles:
idx = int(len(distances) * p / 100)
if idx < len(distances):
print(f"{p:.1f}%: {distances[idx]:.6f}")
return distances
analyze_edge_stats(edge_stats_file)

View File

@@ -0,0 +1,212 @@
import sys
import time
import faiss
import numpy as np
import pickle
import os
import json
import time
import torch
from tqdm import tqdm
from pathlib import Path
import subprocess
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(os.path.join(project_root, "demo"))
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS, get_embedding_path
sys.path.append(project_root)
from contriever.src.contriever import Contriever, load_retriever
M = 32
efConstruction = 256
K_NEIGHBORS = 3
# Original configuration (commented out)
# DB_EMBEDDING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/passages_00.pkl"
# INDEX_SAVING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices"
# New configuration using DPR
DOMAIN_NAME = "dpr"
EMBEDDER_NAME = "facebook/contriever-msmarco"
TASK_NAME = "nq"
MAX_QUERIES_TO_LOAD = 1000
QUERY_ENCODING_BATCH_SIZE = 64
# Get the embedding path using the function from config
embed_path = get_embedding_path(DOMAIN_NAME, EMBEDDER_NAME, 0)
INDEX_SAVING_FILE = os.path.join(os.path.dirname(embed_path), "indices")
os.makedirs(INDEX_SAVING_FILE, exist_ok=True)
# Load embeddings
print(f"Loading embeddings from {embed_path}...")
with open(embed_path, 'rb') as f:
data = pickle.load(f)
xb = data[1]
print(f"Original dtype: {xb.dtype}")
if xb.dtype != np.float32:
print("Converting embeddings to float32.")
xb = xb.astype(np.float32)
else:
print("Embeddings are already float32.")
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
d = xb.shape[1] # Get dimension
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
query_texts = []
print(f"Reading queries from: {query_file_path}")
with open(query_file_path, 'r') as f:
for i, line in enumerate(f):
if i >= MAX_QUERIES_TO_LOAD:
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
break
record = json.loads(line)
query_texts.append(record["query"])
print(f"Loaded {len(query_texts)} query texts.")
print("\nInitializing retriever model for encoding queries...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model, tokenizer, _ = load_retriever(EMBEDDER_NAME)
model.to(device)
model.eval() # Set to evaluation mode
print("Retriever model loaded.")
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
"""Embed queries using the model with batching"""
model = model.half()
model.eval()
embeddings = []
batch_question = []
with torch.no_grad():
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
batch_question.append(query)
# Process when batch is full or at the end
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=512,
padding=True,
truncation=True,
)
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
output = model(**encoded_batch)
# Contriever typically uses output.last_hidden_state pooling or something specialized
if "contriever" not in model_name_or_path:
output = output.last_hidden_state[:, 0, :]
embeddings.append(output.cpu())
batch_question = [] # Reset batch
embeddings = torch.cat(embeddings, dim=0).numpy()
print(f"Query embeddings shape: {embeddings.shape}")
return embeddings
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
# Ensure float32 for Faiss compatibility after encoding
if xq_full.dtype != np.float32:
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
xq_full = xq_full.astype(np.float32)
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
# Check dimension consistency
if xq_full.shape[1] != d:
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
# Build flat index for ground truth
print("\nBuilding FlatIP index for ground truth...")
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
index_flat.add(xb)
print(f"Searching FlatIP index with {len(xq_full)} queries (k={K_NEIGHBORS})...")
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
# Create a specific directory for this index configuration
index_dir = f"{INDEX_SAVING_FILE}/hahahdpr_hnsw_IP_M{M}_efC{efConstruction}"
os.makedirs(index_dir, exist_ok=True)
index_filename = f"{index_dir}/index.faiss"
# Check if index already exists
if os.path.exists(index_filename):
print(f"Found existing index at {index_filename}, loading...")
index = faiss.read_index(index_filename)
print("Index loaded successfully.")
else:
print('Building HNSW index (IP)...')
# add build time
start_time = time.time()
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = efConstruction
index.add(xb)
end_time = time.time()
print(f'time: {end_time - start_time}')
print('HNSW index built.')
# Save the HNSW index
print(f"Saving index to {index_filename}...")
faiss.write_index(index, index_filename)
print("Index saved successfully.")
# Analyze the HNSW index
print("\nAnalyzing HNSW index...")
print(f"Total number of nodes: {index.ntotal}")
print("Neighbor statistics:")
print(index.hnsw.print_neighbor_stats(0))
# Save degree distribution
distribution_filename = f"{index_dir}/degree_distribution.txt"
print(f"Saving degree distribution to {distribution_filename}...")
index.hnsw.save_degree_distribution(0, distribution_filename)
print("Degree distribution saved successfully.")
# Plot the degree distribution
plot_output_path = f"{index_dir}/degree_distribution.png"
print(f"Generating degree distribution plot to {plot_output_path}...")
try:
subprocess.run(
["python", f"{project_root}/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
check=True
)
print(f"Degree distribution plot saved to {plot_output_path}")
except subprocess.CalledProcessError as e:
print(f"Error generating degree distribution plot: {e}")
except FileNotFoundError:
print("Warning: plot_degree_distribution.py script not found in specified path")
print('Searching HNSW index...')
recall_result_file = f"{index_dir}/recall_result.txt"
with open(recall_result_file, 'w') as f:
for efSearch in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
index.hnsw.efSearch = efSearch
# calculate the time of searching
start_time = time.time()
D, I = index.search(xq_full, K_NEIGHBORS)
end_time = time.time()
print(f'time: {end_time - start_time}')
# calculate the recall using the flat index
recall = []
for i in range(len(I)):
acc = 0
for j in range(len(I[i])):
if I[i][j] in recall_idx_flat[i]:
acc += 1
recall.append(acc / len(I[i]))
recall = sum(recall) / len(recall)
print(f'efSearch: {efSearch}')
print(f'recall: {recall}')
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
print(f'Done and result saved to {recall_result_file}')

View File

@@ -0,0 +1,212 @@
import sys
import time
import faiss
import numpy as np
import pickle
import os
import json
import time
import torch
from tqdm import tqdm
from pathlib import Path
import subprocess
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(os.path.join(project_root, "demo"))
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS, get_embedding_path
sys.path.append(project_root)
from contriever.src.contriever import Contriever, load_retriever
M = 20
efConstruction = 256
K_NEIGHBORS = 3
# New configuration using DPR
DOMAIN_NAME = "rpj_wiki"
EMBEDDER_NAME = "facebook/contriever-msmarco"
TASK_NAME = "nq"
MAX_QUERIES_TO_LOAD = 1000
QUERY_ENCODING_BATCH_SIZE = 64
# Get the embedding path using the function from config
# embed_path = get_embedding_path(DOMAIN_NAME, EMBEDDER_NAME, 0)
# INDEX_SAVING_FILE = os.path.join(os.path.dirname(embed_path), "indices")
# os.makedirs(INDEX_SAVING_FILE, exist_ok=True)
# Original configuration (commented out)
embed_path = "/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/passages_00.pkl"
INDEX_SAVING_FILE = "/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/indices"
# Load embeddings
print(f"Loading embeddings from {extend_path}...")
with open(embed_path, 'rb') as f:
data = pickle.load(f)
xb = data[1]
print(f"Original dtype: {xb.dtype}")
if xb.dtype != np.float32:
print("Converting embeddings to float32.")
xb = xb.astype(np.float32)
else:
print("Embeddings are already float32.")
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
d = xb.shape[1] # Get dimension
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
query_texts = []
print(f"Reading queries from: {query_file_path}")
with open(query_file_path, 'r') as f:
for i, line in enumerate(f):
if i >= MAX_QUERIES_TO_LOAD:
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
break
record = json.loads(line)
query_texts.append(record["query"])
print(f"Loaded {len(query_texts)} query texts.")
print("\nInitializing retriever model for encoding queries...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model, tokenizer, _ = load_retriever(EMBEDDER_NAME)
model.to(device)
model.eval() # Set to evaluation mode
print("Retriever model loaded.")
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
"""Embed queries using the model with batching"""
model = model.half()
model.eval()
embeddings = []
batch_question = []
with torch.no_grad():
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
batch_question.append(query)
# Process when batch is full or at the end
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=512,
padding=True,
truncation=True,
)
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
output = model(**encoded_batch)
# Contriever typically uses output.last_hidden_state pooling or something specialized
if "contriever" not in model_name_or_path:
output = output.last_hidden_state[:, 0, :]
embeddings.append(output.cpu())
batch_question = [] # Reset batch
embeddings = torch.cat(embeddings, dim=0).numpy()
print(f"Query embeddings shape: {embeddings.shape}")
return embeddings
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
# Ensure float32 for Faiss compatibility after encoding
if xq_full.dtype != np.float32:
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
xq_full = xq_full.astype(np.float32)
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
# Check dimension consistency
if xq_full.shape[1] != d:
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
# Build flat index for ground truth
print("\nBuilding FlatIP index for ground truth...")
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
index_flat.add(xb)
print(f"Searching FlatIP index with {len(xq_full)} queries (k={K_NEIGHBORS})...")
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
# Create a specific directory for this index configuration
index_dir = f"{INDEX_SAVING_FILE}/rpj_wiki_nsg_IP_M{M}"
os.makedirs(index_dir, exist_ok=True)
index_filename = f"{index_dir}/index.faiss"
# Check if index already exists
if os.path.exists(index_filename):
print(f"Found existing index at {index_filename}, loading...")
index = faiss.read_index(index_filename)
print("Index loaded successfully.")
else:
print('Building HNSW index (IP)...')
# add build time
start_time = time.time()
index = faiss.IndexNSGFlat(d, M, faiss.METRIC_INNER_PRODUCT)
index.verbose = True
index.add(xb)
end_time = time.time()
print(f'time: {end_time - start_time}')
print('HNSW index built.')
# Save the HNSW index
print(f"Saving index to {index_filename}...")
faiss.write_index(index, index_filename)
print("Index saved successfully.")
# Analyze the HNSW index
print("\nAnalyzing HNSW index...")
print(f"Total number of nodes: {index.ntotal}")
print("Neighbor statistics:")
print(index.nsg.print_neighbor_stats(0))
# Save degree distribution
distribution_filename = f"{index_dir}/degree_distribution.txt"
print(f"Saving degree distribution to {distribution_filename}...")
index.nsg.save_degree_distribution(distribution_filename)
print("Degree distribution saved successfully.")
# Plot the degree distribution
plot_output_path = f"{index_dir}/degree_distribution.png"
print(f"Generating degree distribution plot to {plot_output_path}...")
try:
subprocess.run(
["python", f"{project_root}/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
check=True
)
print(f"Degree distribution plot saved to {plot_output_path}")
except subprocess.CalledProcessError as e:
print(f"Error generating degree distribution plot: {e}")
except FileNotFoundError:
print("Warning: plot_degree_distribution.py script not found in specified path")
print('Searching HNSW index...')
recall_result_file = f"{index_dir}/recall_result.txt"
with open(recall_result_file, 'w') as f:
for efSearch in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
index.nsg.efSearch = efSearch
# calculate the time of searching
start_time = time.time()
D, I = index.search(xq_full, K_NEIGHBORS)
end_time = time.time()
print(f'time: {end_time - start_time}')
# calculate the recall using the flat index
recall = []
for i in range(len(I)):
acc = 0
for j in range(len(I[i])):
if I[i][j] in recall_idx_flat[i]:
acc += 1
recall.append(acc / len(I[i]))
recall = sum(recall) / len(recall)
print(f'efSearch: {efSearch}')
print(f'recall: {recall}')
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
print(f'Done and result saved to {recall_result_file}')

View File

@@ -0,0 +1,22 @@
set -l index_dirs \
/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index \
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/index.faiss \
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/index.faiss \
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/index.faiss
# /opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index
set -l index_labels \
origin \
0.01per_M4_degree_based \
M8_merge_edge \
random_delete50
# nsg_R16
set -gx CUDA_VISIBLE_DEVICES 3
for i in (seq (count $index_dirs))
set -l index_file $index_dirs[$i]
set -l index_label $index_labels[$i]
echo "Building HNSW index with $index_label..." >> ./large_graph_simple_build.log
python -u large_graph_simple_build.py --index-file $index_file | tee -a ./large_graph_simple_build.log
end