Initial commit
This commit is contained in:
BIN
packages/leann-backend-hnsw/third_party/faiss/demo/H_hnsw_performance_comparison.png
vendored
Normal file
BIN
packages/leann-backend-hnsw/third_party/faiss/demo/H_hnsw_performance_comparison.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 134 KiB |
BIN
packages/leann-backend-hnsw/third_party/faiss/demo/H_hnsw_recall_comparison.png
vendored
Normal file
BIN
packages/leann-backend-hnsw/third_party/faiss/demo/H_hnsw_recall_comparison.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 89 KiB |
67
packages/leann-backend-hnsw/third_party/faiss/demo/build_demo.py
vendored
Normal file
67
packages/leann-backend-hnsw/third_party/faiss/demo/build_demo.py
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
import pickle
|
||||
import faiss
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
EMBEDDING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/passages_00.pkl"
|
||||
INDEX_OUTPUT_DIR = "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw" # 保存索引的目录
|
||||
M_VALUES_FOR_L2 = [30, 60] # M values for L2
|
||||
EF_CONSTRUCTION_FOR_L2 = 128 # fixed efConstruction
|
||||
|
||||
if not os.path.exists(INDEX_OUTPUT_DIR):
|
||||
print(f"Creating index directory: {INDEX_OUTPUT_DIR}")
|
||||
os.makedirs(INDEX_OUTPUT_DIR)
|
||||
|
||||
print(f"Loading embeddings from {EMBEDDING_FILE}...")
|
||||
with open(EMBEDDING_FILE, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
# Directly assume data is a tuple and the second element is embeddings
|
||||
embeddings = data[1]
|
||||
|
||||
print(f"Converting embeddings from {embeddings.dtype} to float32.")
|
||||
embeddings = embeddings.astype(np.float32)
|
||||
print(f"Loaded embeddings, shape: {embeddings.shape}")
|
||||
dim = embeddings.shape[1]
|
||||
|
||||
# --- Build HNSW L2 index ---
|
||||
print("\n--- Build HNSW L2 index ---")
|
||||
|
||||
# Loop through M values
|
||||
for HNSW_M in M_VALUES_FOR_L2:
|
||||
efConstruction = EF_CONSTRUCTION_FOR_L2
|
||||
|
||||
print(f"\nBuilding HNSW L2 index: M={HNSW_M}, efConstruction={efConstruction}...")
|
||||
|
||||
# Define the filename and path for the L2 index
|
||||
hnsw_filename = f"hnsw_IP_M{HNSW_M}_efC{efConstruction}.index"
|
||||
hnsw_filepath = os.path.join(INDEX_OUTPUT_DIR, hnsw_filename)
|
||||
|
||||
# Note: No longer check if the file exists, it will be overwritten if it exists
|
||||
|
||||
# Create HNSW L2 index
|
||||
index_hnsw = faiss.IndexHNSWFlat(dim, HNSW_M, faiss.METRIC_INNER_PRODUCT)
|
||||
index_hnsw.hnsw.efConstruction = efConstruction
|
||||
|
||||
index_hnsw.verbose = True
|
||||
|
||||
print(f"Adding {embeddings.shape[0]} vectors to HNSW L2 (M={HNSW_M}) index...")
|
||||
start_time_build = time.time()
|
||||
|
||||
index_hnsw.add(embeddings)
|
||||
|
||||
end_time_build = time.time()
|
||||
build_time_s = end_time_build - start_time_build
|
||||
print(f"HNSW L2 build time: {build_time_s:.4f} seconds")
|
||||
|
||||
# Save L2 index (direct operation, no try-except)
|
||||
print(f"Saving HNSW L2 index to {hnsw_filepath}")
|
||||
faiss.write_index(index_hnsw, hnsw_filepath)
|
||||
# Do not check storage size or handle save errors
|
||||
|
||||
print(f"Index {hnsw_filename} saved.")
|
||||
|
||||
del index_hnsw
|
||||
|
||||
print("\n--- HNSW L2 index build completed ---")
|
||||
print("\nScript ended.")
|
||||
250
packages/leann-backend-hnsw/third_party/faiss/demo/build_demo_sample.py
vendored
Normal file
250
packages/leann-backend-hnsw/third_party/faiss/demo/build_demo_sample.py
vendored
Normal file
@@ -0,0 +1,250 @@
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import time
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
# --- Configuration ---
|
||||
DOMAIN_NAME = "rpj_wiki" # Domain name used for finding passages
|
||||
EMBEDDER_NAME = "facebook/contriever-msmarco" # Used in paths
|
||||
ORIGINAL_EMBEDDING_SHARD_ID = 0 # The shard ID of the embedding file we are loading
|
||||
|
||||
# Define the base directory
|
||||
SCALING_OUT_DIR = Path("/powerrag/scaling_out").resolve()
|
||||
|
||||
# Original Data Paths (using functions similar to your utils)
|
||||
# Assuming embeddings for rpj_wiki are in a single file despite passage sharding
|
||||
# Adjust NUM_SHARDS_EMBEDDING if embeddings are also sharded
|
||||
NUM_SHARDS_EMBEDDING = 1
|
||||
ORIGINAL_EMBEDDING_FILE_TEMPLATE = (
|
||||
SCALING_OUT_DIR
|
||||
/ "embeddings/{embedder_name}/{domain_name}/{total_shards}-shards/passages_{shard_id:02d}.pkl"
|
||||
)
|
||||
ORIGINAL_EMBEDDING_FILE = str(ORIGINAL_EMBEDDING_FILE_TEMPLATE).format(
|
||||
embedder_name=EMBEDDER_NAME,
|
||||
domain_name=DOMAIN_NAME,
|
||||
total_shards=NUM_SHARDS_EMBEDDING,
|
||||
shard_id=ORIGINAL_EMBEDDING_SHARD_ID,
|
||||
)
|
||||
|
||||
# Passage Paths
|
||||
NUM_SHARDS_PASSAGE = 8 # As specified in your original utils (NUM_SHARDS['rpj_wiki'])
|
||||
ORIGINAL_PASSAGE_FILE_TEMPLATE = (
|
||||
SCALING_OUT_DIR
|
||||
/ "passages/{domain_name}/{total_shards}-shards/raw_passages-{shard_id}-of-{total_shards}.pkl"
|
||||
)
|
||||
|
||||
# New identifier for the sampled dataset
|
||||
NEW_DATASET_NAME = "rpj_wiki_1M"
|
||||
|
||||
# Fraction to sample (1/60)
|
||||
SAMPLE_FRACTION = 1 / 60
|
||||
|
||||
# Output Paths for the new sampled dataset
|
||||
OUTPUT_EMBEDDING_DIR = SCALING_OUT_DIR / "embeddings" / EMBEDDER_NAME / NEW_DATASET_NAME / "1-shards"
|
||||
OUTPUT_PASSAGE_DIR = SCALING_OUT_DIR / "passages" / NEW_DATASET_NAME / "1-shards"
|
||||
|
||||
OUTPUT_EMBEDDING_FILE = OUTPUT_EMBEDDING_DIR / f"passages_{ORIGINAL_EMBEDDING_SHARD_ID:02d}.pkl"
|
||||
# The new passage file represents the *single* shard of the sampled data
|
||||
OUTPUT_PASSAGE_FILE = OUTPUT_PASSAGE_DIR / f"raw_passages-0-of-1.pkl"
|
||||
|
||||
# --- Directory Setup ---
|
||||
print("Creating output directories if they don't exist...")
|
||||
OUTPUT_EMBEDDING_DIR.mkdir(parents=True, exist_ok=True)
|
||||
OUTPUT_PASSAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Embeddings output dir: {OUTPUT_EMBEDDING_DIR}")
|
||||
print(f"Passages output dir: {OUTPUT_PASSAGE_DIR}")
|
||||
|
||||
|
||||
# --- Helper Function to Load Passages ---
|
||||
def load_all_passages(domain_name, num_shards, template):
|
||||
"""Loads all passage shards and creates an ID-to-content map."""
|
||||
all_passages_list = []
|
||||
passage_id_to_content_map = {}
|
||||
print(f"Loading passages for domain '{domain_name}' from {num_shards} shards...")
|
||||
total_loaded = 0
|
||||
start_time = time.time()
|
||||
|
||||
for shard_id in range(num_shards):
|
||||
shard_path_str = str(template).format(
|
||||
domain_name=domain_name,
|
||||
total_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
shard_path = Path(shard_path_str)
|
||||
|
||||
if not shard_path.exists():
|
||||
print(f"Warning: Passage shard file not found, skipping: {shard_path}")
|
||||
continue
|
||||
|
||||
try:
|
||||
print(f" Loading shard {shard_id} from {shard_path}...")
|
||||
with open(shard_path, 'rb') as f:
|
||||
shard_passages = pickle.load(f) # Expected: list of dicts
|
||||
if not isinstance(shard_passages, list):
|
||||
print(f"Warning: Shard {shard_id} data is not a list.")
|
||||
continue
|
||||
|
||||
all_passages_list.extend(shard_passages)
|
||||
# Build the map, ensuring IDs are strings for consistent lookup
|
||||
for passage_dict in shard_passages:
|
||||
if 'id' in passage_dict:
|
||||
passage_id_to_content_map[str(passage_dict['id'])] = passage_dict
|
||||
else:
|
||||
print(f"Warning: Passage dict in shard {shard_id} missing 'id' key.")
|
||||
print(f" Loaded {len(shard_passages)} passages from shard {shard_id}.")
|
||||
total_loaded += len(shard_passages)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading passage shard {shard_id} from {shard_path}: {e}")
|
||||
|
||||
load_time = time.time() - start_time
|
||||
print(f"Finished loading passages. Total passages loaded: {total_loaded} in {load_time:.2f} seconds.")
|
||||
print(f"Total unique passages mapped by ID: {len(passage_id_to_content_map)}")
|
||||
return all_passages_list, passage_id_to_content_map
|
||||
|
||||
|
||||
# --- Load Original Embeddings ---
|
||||
print(f"\nLoading original embeddings from {ORIGINAL_EMBEDDING_FILE}...")
|
||||
start_load_time = time.time()
|
||||
try:
|
||||
with open(ORIGINAL_EMBEDDING_FILE, 'rb') as f:
|
||||
original_embedding_data = pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Original embedding file not found at {ORIGINAL_EMBEDDING_FILE}")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error loading embedding pickle file: {e}")
|
||||
exit(1)
|
||||
load_time = time.time() - start_load_time
|
||||
print(f"Loaded original embeddings data in {load_time:.2f} seconds.")
|
||||
|
||||
# --- Extract and Validate Embeddings ---
|
||||
try:
|
||||
if not isinstance(original_embedding_data, (list, tuple)) or len(original_embedding_data) != 2:
|
||||
raise TypeError("Expected embedding data to be a list or tuple of length 2 (ids, embeddings)")
|
||||
|
||||
original_embedding_ids = original_embedding_data[0] # Should be a list/iterable of IDs
|
||||
original_embeddings = original_embedding_data[1] # Should be a NumPy array
|
||||
|
||||
# Ensure IDs are in a list for easier indexing later if they aren't already
|
||||
if not isinstance(original_embedding_ids, list):
|
||||
print("Converting embedding IDs to list...")
|
||||
original_embedding_ids = list(original_embedding_ids)
|
||||
|
||||
if not isinstance(original_embeddings, np.ndarray):
|
||||
raise TypeError("Expected second element of embedding data to be a NumPy array")
|
||||
|
||||
print(f"Original data contains {len(original_embedding_ids)} embedding IDs.")
|
||||
print(f"Original embeddings shape: {original_embeddings.shape}, dtype: {original_embeddings.dtype}")
|
||||
|
||||
if len(original_embedding_ids) != original_embeddings.shape[0]:
|
||||
raise ValueError(f"Mismatch! Number of embedding IDs ({len(original_embedding_ids)}) does not match number of embeddings ({original_embeddings.shape[0]})")
|
||||
|
||||
except (TypeError, ValueError, IndexError) as e:
|
||||
print(f"Error processing loaded embedding data: {e}")
|
||||
print("Please ensure the embedding pickle file contains: (list_of_passage_ids, numpy_embedding_array)")
|
||||
exit(1)
|
||||
|
||||
total_embeddings = original_embeddings.shape[0]
|
||||
|
||||
# --- Load Original Passages ---
|
||||
# This might take time and memory depending on the dataset size
|
||||
_, passage_id_to_content_map = load_all_passages(
|
||||
DOMAIN_NAME, NUM_SHARDS_PASSAGE, ORIGINAL_PASSAGE_FILE_TEMPLATE
|
||||
)
|
||||
|
||||
if not passage_id_to_content_map:
|
||||
print("Error: No passages were loaded. Cannot proceed with sampling.")
|
||||
exit(1)
|
||||
|
||||
# --- Calculate Sample Size ---
|
||||
num_samples = math.ceil(total_embeddings * SAMPLE_FRACTION) # Use ceil to get at least 1/60th
|
||||
print(f"\nTotal original embeddings: {total_embeddings}")
|
||||
print(f"Sampling fraction: {SAMPLE_FRACTION:.6f} (1/60)")
|
||||
print(f"Target number of samples: {num_samples}")
|
||||
|
||||
if num_samples > total_embeddings:
|
||||
print("Warning: Calculated sample size exceeds total embeddings. Using all embeddings.")
|
||||
num_samples = total_embeddings
|
||||
elif num_samples <= 0:
|
||||
print("Error: Calculated sample size is zero or negative.")
|
||||
exit(1)
|
||||
|
||||
# --- Perform Random Sampling (Based on Embeddings) ---
|
||||
print("\nPerforming random sampling based on embeddings...")
|
||||
start_sample_time = time.time()
|
||||
|
||||
# Set a seed for reproducibility if needed
|
||||
# np.random.seed(42)
|
||||
|
||||
# Generate unique random indices from the embeddings list
|
||||
sampled_indices = np.random.choice(total_embeddings, size=num_samples, replace=False)
|
||||
|
||||
# Retrieve the corresponding IDs and embeddings using the sampled indices
|
||||
sampled_embedding_ids = [original_embedding_ids[i] for i in sampled_indices]
|
||||
sampled_embeddings = original_embeddings[sampled_indices]
|
||||
|
||||
sample_time = time.time() - start_sample_time
|
||||
print(f"Sampling completed in {sample_time:.2f} seconds.")
|
||||
print(f"Sampled {len(sampled_embedding_ids)} IDs and embeddings.")
|
||||
print(f"Sampled embeddings shape: {sampled_embeddings.shape}")
|
||||
|
||||
# --- Retrieve Corresponding Passages ---
|
||||
print("\nRetrieving corresponding passages for sampled IDs...")
|
||||
start_passage_retrieval_time = time.time()
|
||||
sampled_passages = []
|
||||
missing_ids_count = 0
|
||||
for i, pid in enumerate(sampled_embedding_ids):
|
||||
# Convert pid to string for lookup in the map
|
||||
pid_str = str(pid)
|
||||
if pid_str in passage_id_to_content_map:
|
||||
sampled_passages.append(passage_id_to_content_map[pid_str])
|
||||
else:
|
||||
# This indicates an inconsistency between embedding IDs and passage IDs
|
||||
print(f"Warning: Passage ID '{pid_str}' (from embedding index {sampled_indices[i]}) not found in passage map.")
|
||||
missing_ids_count += 1
|
||||
|
||||
passage_retrieval_time = time.time() - start_passage_retrieval_time
|
||||
print(f"Retrieved {len(sampled_passages)} passages in {passage_retrieval_time:.2f} seconds.")
|
||||
if missing_ids_count > 0:
|
||||
print(f"Warning: Could not find passages for {missing_ids_count} sampled IDs.")
|
||||
|
||||
if not sampled_passages:
|
||||
print("Error: No corresponding passages found for the sampled embeddings. Check ID matching.")
|
||||
exit(1)
|
||||
|
||||
# --- Prepare Output Data ---
|
||||
# Embeddings output format: tuple(list_of_ids, numpy_array_of_embeddings)
|
||||
output_embedding_data = (sampled_embedding_ids, sampled_embeddings)
|
||||
# Passages output format: list[dict] (matching input shard format)
|
||||
output_passage_data = sampled_passages
|
||||
|
||||
# --- Save Sampled Embeddings ---
|
||||
print(f"\nSaving sampled embeddings to {OUTPUT_EMBEDDING_FILE}...")
|
||||
start_save_time = time.time()
|
||||
try:
|
||||
with open(OUTPUT_EMBEDDING_FILE, 'wb') as f:
|
||||
pickle.dump(output_embedding_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
except Exception as e:
|
||||
print(f"Error saving sampled embeddings: {e}")
|
||||
# Continue to try saving passages if desired, or exit(1)
|
||||
save_time = time.time() - start_save_time
|
||||
print(f"Saved sampled embeddings in {save_time:.2f} seconds.")
|
||||
|
||||
# --- Save Sampled Passages ---
|
||||
print(f"\nSaving sampled passages to {OUTPUT_PASSAGE_FILE}...")
|
||||
start_save_time = time.time()
|
||||
try:
|
||||
with open(OUTPUT_PASSAGE_FILE, 'wb') as f:
|
||||
pickle.dump(output_passage_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
except Exception as e:
|
||||
print(f"Error saving sampled passages: {e}")
|
||||
exit(1)
|
||||
save_time = time.time() - start_save_time
|
||||
print(f"Saved sampled passages in {save_time:.2f} seconds.")
|
||||
|
||||
print(f"\nScript finished successfully.")
|
||||
print(f"Sampled embeddings saved to: {OUTPUT_EMBEDDING_FILE}")
|
||||
print(f"Sampled passages saved to: {OUTPUT_PASSAGE_FILE}")
|
||||
BIN
packages/leann-backend-hnsw/third_party/faiss/demo/hnsw_performance_comparison.png
vendored
Normal file
BIN
packages/leann-backend-hnsw/third_party/faiss/demo/hnsw_performance_comparison.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 138 KiB |
BIN
packages/leann-backend-hnsw/third_party/faiss/demo/hnsw_recall_comparison.png
vendored
Normal file
BIN
packages/leann-backend-hnsw/third_party/faiss/demo/hnsw_recall_comparison.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 86 KiB |
354
packages/leann-backend-hnsw/third_party/faiss/demo/large_graph_simple_build.py
vendored
Normal file
354
packages/leann-backend-hnsw/third_party/faiss/demo/large_graph_simple_build.py
vendored
Normal file
@@ -0,0 +1,354 @@
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
sys.path.append(os.path.join(project_root, "demo"))
|
||||
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS
|
||||
sys.path.append(project_root)
|
||||
from contriever.src.contriever import Contriever, load_retriever
|
||||
|
||||
M = 32
|
||||
efConstruction = 256
|
||||
K_NEIGHBORS = 3
|
||||
|
||||
DB_EMBEDDING_FILE = "/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/passages_00.pkl"
|
||||
INDEX_SAVING_FILE = "/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/indices"
|
||||
TASK_NAME = "nq"
|
||||
EMBEDDER_MODEL_NAME = "facebook/contriever-msmarco"
|
||||
MAX_QUERIES_TO_LOAD = 1000
|
||||
QUERY_ENCODING_BATCH_SIZE = 64
|
||||
|
||||
# 1M samples
|
||||
print(f"Loading embeddings from {DB_EMBEDDING_FILE}...")
|
||||
with open(DB_EMBEDDING_FILE, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
xb = data[1]
|
||||
print(f"Original dtype: {xb.dtype}")
|
||||
|
||||
if xb.dtype != np.float32:
|
||||
print("Converting embeddings to float32.")
|
||||
xb = xb.astype(np.float32)
|
||||
else:
|
||||
print("Embeddings are already float32.")
|
||||
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
|
||||
d = xb.shape[1] # Get dimension
|
||||
|
||||
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
|
||||
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
|
||||
|
||||
query_texts = []
|
||||
print(f"Reading queries from: {query_file_path}")
|
||||
with open(query_file_path, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= MAX_QUERIES_TO_LOAD:
|
||||
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
|
||||
break
|
||||
record = json.loads(line)
|
||||
query_texts.append(record["query"])
|
||||
print(f"Loaded {len(query_texts)} query texts.")
|
||||
|
||||
print("\nInitializing retriever model for encoding queries...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
model, tokenizer, _ = load_retriever(EMBEDDER_MODEL_NAME)
|
||||
model.to(device)
|
||||
model.eval() # Set to evaluation mode
|
||||
print("Retriever model loaded.")
|
||||
|
||||
|
||||
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
||||
"""Embed queries using the model with batching"""
|
||||
model = model.half()
|
||||
model.eval()
|
||||
embeddings = []
|
||||
batch_question = []
|
||||
|
||||
with torch.no_grad():
|
||||
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
||||
batch_question.append(query)
|
||||
|
||||
# Process when batch is full or at the end
|
||||
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_question,
|
||||
return_tensors="pt",
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
output = model(**encoded_batch)
|
||||
|
||||
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
||||
if "contriever" not in model_name_or_path:
|
||||
output = output.last_hidden_state[:, 0, :]
|
||||
|
||||
embeddings.append(output.cpu())
|
||||
batch_question = [] # Reset batch
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0).numpy()
|
||||
print(f"Query embeddings shape: {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
|
||||
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_MODEL_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
|
||||
|
||||
# Ensure float32 for Faiss compatibility after encoding
|
||||
if xq_full.dtype != np.float32:
|
||||
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
|
||||
xq_full = xq_full.astype(np.float32)
|
||||
|
||||
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
|
||||
|
||||
# Check dimension consistency
|
||||
if xq_full.shape[1] != d:
|
||||
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
|
||||
|
||||
# loading index_flat from cache
|
||||
cache_file = f"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json"
|
||||
|
||||
recall_idx_flat = None
|
||||
if os.path.exists(cache_file):
|
||||
print(f"Loading cached FLAT index results from {cache_file}...")
|
||||
start_time = time.time()
|
||||
with open(cache_file, 'r') as f:
|
||||
cached_results = json.load(f)
|
||||
D_flat = np.array(cached_results["distances"])
|
||||
recall_idx_flat = np.array(cached_results["indices"])
|
||||
end_time = time.time()
|
||||
print(f"Loaded cached results in {end_time - start_time:.3f} seconds")
|
||||
else:
|
||||
print("\nBuilding FlatIP index for ground truth...")
|
||||
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
|
||||
index_flat.add(xb)
|
||||
print(f"Searching FlatIP index with {len(xq_full)} queries (k={K_NEIGHBORS})...")
|
||||
start_time = time.time()
|
||||
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
|
||||
end_time = time.time()
|
||||
print(f"Time taken for FLAT index search: {end_time - start_time:.3f} seconds")
|
||||
|
||||
# Save results to cache
|
||||
# with open(cache_file, 'w') as f:
|
||||
# json.dump({
|
||||
# "distances": D_flat.tolist(),
|
||||
# "indices": recall_idx_flat.tolist(),
|
||||
# "metadata": {
|
||||
# "task": TASK_NAME,
|
||||
# "k": K_NEIGHBORS,
|
||||
# "timestamp": time.strftime("%Y%m%d_%H%M%S"),
|
||||
# }
|
||||
# }, f)
|
||||
# print(f"Cached FLAT index results to {cache_file}")
|
||||
|
||||
# print(recall_idx_flat)
|
||||
|
||||
# Create a specific directory for this index configuration
|
||||
# index_dir = f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}"
|
||||
# os.makedirs(index_of, exist_ok=True)
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("--index-dir", type=str, default=f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}")
|
||||
parser.add_argument("--index-file", type=str, default=f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}/index.faiss")
|
||||
args = parser.parse_args()
|
||||
index_filename = args.index_file
|
||||
index_dir = os.path.dirname(index_filename)
|
||||
os.makedirs(index_dir, exist_ok=True)
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(index_filename):
|
||||
print(f"Found existing index at {index_filename}, loading...")
|
||||
index = faiss.read_index(index_filename)
|
||||
print("Index loaded successfully.")
|
||||
else:
|
||||
assert False, "Index does not exist"
|
||||
print(f'Building {"NSG" if "nsg" in index_filename else "HNSW"} index (IP)...')
|
||||
# add build time
|
||||
start_time = time.time()
|
||||
if 'nsg' in index_filename:
|
||||
index = faiss.IndexNSGFlat(d, M, faiss.METRIC_INNER_PRODUCT)
|
||||
index.verbose = True
|
||||
else:
|
||||
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
|
||||
index.hnsw.efConstruction = efConstruction
|
||||
index.hnsw.set_percentile_thresholds()
|
||||
index.add(xb)
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
print(f'{"NSG" if "nsg" in index_filename else "HNSW"} index built.')
|
||||
|
||||
# Save the index
|
||||
print(f"Saving index to {index_filename}...")
|
||||
faiss.write_index(index, index_filename)
|
||||
print("Index saved successfully.")
|
||||
|
||||
# Analyze the index
|
||||
print("\nAnalyzing index...")
|
||||
print(f"Total number of nodes: {index.ntotal}")
|
||||
print("Neighbor statistics:")
|
||||
if 'nsg' in index_filename:
|
||||
print(index.nsg.print_neighbor_stats(0))
|
||||
else:
|
||||
print(index.hnsw.print_neighbor_stats(0))
|
||||
|
||||
# Save degree distribution
|
||||
distribution_filename = f"{index_dir}/degree_distribution.txt"
|
||||
print(f"Saving degree distribution to {distribution_filename}...")
|
||||
if 'nsg' in index_filename:
|
||||
index.nsg.save_degree_distribution(distribution_filename)
|
||||
else:
|
||||
index.hnsw.save_degree_distribution(0, distribution_filename)
|
||||
print("Degree distribution saved successfully.")
|
||||
|
||||
# Plot the degree distribution
|
||||
plot_output_path = f"{index_dir}/degree_distribution.png"
|
||||
print(f"Generating degree distribution plot to {plot_output_path}...")
|
||||
try:
|
||||
subprocess.run(
|
||||
["python", "/home/ubuntu/Power-RAG/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
|
||||
check=True
|
||||
)
|
||||
print(f"Degree distribution plot saved to {plot_output_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error generating degree distribution plot: {e}")
|
||||
except FileNotFoundError:
|
||||
print("Warning: plot_degree_distribution.py script not found in current directory")
|
||||
|
||||
print('Searching HNSW index...')
|
||||
|
||||
|
||||
|
||||
# for efSearch in [2, 4, 8, 16, 32, 64,128,256,512,1024]:
|
||||
# print(f'*************efSearch: {efSearch}*************')
|
||||
# for i in range(10):
|
||||
# index.hnsw.efSearch = efSearch
|
||||
# D, I = index.search(xq_full[i:i+1], K_NEIGHBORS)
|
||||
# exit()
|
||||
|
||||
|
||||
recall_result_file = f"{index_dir}/recall_result.txt"
|
||||
time_list = []
|
||||
recall_list = []
|
||||
recompute_list = []
|
||||
with open(recall_result_file, 'w') as f:
|
||||
for efSearch in [2, 4, 8, 16, 24, 32, 48, 64, 96,114,128,144,160,176,192,208,224,240,256,384,420,440,460,480,512,768,1024,1152,1536,1792,2048,2230,2408,2880]:
|
||||
if 'nsg' in index_filename:
|
||||
index.nsg.efSearch = efSearch
|
||||
else:
|
||||
index.hnsw.efSearch = efSearch
|
||||
# calculate the time of searching
|
||||
start_time = time.time()
|
||||
if not ('nsg' in index_filename):
|
||||
faiss.cvar.hnsw_stats.reset()
|
||||
else:
|
||||
faiss.cvar.nsg_stats.reset()
|
||||
D, I = index.search(xq_full, K_NEIGHBORS)
|
||||
print('D[0]:', D[0])
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
time_list.append(end_time - start_time)
|
||||
if 'nsg' in index_filename:
|
||||
print("recompute:", faiss.cvar.nsg_stats.ndis/len(I))
|
||||
recompute_list.append(faiss.cvar.nsg_stats.ndis/len(I))
|
||||
else:
|
||||
print("recompute:", faiss.cvar.hnsw_stats.ndis/len(I))
|
||||
recompute_list.append(faiss.cvar.hnsw_stats.ndis/len(I))
|
||||
# print(I)
|
||||
|
||||
# calculate the recall using the flat index the formula:
|
||||
# recall = sum(recall_idx == recall_idx_flat) / len(recall_idx)
|
||||
recall=[]
|
||||
for i in range(len(I)):
|
||||
acc = 0
|
||||
for j in range(len(I[i])):
|
||||
if I[i][j] in recall_idx_flat[i]:
|
||||
acc += 1
|
||||
recall.append(acc / len(I[i]))
|
||||
recall = sum(recall) / len(recall)
|
||||
recall_list.append(recall)
|
||||
print(f'efSearch: {efSearch}')
|
||||
print(f'recall: {recall}')
|
||||
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
|
||||
print(f'Done and result saved to {recall_result_file}')
|
||||
print(f'time_list: {time_list}')
|
||||
print(f'recall_list: {recall_list}')
|
||||
print(f'recompute_list: {recompute_list}')
|
||||
exit()
|
||||
# Analyze edge stats
|
||||
print("\nAnalyzing edge statistics...")
|
||||
edge_stats_file = f"{index_dir}/edge_stats.txt"
|
||||
if not os.path.exists(edge_stats_file):
|
||||
index.save_edge_stats(edge_stats_file)
|
||||
print(f'Edge stats saved to {edge_stats_file}')
|
||||
else:
|
||||
print(f'Edge stats already exists at {edge_stats_file}')
|
||||
|
||||
|
||||
def analyze_edge_stats(filename):
|
||||
"""
|
||||
Analyze edge statistics from a CSV file and print thresholds at various percentiles.
|
||||
|
||||
Args:
|
||||
filename: Path to the edge statistics CSV file
|
||||
"""
|
||||
if not os.path.exists(filename):
|
||||
print(f"Error: File {filename} does not exist")
|
||||
return
|
||||
|
||||
print(f"Analyzing edge statistics from {filename}...")
|
||||
|
||||
# Read the file
|
||||
distances = []
|
||||
with open(filename, 'r') as f:
|
||||
# Skip header
|
||||
header = f.readline()
|
||||
|
||||
# Read all edges
|
||||
for line in f:
|
||||
parts = line.strip().split(',')
|
||||
if len(parts) >= 4:
|
||||
try:
|
||||
src = int(parts[0])
|
||||
dst = int(parts[1])
|
||||
level = int(parts[2])
|
||||
distance = float(parts[3])
|
||||
distances.append(distance)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not distances:
|
||||
print("No valid edges found in file")
|
||||
return
|
||||
|
||||
# Sort distances
|
||||
distances = np.array(distances)
|
||||
distances.sort()
|
||||
|
||||
# Calculate and print statistics
|
||||
print(f"Total edges: {len(distances)}")
|
||||
print(f"Min distance: {distances[0]:.6f}")
|
||||
print(f"Max distance: {distances[-1]:.6f}")
|
||||
|
||||
# Print thresholds at specified percentiles
|
||||
percentiles = [0.5, 1, 2, 3, 5, 8, 10, 15, 20,30,40,50,60,70]
|
||||
print("\nDistance thresholds at percentiles:")
|
||||
for p in percentiles:
|
||||
idx = int(len(distances) * p / 100)
|
||||
if idx < len(distances):
|
||||
print(f"{p:.1f}%: {distances[idx]:.6f}")
|
||||
|
||||
return distances
|
||||
|
||||
analyze_edge_stats(edge_stats_file)
|
||||
194
packages/leann-backend-hnsw/third_party/faiss/demo/plot_graph_struct.py
vendored
Normal file
194
packages/leann-backend-hnsw/third_party/faiss/demo/plot_graph_struct.py
vendored
Normal file
@@ -0,0 +1,194 @@
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def extract_data_from_log(log_content):
|
||||
"""Extract method names, recall lists, and recompute lists from the log file."""
|
||||
|
||||
# Regular expressions to find the data - modified to match the actual format
|
||||
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
|
||||
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
|
||||
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
|
||||
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
|
||||
|
||||
# Find all matches
|
||||
method_matches = re.findall(method_pattern, log_content)
|
||||
methods = []
|
||||
for match in method_matches:
|
||||
# Each match is a tuple with one empty string and one with the method name
|
||||
method = match[0] if match[0] else match[1]
|
||||
methods.append(method)
|
||||
|
||||
recall_lists_str = re.findall(recall_list_pattern, log_content)
|
||||
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
|
||||
avg_neighbors = re.findall(avg_neighbors_pattern, log_content)
|
||||
|
||||
# Debug information
|
||||
print(f"Found {len(methods)} methods: {methods}")
|
||||
print(f"Found {len(recall_lists_str)} recall lists")
|
||||
print(f"Found {len(recompute_lists_str)} recompute lists")
|
||||
print(f"Found {len(avg_neighbors)} average neighbors values")
|
||||
|
||||
# If the regex approach fails, try a more direct approach
|
||||
if len(methods) < 5:
|
||||
print("Regex approach failed, trying direct extraction...")
|
||||
sections = log_content.split("Building HNSW index with ")[1:]
|
||||
methods = []
|
||||
for section in sections:
|
||||
# Extract the method name (everything up to the first newline)
|
||||
method_name = section.split("\n")[0].strip()
|
||||
# Remove trailing dots if present
|
||||
method_name = method_name.rstrip('.')
|
||||
methods.append(method_name)
|
||||
print(f"Direct extraction found {len(methods)} methods: {methods}")
|
||||
|
||||
# Convert string representations of lists to actual lists
|
||||
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
|
||||
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
|
||||
|
||||
# Convert average neighbors to float
|
||||
avg_neighbors = [float(avg) for avg in avg_neighbors]
|
||||
|
||||
# Make sure we have the same number of items in each list
|
||||
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
|
||||
if min_length < 5:
|
||||
print(f"Warning: Expected 5 methods, but only found {min_length}")
|
||||
|
||||
# Ensure all lists have the same length
|
||||
methods = methods[:min_length]
|
||||
recall_lists = recall_lists[:min_length]
|
||||
recompute_lists = recompute_lists[:min_length]
|
||||
avg_neighbors = avg_neighbors[:min_length]
|
||||
|
||||
return methods, recall_lists, recompute_lists, avg_neighbors
|
||||
|
||||
def plot_performance(methods, recall_lists, recompute_lists, avg_neighbors):
|
||||
"""Create a plot comparing the performance of different HNSW methods."""
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
colors = ['blue', 'green', 'red', 'purple', 'orange']
|
||||
markers = ['o', 's', '^', 'x', 'd']
|
||||
|
||||
for i, method in enumerate(methods):
|
||||
# Add average neighbors to the label
|
||||
label = f"{method} (avg. {avg_neighbors[i]} neighbors)"
|
||||
plt.plot(recompute_lists[i], recall_lists[i], label=label,
|
||||
color=colors[i], marker=markers[i], markersize=8, markevery=5)
|
||||
|
||||
plt.xlabel('Distance Computations', fontsize=14)
|
||||
plt.ylabel('Recall', fontsize=14)
|
||||
plt.title('HNSW Index Performance: Recall vs. Computation Cost', fontsize=16)
|
||||
plt.grid(True, linestyle='--', alpha=0.7)
|
||||
plt.legend(fontsize=12)
|
||||
plt.xscale('log')
|
||||
plt.ylim(0, 1.0)
|
||||
|
||||
# Add horizontal lines for different recall levels
|
||||
recall_levels = [0.90, 0.95, 0.96, 0.97, 0.98]
|
||||
line_styles = [':', '--', '-.', '-', '-']
|
||||
line_widths = [1, 1, 1, 1.5, 1.5]
|
||||
|
||||
for i, level in enumerate(recall_levels):
|
||||
plt.axhline(y=level, color='gray', linestyle=line_styles[i],
|
||||
alpha=0.7, linewidth=line_widths[i])
|
||||
plt.text(130, level + 0.002, f'{level*100:.0f}% Recall', fontsize=10)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('faiss/demo/H_hnsw_performance_comparison.png')
|
||||
plt.show()
|
||||
|
||||
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors):
|
||||
"""Create a bar chart comparing computation costs at different recall levels."""
|
||||
|
||||
recall_levels = [0.90, 0.95, 0.96, 0.97, 0.98]
|
||||
|
||||
# Get computation costs for each method at each recall level
|
||||
computation_costs = []
|
||||
for i, method in enumerate(methods):
|
||||
method_costs = []
|
||||
for level in recall_levels:
|
||||
# Find the first index where recall exceeds the target level
|
||||
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
|
||||
if recall_idx is not None:
|
||||
method_costs.append(recompute_lists[i][recall_idx])
|
||||
else:
|
||||
# If the method doesn't reach this recall level, use None
|
||||
method_costs.append(None)
|
||||
computation_costs.append(method_costs)
|
||||
|
||||
# Set up the plot
|
||||
fig, ax = plt.subplots(figsize=(14, 8))
|
||||
|
||||
# Set width of bars
|
||||
bar_width = 0.15
|
||||
|
||||
# Set positions of the bars on X axis
|
||||
r = np.arange(len(recall_levels))
|
||||
|
||||
# Colors for each method
|
||||
colors = ['blue', 'green', 'red', 'purple', 'orange']
|
||||
|
||||
# Create bars
|
||||
for i, method in enumerate(methods):
|
||||
# Filter out None values
|
||||
valid_costs = [cost if cost is not None else 0 for cost in computation_costs[i]]
|
||||
valid_positions = [pos for pos, cost in zip(r + i*bar_width, computation_costs[i]) if cost is not None]
|
||||
valid_costs = [cost for cost in computation_costs[i] if cost is not None]
|
||||
|
||||
bars = ax.bar(valid_positions, valid_costs, width=bar_width,
|
||||
color=colors[i], label=f"{method} (avg. {avg_neighbors[i]} neighbors)")
|
||||
|
||||
# Add value labels on top of bars
|
||||
for j, bar in enumerate(bars):
|
||||
height = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width()/2., height + 500,
|
||||
f'{height:.0f}', ha='center', va='bottom', rotation=0, fontsize=9)
|
||||
|
||||
# Add labels and title
|
||||
ax.set_xlabel('Recall Level', fontsize=14)
|
||||
ax.set_ylabel('Distance Computations', fontsize=14)
|
||||
ax.set_title('Computation Cost Required to Achieve Different Recall Levels', fontsize=16)
|
||||
|
||||
# Set x-ticks
|
||||
ax.set_xticks(r + bar_width * 2)
|
||||
ax.set_xticklabels([f'{level*100:.0f}%' for level in recall_levels])
|
||||
|
||||
# Add legend
|
||||
ax.legend(fontsize=12)
|
||||
|
||||
# Add grid
|
||||
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('faiss/demo/H_hnsw_recall_comparison.png')
|
||||
plt.show()
|
||||
|
||||
# Read the log file
|
||||
with open('faiss/demo/output.log', 'r') as f:
|
||||
log_content = f.read()
|
||||
|
||||
# Extract data
|
||||
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
|
||||
|
||||
# Plot the results
|
||||
plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
|
||||
|
||||
# Plot the recall comparison
|
||||
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors)
|
||||
|
||||
# Print a summary of the methods and their characteristics
|
||||
print("\nMethod Summary:")
|
||||
for i, method in enumerate(methods):
|
||||
print(f"{method}:")
|
||||
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
|
||||
|
||||
# Find the recompute values needed for different recall levels
|
||||
recall_levels = [0.90, 0.95, 0.96, 0.97, 0.98]
|
||||
for level in recall_levels:
|
||||
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
|
||||
if recall_idx is not None:
|
||||
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.2f}")
|
||||
else:
|
||||
print(f" - Does not reach {level*100:.0f}% recall in the test")
|
||||
print()
|
||||
202
packages/leann-backend-hnsw/third_party/faiss/demo/plot_graph_struct_big.py
vendored
Normal file
202
packages/leann-backend-hnsw/third_party/faiss/demo/plot_graph_struct_big.py
vendored
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
RECALL_LEVELS = [0.85, 0.90, 0.93, 0.94, 0.95, 0.96]
|
||||
|
||||
def extract_data_from_log(log_content):
|
||||
"""Extract method names, recall lists, and recompute lists from the log file."""
|
||||
|
||||
# Regular expressions to find the dataz - modified to match the actual format
|
||||
method_pattern = r"Building HNSW index with (.*)\.\.\.|Building HNSW index with ([^\n]+)..."
|
||||
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
|
||||
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
|
||||
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
|
||||
|
||||
# Find all matches
|
||||
method_matches = re.findall(method_pattern, log_content)
|
||||
methods = []
|
||||
for match in method_matches:
|
||||
# Each match is a tuple with one empty string and one with the method name
|
||||
method = match[0] if match[0] else match[1]
|
||||
methods.append(method)
|
||||
|
||||
recall_lists_str = re.findall(recall_list_pattern, log_content)
|
||||
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
|
||||
avg_neighbors = re.findall(avg_neighbors_pattern, log_content)
|
||||
|
||||
# Debug information
|
||||
print(f"Found {len(methods)} methods: {methods}")
|
||||
print(f"Found {len(recall_lists_str)} recall lists")
|
||||
print(f"Found {len(recompute_lists_str)} recompute lists")
|
||||
print(f"Found {len(avg_neighbors)} average neighbors values")
|
||||
|
||||
# If the regex approach fails, try a more direct approach
|
||||
if len(methods) < 5:
|
||||
print("Regex approach failed, trying direct extraction...")
|
||||
sections = log_content.split("Building HNSW index with ")[1:]
|
||||
methods = []
|
||||
for section in sections:
|
||||
# Extract the method name (everything up to the first newline)
|
||||
method_name = section.split("\n")[0].strip()
|
||||
# Remove trailing dots if present
|
||||
method_name = method_name.rstrip('.')
|
||||
methods.append(method_name)
|
||||
print(f"Direct extraction found {len(methods)} methods: {methods}")
|
||||
|
||||
# Convert string representations of lists to actual lists
|
||||
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
|
||||
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
|
||||
|
||||
# Convert average neighbors to float
|
||||
avg_neighbors = [float(avg) for avg in avg_neighbors]
|
||||
|
||||
# Make sure we have the same number of items in each list
|
||||
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
|
||||
if min_length < 5:
|
||||
print(f"Warning: Expected 5 methods, but only found {min_length}")
|
||||
|
||||
# Ensure all lists have the same length
|
||||
methods = methods[:min_length]
|
||||
recall_lists = recall_lists[:min_length]
|
||||
recompute_lists = recompute_lists[:min_length]
|
||||
avg_neighbors = avg_neighbors[:min_length]
|
||||
|
||||
return methods, recall_lists, recompute_lists, avg_neighbors
|
||||
|
||||
def plot_performance(methods, recall_lists, recompute_lists, avg_neighbors):
|
||||
"""Create a plot comparing the performance of different HNSW methods."""
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
colors = ['blue', 'green', 'red', 'purple', 'orange']
|
||||
markers = ['o', 's', '^', 'x', 'd']
|
||||
|
||||
for i, method in enumerate(methods):
|
||||
# Add average neighbors to the label
|
||||
label = f"{method} (avg. {avg_neighbors[i]} neighbors)"
|
||||
plt.plot(recompute_lists[i], recall_lists[i], label=label,
|
||||
color=colors[i], marker=markers[i], markersize=8, markevery=5)
|
||||
|
||||
plt.xlabel('Distance Computations', fontsize=14)
|
||||
plt.ylabel('Recall', fontsize=14)
|
||||
plt.title('HNSW Index Performance: Recall vs. Computation Cost', fontsize=16)
|
||||
plt.grid(True, linestyle='--', alpha=0.7)
|
||||
plt.legend(fontsize=12)
|
||||
plt.xscale('log')
|
||||
plt.ylim(0, 1.0)
|
||||
|
||||
# Add horizontal lines for different recall levels
|
||||
line_styles = [':', '--', '-.', '-', '-', ':']
|
||||
line_widths = [1, 1, 1, 1.5, 1.5, 1]
|
||||
|
||||
for i, level in enumerate(RECALL_LEVELS):
|
||||
plt.axhline(y=level, color='gray', linestyle=line_styles[i],
|
||||
alpha=0.7, linewidth=line_widths[i])
|
||||
plt.text(130, level + 0.002, f'{level*100:.0f}% Recall', fontsize=10)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(os.path.dirname(__file__), 'H_hnsw_performance_comparison.png'))
|
||||
plt.show()
|
||||
|
||||
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors):
|
||||
"""Create a bar chart comparing computation costs at different recall levels."""
|
||||
|
||||
# Get computation costs for each method at each recall level
|
||||
computation_costs = []
|
||||
for i, method in enumerate(methods):
|
||||
method_costs = []
|
||||
for level in RECALL_LEVELS:
|
||||
# Find the first index where recall exceeds the target level
|
||||
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
|
||||
if recall_idx is not None:
|
||||
method_costs.append(recompute_lists[i][recall_idx])
|
||||
else:
|
||||
# If the method doesn't reach this recall level, use None
|
||||
method_costs.append(None)
|
||||
computation_costs.append(method_costs)
|
||||
|
||||
# Set up the plot
|
||||
fig, ax = plt.subplots(figsize=(14, 8))
|
||||
|
||||
# Set width of bars
|
||||
bar_width = 0.15
|
||||
|
||||
# Set positions of the bars on X axis
|
||||
r = np.arange(len(RECALL_LEVELS))
|
||||
|
||||
# Colors for each method
|
||||
colors = ['blue', 'green', 'red', 'purple', 'orange']
|
||||
|
||||
# Create bars
|
||||
for i, method in enumerate(methods):
|
||||
# Filter out None values
|
||||
valid_costs = [cost if cost is not None else 0 for cost in computation_costs[i]]
|
||||
valid_positions = [pos for pos, cost in zip(r + i*bar_width, computation_costs[i]) if cost is not None]
|
||||
valid_costs = [cost for cost in computation_costs[i] if cost is not None]
|
||||
|
||||
bars = ax.bar(valid_positions, valid_costs, width=bar_width,
|
||||
color=colors[i], label=f"{method} (avg. {avg_neighbors[i]} neighbors)")
|
||||
|
||||
# Add value labels on top of bars
|
||||
for j, bar in enumerate(bars):
|
||||
height = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width()/2., height + 500,
|
||||
f'{height:.0f}', ha='center', va='bottom', rotation=0, fontsize=9)
|
||||
|
||||
# Add labels and title
|
||||
ax.set_xlabel('Recall Level', fontsize=14)
|
||||
ax.set_ylabel('Distance Computations', fontsize=14)
|
||||
ax.set_title('Computation Cost Required to Achieve Different Recall Levels', fontsize=16)
|
||||
|
||||
# Set x-ticks
|
||||
ax.set_xticks(r + bar_width * 2)
|
||||
ax.set_xticklabels([f'{level*100:.0f}%' for level in RECALL_LEVELS])
|
||||
|
||||
# Add legend
|
||||
ax.legend(fontsize=12)
|
||||
|
||||
# Add grid
|
||||
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(os.path.dirname(__file__), 'H_hnsw_recall_comparison.png'))
|
||||
plt.show()
|
||||
|
||||
# Read the log file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('log_file', type=str, default='nlevel_output.log')
|
||||
args = parser.parse_args()
|
||||
|
||||
log_file = args.log_file
|
||||
log_file = os.path.join(os.path.dirname(__file__), log_file)
|
||||
|
||||
with open(log_file, 'r') as f:
|
||||
log_content = f.read()
|
||||
|
||||
# Extract data
|
||||
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
|
||||
|
||||
# Plot the results
|
||||
plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
|
||||
|
||||
# Plot the recall comparison
|
||||
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors)
|
||||
|
||||
# Print a summary of the methods and their characteristics
|
||||
print("\nMethod Summary:")
|
||||
for i, method in enumerate(methods):
|
||||
print(f"{method}:")
|
||||
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
|
||||
|
||||
# Find the recompute values needed for different recall levels
|
||||
for level in RECALL_LEVELS:
|
||||
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
|
||||
if recall_idx is not None:
|
||||
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.2f}")
|
||||
else:
|
||||
print(f" - Does not reach {level*100:.0f}% recall in the test")
|
||||
print()
|
||||
329
packages/leann-backend-hnsw/third_party/faiss/demo/simple_build.py
vendored
Normal file
329
packages/leann-backend-hnsw/third_party/faiss/demo/simple_build.py
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
import sys
|
||||
import time
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
# Add argument parsing
|
||||
parser = argparse.ArgumentParser(description='Build and evaluate HNSW index')
|
||||
parser.add_argument('--config', type=str, default="0.02per_M6-7_degree_based",
|
||||
help='Configuration name for the index (default: 0.02per_M6-7_degree_based)')
|
||||
parser.add_argument('--M', type=int, default=32,
|
||||
help='HNSW M parameter (default: 32)')
|
||||
parser.add_argument('--efConstruction', type=int, default=256,
|
||||
help='HNSW efConstruction parameter (default: 256)')
|
||||
parser.add_argument('--K_NEIGHBORS', type=int, default=3,
|
||||
help='Number of neighbors to retrieve (default: 3)')
|
||||
parser.add_argument('--max_queries', type=int, default=1000,
|
||||
help='Maximum number of queries to load (default: 1000)')
|
||||
parser.add_argument('--batch_size', type=int, default=64,
|
||||
help='Batch size for query encoding (default: 64)')
|
||||
parser.add_argument('--db_embedding_file', type=str,
|
||||
default="/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/passages_00.pkl",
|
||||
help='Path to database embedding file')
|
||||
parser.add_argument('--index_saving_dir', type=str,
|
||||
default="/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/indices",
|
||||
help='Directory to save the index')
|
||||
parser.add_argument('--task_name', type=str, default="nq",
|
||||
help='Task name from TASK_CONFIGS (default: nq)')
|
||||
parser.add_argument('--embedder_model', type=str, default="facebook/contriever-msmarco",
|
||||
help='Model name for query embedding (default: facebook/contriever-msmarco)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Replace hardcoded constants with arguments
|
||||
M = args.M
|
||||
efConstruction = args.efConstruction
|
||||
K_NEIGHBORS = args.K_NEIGHBORS
|
||||
DB_EMBEDDING_FILE = args.db_embedding_file
|
||||
INDEX_SAVING_FILE = args.index_saving_dir
|
||||
TASK_NAME = args.task_name
|
||||
EMBEDDER_MODEL_NAME = args.embedder_model
|
||||
MAX_QUERIES_TO_LOAD = args.max_queries
|
||||
QUERY_ENCODING_BATCH_SIZE = args.batch_size
|
||||
CONFIG_NAME = args.config
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
sys.path.append(os.path.join(project_root, "demo"))
|
||||
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS
|
||||
sys.path.append(project_root)
|
||||
from contriever.src.contriever import Contriever, load_retriever
|
||||
|
||||
# 1M samples
|
||||
print(f"Loading embeddings from {DB_EMBEDDING_FILE}...")
|
||||
with open(DB_EMBEDDING_FILE, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
xb = data[1]
|
||||
print(f"Original dtype: {xb.dtype}")
|
||||
|
||||
if xb.dtype != np.float32:
|
||||
print("Converting embeddings to float32.")
|
||||
xb = xb.astype(np.float32)
|
||||
else:
|
||||
print("Embeddings are already float32.")
|
||||
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
|
||||
d = xb.shape[1] # Get dimension
|
||||
|
||||
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
|
||||
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
|
||||
|
||||
query_texts = []
|
||||
print(f"Reading queries from: {query_file_path}")
|
||||
with open(query_file_path, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= MAX_QUERIES_TO_LOAD:
|
||||
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
|
||||
break
|
||||
record = json.loads(line)
|
||||
query_texts.append(record["query"])
|
||||
print(f"Loaded {len(query_texts)} query texts.")
|
||||
|
||||
print("\nInitializing retriever model for encoding queries...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
model, tokenizer, _ = load_retriever(EMBEDDER_MODEL_NAME)
|
||||
model.to(device)
|
||||
model.eval() # Set to evaluation mode
|
||||
print("Retriever model loaded.")
|
||||
|
||||
|
||||
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
||||
"""Embed queries using the model with batching"""
|
||||
model = model.half()
|
||||
model.eval()
|
||||
embeddings = []
|
||||
batch_question = []
|
||||
|
||||
with torch.no_grad():
|
||||
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
||||
batch_question.append(query)
|
||||
|
||||
# Process when batch is full or at the end
|
||||
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_question,
|
||||
return_tensors="pt",
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
output = model(**encoded_batch)
|
||||
|
||||
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
||||
if "contriever" not in model_name_or_path:
|
||||
output = output.last_hidden_state[:, 0, :]
|
||||
|
||||
embeddings.append(output.cpu())
|
||||
batch_question = [] # Reset batch
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0).numpy()
|
||||
print(f"Query embeddings shape: {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
|
||||
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_MODEL_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
|
||||
|
||||
# Ensure float32 for Faiss compatibility after encoding
|
||||
if xq_full.dtype != np.float32:
|
||||
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
|
||||
xq_full = xq_full.astype(np.float32)
|
||||
|
||||
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
|
||||
|
||||
# Check dimension consistency
|
||||
if xq_full.shape[1] != d:
|
||||
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
|
||||
|
||||
# recall_idx = []
|
||||
|
||||
print("\nBuilding FlatIP index for ground truth...")
|
||||
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
|
||||
index_flat.add(xb)
|
||||
print(f"Searching FlatIP index with {MAX_QUERIES_TO_LOAD} queries (k={K_NEIGHBORS})...")
|
||||
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
|
||||
|
||||
# print(recall_idx_flat)
|
||||
|
||||
# Create a specific directory for this index configuration
|
||||
index_dir = f"{INDEX_SAVING_FILE}/{CONFIG_NAME}_hnsw_IP_M{M}_efC{efConstruction}"
|
||||
if CONFIG_NAME == "origin":
|
||||
index_dir = f"{INDEX_SAVING_FILE}/hnsw_IP_M{M}_efC{efConstruction}"
|
||||
os.makedirs(index_dir, exist_ok=True)
|
||||
index_filename = f"{index_dir}/index.faiss"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(index_filename):
|
||||
print(f"Found existing index at {index_filename}, loading...")
|
||||
index = faiss.read_index(index_filename)
|
||||
print("Index loaded successfully.")
|
||||
else:
|
||||
print('Building HNSW index (IP)...')
|
||||
# add build time
|
||||
start_time = time.time()
|
||||
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
|
||||
index.hnsw.efConstruction = efConstruction
|
||||
index.hnsw.set_percentile_thresholds()
|
||||
index.add(xb)
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
print('HNSW index built.')
|
||||
|
||||
# Save the HNSW index
|
||||
print(f"Saving index to {index_filename}...")
|
||||
faiss.write_index(index, index_filename)
|
||||
print("Index saved successfully.")
|
||||
# Analyze the HNSW index
|
||||
print("\nAnalyzing HNSW index...")
|
||||
print(f"Total number of nodes: {index.ntotal}")
|
||||
print("Neighbor statistics:")
|
||||
print(index.hnsw.print_neighbor_stats(0))
|
||||
|
||||
# Save degree distribution
|
||||
distribution_filename = f"{index_dir}/degree_distribution.txt"
|
||||
print(f"Saving degree distribution to {distribution_filename}...")
|
||||
index.hnsw.save_degree_distribution(0, distribution_filename)
|
||||
print("Degree distribution saved successfully.")
|
||||
|
||||
# Plot the degree distribution
|
||||
plot_output_path = f"{index_dir}/degree_distribution.png"
|
||||
print(f"Generating degree distribution plot to {plot_output_path}...")
|
||||
try:
|
||||
subprocess.run(
|
||||
["python", "/home/ubuntu/Power-RAG/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
|
||||
check=True
|
||||
)
|
||||
print(f"Degree distribution plot saved to {plot_output_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error generating degree distribution plot: {e}")
|
||||
except FileNotFoundError:
|
||||
print("Warning: plot_degree_distribution.py script not found in current directory")
|
||||
|
||||
print('Searching HNSW index...')
|
||||
|
||||
|
||||
|
||||
# for efSearch in [2, 4, 8, 16, 32, 64,128,256,512,1024]:
|
||||
# print(f'*************efSearch: {efSearch}*************')
|
||||
# for i in range(10):
|
||||
# index.hnsw.efSearch = efSearch
|
||||
# D, I = index.search(xq_full[i:i+1], K_NEIGHBORS)
|
||||
# exit()
|
||||
|
||||
|
||||
recall_result_file = f"{index_dir}/recall_result.txt"
|
||||
time_list = []
|
||||
recall_list = []
|
||||
recompute_list = []
|
||||
with open(recall_result_file, 'w') as f:
|
||||
for efSearch in [2, 4, 8, 16, 24, 32, 48, 64, 96,114,128,144,160,176,192,208,224,240,256,384,420,440,460,480,512,768,1024,1152,1536,1792,2048,2230,2408,2880]:
|
||||
index.hnsw.efSearch = efSearch
|
||||
# calculate the time of searching
|
||||
start_time = time.time()
|
||||
faiss.cvar.hnsw_stats.reset()
|
||||
# print faiss.cvar.hnsw_stats.ndis
|
||||
print(f'ndis: {faiss.cvar.hnsw_stats.ndis}')
|
||||
D, I = index.search(xq_full, K_NEIGHBORS)
|
||||
print('D[0]:', D[0])
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
time_list.append(end_time - start_time)
|
||||
print("recompute:", faiss.cvar.hnsw_stats.ndis/len(I))
|
||||
recompute_list.append(faiss.cvar.hnsw_stats.ndis/len(I))
|
||||
# print(I)
|
||||
|
||||
# calculate the recall using the flat index the formula:
|
||||
# recall = sum(recall_idx == recall_idx_flat) / len(recall_idx)
|
||||
recall=[]
|
||||
for i in range(len(I)):
|
||||
acc = 0
|
||||
for j in range(len(I[i])):
|
||||
if I[i][j] in recall_idx_flat[i]:
|
||||
acc += 1
|
||||
recall.append(acc / len(I[i]))
|
||||
recall = sum(recall) / len(recall)
|
||||
recall_list.append(recall)
|
||||
print(f'efSearch: {efSearch}')
|
||||
print(f'recall: {recall}')
|
||||
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
|
||||
print(f'Done and result saved to {recall_result_file}')
|
||||
print(f'time_list: {time_list}')
|
||||
print(f'recall_list: {recall_list}')
|
||||
print(f'recompute_list: {recompute_list}')
|
||||
exit()
|
||||
# Analyze edge stats
|
||||
print("\nAnalyzing edge statistics...")
|
||||
edge_stats_file = f"{index_dir}/edge_stats.txt"
|
||||
if not os.path.exists(edge_stats_file):
|
||||
index.save_edge_stats(edge_stats_file)
|
||||
print(f'Edge stats saved to {edge_stats_file}')
|
||||
else:
|
||||
print(f'Edge stats already exists at {edge_stats_file}')
|
||||
|
||||
|
||||
def analyze_edge_stats(filename):
|
||||
"""
|
||||
Analyze edge statistics from a CSV file and print thresholds at various percentiles.
|
||||
|
||||
Args:
|
||||
filename: Path to the edge statistics CSV file
|
||||
"""
|
||||
if not os.path.exists(filename):
|
||||
print(f"Error: File {filename} does not exist")
|
||||
return
|
||||
|
||||
print(f"Analyzing edge statistics from {filename}...")
|
||||
|
||||
# Read the file
|
||||
distances = []
|
||||
with open(filename, 'r') as f:
|
||||
# Skip header
|
||||
header = f.readline()
|
||||
|
||||
# Read all edges
|
||||
for line in f:
|
||||
parts = line.strip().split(',')
|
||||
if len(parts) >= 4:
|
||||
try:
|
||||
src = int(parts[0])
|
||||
dst = int(parts[1])
|
||||
level = int(parts[2])
|
||||
distance = float(parts[3])
|
||||
distances.append(distance)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not distances:
|
||||
print("No valid edges found in file")
|
||||
return
|
||||
|
||||
# Sort distances
|
||||
distances = np.array(distances)
|
||||
distances.sort()
|
||||
|
||||
# Calculate and print statistics
|
||||
print(f"Total edges: {len(distances)}")
|
||||
print(f"Min distance: {distances[0]:.6f}")
|
||||
print(f"Max distance: {distances[-1]:.6f}")
|
||||
|
||||
# Print thresholds at specified percentiles
|
||||
percentiles = [0.5, 1, 2, 3, 5, 8, 10, 15, 20,30,40,50,60,70]
|
||||
print("\nDistance thresholds at percentiles:")
|
||||
for p in percentiles:
|
||||
idx = int(len(distances) * p / 100)
|
||||
if idx < len(distances):
|
||||
print(f"{p:.1f}%: {distances[idx]:.6f}")
|
||||
|
||||
return distances
|
||||
|
||||
analyze_edge_stats(edge_stats_file)
|
||||
212
packages/leann-backend-hnsw/third_party/faiss/demo/simple_build_dpr.py
vendored
Normal file
212
packages/leann-backend-hnsw/third_party/faiss/demo/simple_build_dpr.py
vendored
Normal file
@@ -0,0 +1,212 @@
|
||||
import sys
|
||||
import time
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
sys.path.append(os.path.join(project_root, "demo"))
|
||||
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS, get_embedding_path
|
||||
sys.path.append(project_root)
|
||||
from contriever.src.contriever import Contriever, load_retriever
|
||||
|
||||
M = 32
|
||||
efConstruction = 256
|
||||
K_NEIGHBORS = 3
|
||||
|
||||
# Original configuration (commented out)
|
||||
# DB_EMBEDDING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/passages_00.pkl"
|
||||
# INDEX_SAVING_FILE = "/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices"
|
||||
|
||||
# New configuration using DPR
|
||||
DOMAIN_NAME = "dpr"
|
||||
EMBEDDER_NAME = "facebook/contriever-msmarco"
|
||||
TASK_NAME = "nq"
|
||||
MAX_QUERIES_TO_LOAD = 1000
|
||||
QUERY_ENCODING_BATCH_SIZE = 64
|
||||
|
||||
# Get the embedding path using the function from config
|
||||
embed_path = get_embedding_path(DOMAIN_NAME, EMBEDDER_NAME, 0)
|
||||
INDEX_SAVING_FILE = os.path.join(os.path.dirname(embed_path), "indices")
|
||||
os.makedirs(INDEX_SAVING_FILE, exist_ok=True)
|
||||
|
||||
# Load embeddings
|
||||
print(f"Loading embeddings from {embed_path}...")
|
||||
with open(embed_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
xb = data[1]
|
||||
print(f"Original dtype: {xb.dtype}")
|
||||
|
||||
if xb.dtype != np.float32:
|
||||
print("Converting embeddings to float32.")
|
||||
xb = xb.astype(np.float32)
|
||||
else:
|
||||
print("Embeddings are already float32.")
|
||||
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
|
||||
d = xb.shape[1] # Get dimension
|
||||
|
||||
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
|
||||
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
|
||||
|
||||
query_texts = []
|
||||
print(f"Reading queries from: {query_file_path}")
|
||||
with open(query_file_path, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= MAX_QUERIES_TO_LOAD:
|
||||
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
|
||||
break
|
||||
record = json.loads(line)
|
||||
query_texts.append(record["query"])
|
||||
print(f"Loaded {len(query_texts)} query texts.")
|
||||
|
||||
print("\nInitializing retriever model for encoding queries...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
model, tokenizer, _ = load_retriever(EMBEDDER_NAME)
|
||||
model.to(device)
|
||||
model.eval() # Set to evaluation mode
|
||||
print("Retriever model loaded.")
|
||||
|
||||
|
||||
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
||||
"""Embed queries using the model with batching"""
|
||||
model = model.half()
|
||||
model.eval()
|
||||
embeddings = []
|
||||
batch_question = []
|
||||
|
||||
with torch.no_grad():
|
||||
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
||||
batch_question.append(query)
|
||||
|
||||
# Process when batch is full or at the end
|
||||
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_question,
|
||||
return_tensors="pt",
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
output = model(**encoded_batch)
|
||||
|
||||
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
||||
if "contriever" not in model_name_or_path:
|
||||
output = output.last_hidden_state[:, 0, :]
|
||||
|
||||
embeddings.append(output.cpu())
|
||||
batch_question = [] # Reset batch
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0).numpy()
|
||||
print(f"Query embeddings shape: {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
|
||||
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
|
||||
|
||||
# Ensure float32 for Faiss compatibility after encoding
|
||||
if xq_full.dtype != np.float32:
|
||||
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
|
||||
xq_full = xq_full.astype(np.float32)
|
||||
|
||||
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
|
||||
|
||||
# Check dimension consistency
|
||||
if xq_full.shape[1] != d:
|
||||
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
|
||||
|
||||
# Build flat index for ground truth
|
||||
print("\nBuilding FlatIP index for ground truth...")
|
||||
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
|
||||
index_flat.add(xb)
|
||||
print(f"Searching FlatIP index with {len(xq_full)} queries (k={K_NEIGHBORS})...")
|
||||
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
|
||||
|
||||
# Create a specific directory for this index configuration
|
||||
index_dir = f"{INDEX_SAVING_FILE}/hahahdpr_hnsw_IP_M{M}_efC{efConstruction}"
|
||||
os.makedirs(index_dir, exist_ok=True)
|
||||
index_filename = f"{index_dir}/index.faiss"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(index_filename):
|
||||
print(f"Found existing index at {index_filename}, loading...")
|
||||
index = faiss.read_index(index_filename)
|
||||
print("Index loaded successfully.")
|
||||
else:
|
||||
print('Building HNSW index (IP)...')
|
||||
# add build time
|
||||
start_time = time.time()
|
||||
index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_INNER_PRODUCT)
|
||||
index.hnsw.efConstruction = efConstruction
|
||||
index.add(xb)
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
print('HNSW index built.')
|
||||
|
||||
# Save the HNSW index
|
||||
print(f"Saving index to {index_filename}...")
|
||||
faiss.write_index(index, index_filename)
|
||||
print("Index saved successfully.")
|
||||
|
||||
# Analyze the HNSW index
|
||||
print("\nAnalyzing HNSW index...")
|
||||
print(f"Total number of nodes: {index.ntotal}")
|
||||
print("Neighbor statistics:")
|
||||
print(index.hnsw.print_neighbor_stats(0))
|
||||
|
||||
# Save degree distribution
|
||||
distribution_filename = f"{index_dir}/degree_distribution.txt"
|
||||
print(f"Saving degree distribution to {distribution_filename}...")
|
||||
index.hnsw.save_degree_distribution(0, distribution_filename)
|
||||
print("Degree distribution saved successfully.")
|
||||
|
||||
# Plot the degree distribution
|
||||
plot_output_path = f"{index_dir}/degree_distribution.png"
|
||||
print(f"Generating degree distribution plot to {plot_output_path}...")
|
||||
try:
|
||||
subprocess.run(
|
||||
["python", f"{project_root}/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
|
||||
check=True
|
||||
)
|
||||
print(f"Degree distribution plot saved to {plot_output_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error generating degree distribution plot: {e}")
|
||||
except FileNotFoundError:
|
||||
print("Warning: plot_degree_distribution.py script not found in specified path")
|
||||
|
||||
print('Searching HNSW index...')
|
||||
|
||||
recall_result_file = f"{index_dir}/recall_result.txt"
|
||||
with open(recall_result_file, 'w') as f:
|
||||
for efSearch in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
|
||||
index.hnsw.efSearch = efSearch
|
||||
# calculate the time of searching
|
||||
start_time = time.time()
|
||||
|
||||
D, I = index.search(xq_full, K_NEIGHBORS)
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
|
||||
# calculate the recall using the flat index
|
||||
recall = []
|
||||
for i in range(len(I)):
|
||||
acc = 0
|
||||
for j in range(len(I[i])):
|
||||
if I[i][j] in recall_idx_flat[i]:
|
||||
acc += 1
|
||||
recall.append(acc / len(I[i]))
|
||||
recall = sum(recall) / len(recall)
|
||||
print(f'efSearch: {efSearch}')
|
||||
print(f'recall: {recall}')
|
||||
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
|
||||
print(f'Done and result saved to {recall_result_file}')
|
||||
212
packages/leann-backend-hnsw/third_party/faiss/demo/simple_build_nsg.py
vendored
Normal file
212
packages/leann-backend-hnsw/third_party/faiss/demo/simple_build_nsg.py
vendored
Normal file
@@ -0,0 +1,212 @@
|
||||
import sys
|
||||
import time
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
sys.path.append(os.path.join(project_root, "demo"))
|
||||
from config import SCALING_OUT_DIR, get_example_path, TASK_CONFIGS, get_embedding_path
|
||||
sys.path.append(project_root)
|
||||
from contriever.src.contriever import Contriever, load_retriever
|
||||
|
||||
M = 20
|
||||
efConstruction = 256
|
||||
K_NEIGHBORS = 3
|
||||
|
||||
# New configuration using DPR
|
||||
DOMAIN_NAME = "rpj_wiki"
|
||||
EMBEDDER_NAME = "facebook/contriever-msmarco"
|
||||
TASK_NAME = "nq"
|
||||
MAX_QUERIES_TO_LOAD = 1000
|
||||
QUERY_ENCODING_BATCH_SIZE = 64
|
||||
|
||||
# Get the embedding path using the function from config
|
||||
# embed_path = get_embedding_path(DOMAIN_NAME, EMBEDDER_NAME, 0)
|
||||
# INDEX_SAVING_FILE = os.path.join(os.path.dirname(embed_path), "indices")
|
||||
# os.makedirs(INDEX_SAVING_FILE, exist_ok=True)
|
||||
|
||||
# Original configuration (commented out)
|
||||
embed_path = "/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/passages_00.pkl"
|
||||
INDEX_SAVING_FILE = "/powerrag/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki_1M/1-shards/indices"
|
||||
|
||||
# Load embeddings
|
||||
print(f"Loading embeddings from {extend_path}...")
|
||||
with open(embed_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
xb = data[1]
|
||||
print(f"Original dtype: {xb.dtype}")
|
||||
|
||||
if xb.dtype != np.float32:
|
||||
print("Converting embeddings to float32.")
|
||||
xb = xb.astype(np.float32)
|
||||
else:
|
||||
print("Embeddings are already float32.")
|
||||
print(f"Loaded database embeddings (xb), shape: {xb.shape}")
|
||||
d = xb.shape[1] # Get dimension
|
||||
|
||||
query_file_path = TASK_CONFIGS[TASK_NAME].query_path
|
||||
print(f"Using query path from TASK_CONFIGS: {query_file_path}")
|
||||
|
||||
query_texts = []
|
||||
print(f"Reading queries from: {query_file_path}")
|
||||
with open(query_file_path, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= MAX_QUERIES_TO_LOAD:
|
||||
print(f"Stopped loading queries at limit: {MAX_QUERIES_TO_LOAD}")
|
||||
break
|
||||
record = json.loads(line)
|
||||
query_texts.append(record["query"])
|
||||
print(f"Loaded {len(query_texts)} query texts.")
|
||||
|
||||
print("\nInitializing retriever model for encoding queries...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
model, tokenizer, _ = load_retriever(EMBEDDER_NAME)
|
||||
model.to(device)
|
||||
model.eval() # Set to evaluation mode
|
||||
print("Retriever model loaded.")
|
||||
|
||||
|
||||
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
||||
"""Embed queries using the model with batching"""
|
||||
model = model.half()
|
||||
model.eval()
|
||||
embeddings = []
|
||||
batch_question = []
|
||||
|
||||
with torch.no_grad():
|
||||
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
||||
batch_question.append(query)
|
||||
|
||||
# Process when batch is full or at the end
|
||||
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_question,
|
||||
return_tensors="pt",
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
output = model(**encoded_batch)
|
||||
|
||||
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
||||
if "contriever" not in model_name_or_path:
|
||||
output = output.last_hidden_state[:, 0, :]
|
||||
|
||||
embeddings.append(output.cpu())
|
||||
batch_question = [] # Reset batch
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0).numpy()
|
||||
print(f"Query embeddings shape: {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
print(f"\nEncoding {len(query_texts)} queries (batch size: {QUERY_ENCODING_BATCH_SIZE})...")
|
||||
xq_full = embed_queries(query_texts, model, tokenizer, EMBEDDER_NAME, per_gpu_batch_size=QUERY_ENCODING_BATCH_SIZE)
|
||||
|
||||
# Ensure float32 for Faiss compatibility after encoding
|
||||
if xq_full.dtype != np.float32:
|
||||
print(f"Converting encoded queries from {xq_full.dtype} to float32.")
|
||||
xq_full = xq_full.astype(np.float32)
|
||||
|
||||
print(f"Encoded queries (xq_full), shape: {xq_full.shape}, dtype: {xq_full.dtype}")
|
||||
|
||||
# Check dimension consistency
|
||||
if xq_full.shape[1] != d:
|
||||
raise ValueError(f"Query embedding dimension ({xq_full.shape[1]}) does not match database dimension ({d})")
|
||||
|
||||
# Build flat index for ground truth
|
||||
print("\nBuilding FlatIP index for ground truth...")
|
||||
index_flat = faiss.IndexFlatIP(d) # Use Inner Product
|
||||
index_flat.add(xb)
|
||||
print(f"Searching FlatIP index with {len(xq_full)} queries (k={K_NEIGHBORS})...")
|
||||
D_flat, recall_idx_flat = index_flat.search(xq_full, k=K_NEIGHBORS)
|
||||
|
||||
# Create a specific directory for this index configuration
|
||||
index_dir = f"{INDEX_SAVING_FILE}/rpj_wiki_nsg_IP_M{M}"
|
||||
os.makedirs(index_dir, exist_ok=True)
|
||||
index_filename = f"{index_dir}/index.faiss"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(index_filename):
|
||||
print(f"Found existing index at {index_filename}, loading...")
|
||||
index = faiss.read_index(index_filename)
|
||||
print("Index loaded successfully.")
|
||||
else:
|
||||
print('Building HNSW index (IP)...')
|
||||
# add build time
|
||||
start_time = time.time()
|
||||
index = faiss.IndexNSGFlat(d, M, faiss.METRIC_INNER_PRODUCT)
|
||||
index.verbose = True
|
||||
index.add(xb)
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
print('HNSW index built.')
|
||||
|
||||
# Save the HNSW index
|
||||
print(f"Saving index to {index_filename}...")
|
||||
faiss.write_index(index, index_filename)
|
||||
print("Index saved successfully.")
|
||||
|
||||
# Analyze the HNSW index
|
||||
print("\nAnalyzing HNSW index...")
|
||||
print(f"Total number of nodes: {index.ntotal}")
|
||||
print("Neighbor statistics:")
|
||||
print(index.nsg.print_neighbor_stats(0))
|
||||
|
||||
# Save degree distribution
|
||||
distribution_filename = f"{index_dir}/degree_distribution.txt"
|
||||
print(f"Saving degree distribution to {distribution_filename}...")
|
||||
index.nsg.save_degree_distribution(distribution_filename)
|
||||
print("Degree distribution saved successfully.")
|
||||
|
||||
# Plot the degree distribution
|
||||
plot_output_path = f"{index_dir}/degree_distribution.png"
|
||||
print(f"Generating degree distribution plot to {plot_output_path}...")
|
||||
try:
|
||||
subprocess.run(
|
||||
["python", f"{project_root}/utils/plot_degree_distribution.py", distribution_filename, "-o", plot_output_path],
|
||||
check=True
|
||||
)
|
||||
print(f"Degree distribution plot saved to {plot_output_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error generating degree distribution plot: {e}")
|
||||
except FileNotFoundError:
|
||||
print("Warning: plot_degree_distribution.py script not found in specified path")
|
||||
|
||||
print('Searching HNSW index...')
|
||||
|
||||
recall_result_file = f"{index_dir}/recall_result.txt"
|
||||
with open(recall_result_file, 'w') as f:
|
||||
for efSearch in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
|
||||
index.nsg.efSearch = efSearch
|
||||
# calculate the time of searching
|
||||
start_time = time.time()
|
||||
|
||||
D, I = index.search(xq_full, K_NEIGHBORS)
|
||||
end_time = time.time()
|
||||
print(f'time: {end_time - start_time}')
|
||||
|
||||
# calculate the recall using the flat index
|
||||
recall = []
|
||||
for i in range(len(I)):
|
||||
acc = 0
|
||||
for j in range(len(I[i])):
|
||||
if I[i][j] in recall_idx_flat[i]:
|
||||
acc += 1
|
||||
recall.append(acc / len(I[i]))
|
||||
recall = sum(recall) / len(recall)
|
||||
print(f'efSearch: {efSearch}')
|
||||
print(f'recall: {recall}')
|
||||
f.write(f'efSearch: {efSearch}, recall: {recall}\n')
|
||||
print(f'Done and result saved to {recall_result_file}')
|
||||
22
packages/leann-backend-hnsw/third_party/faiss/demo/simple_search.fish
vendored
Normal file
22
packages/leann-backend-hnsw/third_party/faiss/demo/simple_search.fish
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
set -l index_dirs \
|
||||
/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index \
|
||||
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/index.faiss \
|
||||
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/index.faiss \
|
||||
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/index.faiss
|
||||
# /opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index
|
||||
|
||||
set -l index_labels \
|
||||
origin \
|
||||
0.01per_M4_degree_based \
|
||||
M8_merge_edge \
|
||||
random_delete50
|
||||
# nsg_R16
|
||||
|
||||
set -gx CUDA_VISIBLE_DEVICES 3
|
||||
|
||||
for i in (seq (count $index_dirs))
|
||||
set -l index_file $index_dirs[$i]
|
||||
set -l index_label $index_labels[$i]
|
||||
echo "Building HNSW index with $index_label..." >> ./large_graph_simple_build.log
|
||||
python -u large_graph_simple_build.py --index-file $index_file | tee -a ./large_graph_simple_build.log
|
||||
end
|
||||
Reference in New Issue
Block a user