739 lines
31 KiB
Python
739 lines
31 KiB
Python
#!/usr/bin/env python3
|
|
import subprocess
|
|
import json
|
|
import re
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from typing import List, Dict, Any, Tuple, Optional
|
|
import os
|
|
import time
|
|
import sys
|
|
import argparse
|
|
import concurrent.futures
|
|
import signal
|
|
import psutil
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--path-suffix", type=str, default="", help="Path suffix for the index")
|
|
parser.add_argument("--pq-compressed", type=int, default=None)
|
|
parser.add_argument("--beam-width", type=int, default=2, help="DiskANN beam width for search (controls number of IO requests per iteration)")
|
|
parser.add_argument("--index-type", type=str, default="diskann", help="Index type to test (default: diskann)")
|
|
parser.add_argument("--task", type=str, default="nq", help="Task to run (default: nq)")
|
|
parser.add_argument("--max-workers", type=int, default=1, help="Maximum number of concurrent processes")
|
|
parser.add_argument("--timeout", type=int, default=1800, help="Timeout for each process in seconds")
|
|
parser.add_argument("--retry-count", type=int, default=2, help="Number of retries for failed runs")
|
|
parser.add_argument(
|
|
"--target-recalls",
|
|
type=float,
|
|
nargs='+',
|
|
default=[0.85, 0.90, 0.95],
|
|
help="Target recalls to achieve (e.g., --target-recalls 0.85 0.90 0.95)"
|
|
)
|
|
args = parser.parse_args()
|
|
path_suffix = args.path_suffix
|
|
|
|
pq_compressed = args.pq_compressed
|
|
beam_width = args.beam_width
|
|
max_workers = args.max_workers
|
|
timeout = args.timeout
|
|
retry_count = args.retry_count
|
|
|
|
TARGET_RECALLS = args.target_recalls
|
|
|
|
task = args.task
|
|
|
|
# Process management
|
|
running_processes = {} # PID -> Process object
|
|
|
|
# Based on previous data, search around these values
|
|
if args.index_type == "diskann":
|
|
if task == "nq":
|
|
if pq_compressed is None:
|
|
NPROBE_RANGES = {
|
|
0.85: range(10, 50),
|
|
0.90: range(62, 67), # Narrow range around 64 (63, 64, 65, 66)
|
|
0.95: range(190, 195) # Narrow range around 192 (190, 191, 192, 193, 194)
|
|
}
|
|
elif pq_compressed == 10:
|
|
NPROBE_RANGES = {
|
|
0.85: range(10, 70),
|
|
0.90: range(90, 127), # Narrow range around 64 (63, 64, 65, 66)
|
|
0.95: range(200, 384) # Narrow range around 192 (190, 191, 192, 193, 194)
|
|
}
|
|
elif pq_compressed == 20:
|
|
NPROBE_RANGES = {
|
|
0.85: range(10, 50),
|
|
0.90: range(64, 128), # Narrow range around 64 (63, 64, 65, 66)
|
|
0.95: range(188, 192) # Narrow range around 192 (190, 191, 192, 193, 194)
|
|
}
|
|
elif pq_compressed == 5:
|
|
NPROBE_RANGES = {
|
|
0.85: range(10, 500),
|
|
0.90: range(768, 2000), # Narrow range around 64 (63, 64, 65, 66)
|
|
0.95: range(3000, 4096) # Narrow range around 192 (190, 191, 192, 193, 194)
|
|
}
|
|
elif task == "trivia":
|
|
if pq_compressed is None:
|
|
NPROBE_RANGES = {
|
|
0.85: range(90, 150),
|
|
0.90: range(150, 200), # Narrow range around 64 (63, 64, 65, 66)
|
|
0.95: range(200, 300) # Narrow range around 192 (190, 191, 192, 193, 194)
|
|
}
|
|
elif task == "gpqa":
|
|
if pq_compressed is None:
|
|
NPROBE_RANGES = {
|
|
0.85: range(1, 30),
|
|
0.90: range(1, 30), # Narrow range around 64 (63, 64, 65, 66)
|
|
0.95: range(1, 30) # Narrow range around 192 (190, 191, 192, 193, 194)
|
|
}
|
|
elif task == "hotpot":
|
|
if pq_compressed is None:
|
|
NPROBE_RANGES = {
|
|
0.85: range(19, 160),
|
|
0.90: range(120, 210), # Narrow range around 64 (63, 64, 65, 66)
|
|
0.95: range(1000, 1200) # Narrow range around 192 (190, 191, 192, 193, 194)
|
|
}
|
|
elif args.index_type == "ivf_disk":
|
|
if task == "nq":
|
|
assert pq_compressed is None
|
|
NPROBE_RANGES = {
|
|
0.85: range(13, 16),
|
|
0.90: range(30,40),
|
|
0.95: range(191, 194)
|
|
}
|
|
elif task == "trivia":
|
|
assert pq_compressed is None
|
|
NPROBE_RANGES = {
|
|
0.85: range(13, 50),
|
|
0.90: range(30, 100),
|
|
0.95: range(100, 400)
|
|
}
|
|
elif task == "gpqa":
|
|
assert pq_compressed is None
|
|
NPROBE_RANGES = {
|
|
0.85: range(1, 30),
|
|
0.90: range(1, 30), # Narrow range around 64 (63, 64, 65, 66)
|
|
0.95: range(1, 30) # Narrow range around 192 (190, 191, 192, 193, 194)
|
|
}
|
|
elif task == "hotpot":
|
|
NPROBE_RANGES = {
|
|
0.85: range(13, 100),
|
|
0.90: range(30, 200),
|
|
0.95: range(191, 700)
|
|
}
|
|
elif args.index_type == "hnsw":
|
|
if task == "nq":
|
|
NPROBE_RANGES = {
|
|
0.85: range(130, 140),
|
|
0.90: range(550, 666),
|
|
0.95: range(499, 1199),
|
|
}
|
|
if task == "gpqa":
|
|
NPROBE_RANGES = {
|
|
0.85: range(40, 70),
|
|
0.90: range(60, 100),
|
|
0.95: range(200, 500),
|
|
}
|
|
elif task == "hotpot":
|
|
NPROBE_RANGES = {
|
|
0.85: range(450, 480),
|
|
0.90: range(1000, 1300),
|
|
0.95: range(2000, 4000),
|
|
}
|
|
elif task == "trivia":
|
|
NPROBE_RANGES = {
|
|
0.85: range(100, 400),
|
|
0.90: range(700, 1800),
|
|
0.95: range(506, 1432)
|
|
}
|
|
|
|
# Create a directory for logs if it doesn't exist
|
|
os.makedirs("nprobe_logs", exist_ok=True)
|
|
|
|
# Set up signal handling for clean termination
|
|
def signal_handler(sig, frame):
|
|
print("Received termination signal. Cleaning up running processes...")
|
|
for pid, process in running_processes.items():
|
|
try:
|
|
if process.poll() is None: # Process is still running
|
|
print(f"Terminating process {pid}...")
|
|
process.terminate()
|
|
time.sleep(0.5)
|
|
if process.poll() is None: # If still running after terminate
|
|
print(f"Killing process {pid}...")
|
|
process.kill()
|
|
|
|
# Kill any child processes
|
|
try:
|
|
parent = psutil.Process(pid)
|
|
children = parent.children(recursive=True)
|
|
for child in children:
|
|
print(f"Killing child process {child.pid}...")
|
|
child.kill()
|
|
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
pass
|
|
except:
|
|
pass
|
|
|
|
print("All processes terminated. Exiting.")
|
|
sys.exit(1)
|
|
|
|
# Register signal handlers
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
def run_batch_demo(nprobe: int, retry: int = 0) -> Optional[float]:
|
|
"""Run main.py in batch mode with a specific nprobe value and extract the recall."""
|
|
command = f"python -u ./demo/main.py --search-only --load-indices {args.index_type} --domain rpj_wiki --lazy-load-passages --nprobe {nprobe} --task {task} --skip-passages"
|
|
if pq_compressed is not None:
|
|
command += f" --diskann-search-memory-maximum {pq_compressed}"
|
|
if beam_width is not None:
|
|
command += f" --diskann-beam-width {beam_width}"
|
|
if args.index_type == "hnsw":
|
|
command += f" --hnsw-old"
|
|
# command += " --embedder intfloat/multilingual-e5-small"
|
|
|
|
cmd = [
|
|
"fish", "-c",
|
|
# f"set -gx LD_PRELOAD \"/lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so /lib/x86_64-linux-gnu/libiomp5.so\" && "
|
|
"source ./.venv/bin/activate.fish &&"
|
|
+ command
|
|
]
|
|
|
|
print(f"Running with nprobe={nprobe}, beam_width={beam_width}, retry={retry}/{retry_count}")
|
|
log_file = f"nprobe_logs/nprobe_{nprobe}_beam{beam_width}_{path_suffix}_retry{retry}.log"
|
|
|
|
try:
|
|
# Also save the command to the log file
|
|
with open(log_file, "w") as f:
|
|
f.write(f"Command: {cmd[1]}\n\n")
|
|
f.write(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
|
f.write("=== OUTPUT BEGINS ===\n")
|
|
|
|
# Run the command and tee the output to both stdout and the log file
|
|
with open(log_file, "a") as f:
|
|
process = subprocess.Popen(
|
|
cmd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
text=True,
|
|
bufsize=1 # Line buffered
|
|
)
|
|
|
|
# Register the process for cleanup
|
|
pid = process.pid
|
|
running_processes[pid] = process
|
|
|
|
# Process output line by line for real-time logging
|
|
if process.stdout: # Check if stdout is not None
|
|
# Set a timeout
|
|
start_time = time.time()
|
|
current_output = ""
|
|
|
|
while process.poll() is None:
|
|
# Check for timeout
|
|
if time.time() - start_time > timeout:
|
|
print(f"Process timeout for nprobe={nprobe}, killing...")
|
|
f.write("\n\nProcess timed out, killing...\n")
|
|
process.terminate()
|
|
time.sleep(0.5)
|
|
if process.poll() is None:
|
|
process.kill()
|
|
|
|
# Clean up child processes
|
|
try:
|
|
parent = psutil.Process(pid)
|
|
children = parent.children(recursive=True)
|
|
for child in children:
|
|
child.kill()
|
|
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
pass
|
|
|
|
if pid in running_processes:
|
|
del running_processes[pid]
|
|
|
|
# Retry if we have attempts left
|
|
if retry < retry_count:
|
|
print(f"Retrying nprobe={nprobe}...")
|
|
return run_batch_demo(nprobe, retry + 1)
|
|
return None
|
|
|
|
# Read output with a small timeout to allow for process checking
|
|
try:
|
|
line = process.stdout.readline()
|
|
if not line:
|
|
time.sleep(0.1) # Small pause to avoid busy waiting
|
|
continue
|
|
|
|
print(line, end='') # Print to console
|
|
f.write(line) # Write to log file
|
|
f.flush() # Make sure it's written immediately
|
|
except:
|
|
time.sleep(0.1)
|
|
|
|
exit_code = process.wait()
|
|
|
|
# Process complete, remove from running list
|
|
if pid in running_processes:
|
|
del running_processes[pid]
|
|
|
|
f.write(f"\nExit code: {exit_code}\n")
|
|
f.write(f"End time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
|
|
# Re-read the log file to extract recall rate
|
|
with open(log_file, "r") as f:
|
|
log_content = f.read()
|
|
|
|
# Try multiple patterns to find recall rate
|
|
recall = None
|
|
patterns = [
|
|
fr"Avg recall rate for {args.index_type}: ([0-9.]+)",
|
|
r"recall: ([0-9.]+)",
|
|
fr"{args.index_type}.*?recall.*?([0-9.]+)",
|
|
fr"recall.*?{args.index_type}.*?([0-9.]+)"
|
|
]
|
|
|
|
for pattern in patterns:
|
|
matches = re.findall(pattern, log_content, re.IGNORECASE)
|
|
if matches:
|
|
try:
|
|
recall = float(matches[-1]) # Take the last one if multiple matches
|
|
print(f"Found recall rate using pattern: {pattern}")
|
|
break
|
|
except ValueError:
|
|
continue
|
|
|
|
if recall is None:
|
|
print(f"Warning: Could not extract recall rate from output log {log_file}")
|
|
# Try to find any number that looks like a recall rate (between 0 and 1)
|
|
possible_recalls = re.findall(r"recall.*?([0-9]+\.[0-9]+)", log_content, re.IGNORECASE)
|
|
if possible_recalls:
|
|
try:
|
|
recall_candidates = [float(r) for r in possible_recalls if 0 <= float(r) <= 1]
|
|
if recall_candidates:
|
|
recall = recall_candidates[-1] # Take the last one
|
|
print(f"Guessed recall rate: {recall} (based on pattern matching)")
|
|
except ValueError:
|
|
pass
|
|
|
|
if recall is None:
|
|
# Log this failure with more context
|
|
with open("nprobe_logs/failed_recalls.log", "a") as f:
|
|
f.write(f"Failed to extract recall for nprobe={nprobe} at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
|
|
# Retry if we have attempts left
|
|
if retry < retry_count:
|
|
print(f"Retrying nprobe={nprobe} due to failed recall extraction...")
|
|
return run_batch_demo(nprobe, retry + 1)
|
|
|
|
return None
|
|
|
|
print(f"nprobe={nprobe}, recall={recall:.4f}")
|
|
return recall
|
|
|
|
except subprocess.TimeoutExpired:
|
|
print(f"Command timed out for nprobe={nprobe}")
|
|
with open(log_file, "a") as f:
|
|
f.write("\n\nCommand timed out after 1800 seconds\n")
|
|
|
|
# Retry if we have attempts left
|
|
if retry < retry_count:
|
|
print(f"Retrying nprobe={nprobe}...")
|
|
return run_batch_demo(nprobe, retry + 1)
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"Error running command for nprobe={nprobe}: {e}")
|
|
with open(log_file, "a") as f:
|
|
f.write(f"\n\nError: {e}\n")
|
|
|
|
# Retry if we have attempts left
|
|
if retry < retry_count:
|
|
print(f"Retrying nprobe={nprobe} due to error: {e}...")
|
|
return run_batch_demo(nprobe, retry + 1)
|
|
|
|
return None
|
|
|
|
def batch_run_nprobe_values(nprobe_values):
|
|
"""Run multiple nprobe values in parallel with a thread pool."""
|
|
results = {}
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
future_to_nprobe = {executor.submit(run_batch_demo, nprobe): nprobe for nprobe in nprobe_values}
|
|
|
|
for future in concurrent.futures.as_completed(future_to_nprobe):
|
|
nprobe = future_to_nprobe[future]
|
|
try:
|
|
recall = future.result()
|
|
if recall is not None:
|
|
results[nprobe] = recall
|
|
print(f"Completed nprobe={nprobe} with recall={recall:.4f}")
|
|
except Exception as e:
|
|
print(f"Error processing nprobe={nprobe}: {e}")
|
|
|
|
return results
|
|
|
|
def adaptive_search_nprobe(target_recall: float, min_nprobe: int, max_nprobe: int, tolerance: float = 0.001) -> Dict:
|
|
"""
|
|
Use an adaptive search strategy to find the optimal nprobe value for a target recall.
|
|
Combines binary search with exploration to handle non-linear relationships.
|
|
|
|
Args:
|
|
target_recall: The target recall to achieve
|
|
min_nprobe: Minimum nprobe value to start search
|
|
max_nprobe: Maximum nprobe value for search
|
|
tolerance: How close we need to get to the target_recall
|
|
|
|
Returns:
|
|
Dictionary with the best nprobe, achieved recall, and other metadata
|
|
"""
|
|
print(f"\nAdaptive searching for nprobe that achieves {target_recall*100:.1f}% recall...")
|
|
print(f"Search range: {min_nprobe} - {max_nprobe}")
|
|
|
|
with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f:
|
|
f.write(f"\nAdaptive searching for nprobe that achieves {target_recall*100:.1f}% recall...\n")
|
|
f.write(f"Search range: {min_nprobe} - {max_nprobe}\n")
|
|
|
|
best_result = {"nprobe": None, "recall": None, "difference": float('inf')}
|
|
all_results = {"nprobe": [], "recall": []}
|
|
|
|
# Save initial file for this search
|
|
search_results_file = f"nprobe_logs/search_results_{path_suffix}_{target_recall:.2f}.json"
|
|
search_data = {
|
|
"target": target_recall,
|
|
"current_best": best_result,
|
|
"all_results": all_results,
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"search_range": {"min": min_nprobe, "max": max_nprobe}
|
|
}
|
|
|
|
with open(search_results_file, "w") as f:
|
|
json.dump(search_data, f, indent=2)
|
|
|
|
# Start with a strategic sampling to understand the recall curve
|
|
# Choose more points if the range is large
|
|
range_size = max_nprobe - min_nprobe
|
|
if range_size > 500:
|
|
num_initial_samples = 5
|
|
elif range_size > 100:
|
|
num_initial_samples = 4
|
|
else:
|
|
num_initial_samples = 3
|
|
|
|
sample_points = [min_nprobe]
|
|
step = range_size // (num_initial_samples - 1)
|
|
for i in range(1, num_initial_samples - 1):
|
|
sample_points.append(min_nprobe + i * step)
|
|
sample_points.append(max_nprobe)
|
|
|
|
# Run initial sample points in parallel
|
|
initial_results = batch_run_nprobe_values(sample_points)
|
|
|
|
# Update all_results and best_result based on initial_results
|
|
for nprobe, recall in initial_results.items():
|
|
all_results["nprobe"].append(nprobe)
|
|
all_results["recall"].append(recall)
|
|
|
|
diff = abs(recall - target_recall)
|
|
if diff < best_result["difference"]:
|
|
best_result = {"nprobe": nprobe, "recall": recall, "difference": diff}
|
|
|
|
# Update search results file
|
|
search_data = {
|
|
"target": target_recall,
|
|
"current_best": best_result,
|
|
"all_results": all_results,
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"search_range": {"min": min_nprobe, "max": max_nprobe}
|
|
}
|
|
with open(search_results_file, "w") as f:
|
|
json.dump(search_data, f, indent=2)
|
|
|
|
# Check if we've already reached target within tolerance
|
|
if best_result["difference"] <= tolerance:
|
|
print(f"Found good enough nprobe value: {best_result['nprobe']} with recall {best_result['recall']:.4f}")
|
|
return best_result
|
|
|
|
# Analyze initial results to decide on next strategy
|
|
# Sort results by nprobe
|
|
sorted_results = sorted([(n, r) for n, r in zip(all_results["nprobe"], all_results["recall"])])
|
|
nprobes, recalls = zip(*sorted_results)
|
|
|
|
# Check if the relationship is monotonic
|
|
is_monotonic = all(recalls[i] <= recalls[i+1] for i in range(len(recalls)-1)) or \
|
|
all(recalls[i] >= recalls[i+1] for i in range(len(recalls)-1))
|
|
|
|
if is_monotonic:
|
|
print("Relationship appears monotonic, proceeding with binary search.")
|
|
# Find the two closest points that bracket the target
|
|
bracket_low, bracket_high = None, None
|
|
for i in range(len(recalls)-1):
|
|
if (recalls[i] <= target_recall <= recalls[i+1]) or (recalls[i] >= target_recall >= recalls[i+1]):
|
|
bracket_low, bracket_high = nprobes[i], nprobes[i+1]
|
|
break
|
|
|
|
if bracket_low is None:
|
|
# Target is outside our current range, adjust range
|
|
if all(r < target_recall for r in recalls):
|
|
# All recalls are too low, need to increase nprobe
|
|
bracket_low = nprobes[-1]
|
|
bracket_high = min(max_nprobe, nprobes[-1] * 2)
|
|
else:
|
|
# All recalls are too high, need to decrease nprobe
|
|
bracket_low = max(min_nprobe, nprobes[0] // 2)
|
|
bracket_high = nprobes[0]
|
|
|
|
# Binary search between bracket_low and bracket_high
|
|
while abs(bracket_high - bracket_low) > 3:
|
|
mid_nprobe = (bracket_low + bracket_high) // 2
|
|
if mid_nprobe in initial_results:
|
|
mid_recall = initial_results[mid_nprobe]
|
|
else:
|
|
mid_recall = run_batch_demo(mid_nprobe)
|
|
if mid_recall is not None:
|
|
all_results["nprobe"].append(mid_nprobe)
|
|
all_results["recall"].append(mid_recall)
|
|
|
|
diff = abs(mid_recall - target_recall)
|
|
if diff < best_result["difference"]:
|
|
best_result = {"nprobe": mid_nprobe, "recall": mid_recall, "difference": diff}
|
|
|
|
# Update search results file
|
|
search_data["current_best"] = best_result
|
|
search_data["all_results"] = all_results
|
|
with open(search_results_file, "w") as f:
|
|
json.dump(search_data, f, indent=2)
|
|
|
|
# Check if we're close enough
|
|
if mid_recall is not None:
|
|
if abs(mid_recall - target_recall) <= tolerance:
|
|
break
|
|
|
|
# Adjust brackets
|
|
if mid_recall < target_recall:
|
|
bracket_low = mid_nprobe
|
|
else:
|
|
bracket_high = mid_nprobe
|
|
else:
|
|
# If we failed to get a result, try a different point
|
|
bracket_high = mid_nprobe - 1
|
|
else:
|
|
print("Relationship appears non-monotonic, using adaptive sampling.")
|
|
# For non-monotonic relationships, we'll use adaptive sampling
|
|
# First, find the best current point
|
|
best_idx = recalls.index(min(recalls, key=lambda r: abs(r - target_recall)))
|
|
best_nprobe = nprobes[best_idx]
|
|
|
|
# Try points around the best point with decreasing radius
|
|
radius = max(50, (max_nprobe - min_nprobe) // 10)
|
|
min_radius = 3
|
|
|
|
while radius >= min_radius:
|
|
# Try points at current radius around best_nprobe
|
|
test_points = []
|
|
lower_bound = max(min_nprobe, best_nprobe - radius)
|
|
upper_bound = min(max_nprobe, best_nprobe + radius)
|
|
|
|
if lower_bound not in initial_results and lower_bound != best_nprobe:
|
|
test_points.append(lower_bound)
|
|
if upper_bound not in initial_results and upper_bound != best_nprobe:
|
|
test_points.append(upper_bound)
|
|
|
|
# Add a point in the middle if range is large enough
|
|
if upper_bound - lower_bound > 2*radius/3 and len(test_points) < max_workers:
|
|
mid_point = (lower_bound + upper_bound) // 2
|
|
if mid_point not in initial_results and mid_point != best_nprobe:
|
|
test_points.append(mid_point)
|
|
|
|
# Run tests
|
|
if test_points:
|
|
new_results = batch_run_nprobe_values(test_points)
|
|
initial_results.update(new_results)
|
|
|
|
# Update all_results and best_result
|
|
for nprobe, recall in new_results.items():
|
|
all_results["nprobe"].append(nprobe)
|
|
all_results["recall"].append(recall)
|
|
|
|
diff = abs(recall - target_recall)
|
|
if diff < best_result["difference"]:
|
|
best_result = {"nprobe": nprobe, "recall": recall, "difference": diff}
|
|
best_nprobe = nprobe # Update the center for next iteration
|
|
|
|
# Update search results file
|
|
search_data["current_best"] = best_result
|
|
search_data["all_results"] = all_results
|
|
with open(search_results_file, "w") as f:
|
|
json.dump(search_data, f, indent=2)
|
|
|
|
# Check if we're close enough
|
|
if best_result["difference"] <= tolerance:
|
|
break
|
|
|
|
# Reduce radius for next iteration
|
|
radius = max(min_radius, radius // 2)
|
|
|
|
# After search, do a final fine-tuning around the best result
|
|
if best_result["nprobe"] is not None:
|
|
fine_tune_range = range(max(min_nprobe, best_result["nprobe"] - 2),
|
|
min(max_nprobe, best_result["nprobe"] + 3))
|
|
|
|
fine_tune_points = [n for n in fine_tune_range if n not in all_results["nprobe"]]
|
|
if fine_tune_points:
|
|
fine_tune_results = batch_run_nprobe_values(fine_tune_points)
|
|
|
|
for nprobe, recall in fine_tune_results.items():
|
|
all_results["nprobe"].append(nprobe)
|
|
all_results["recall"].append(recall)
|
|
|
|
diff = abs(recall - target_recall)
|
|
if diff < best_result["difference"]:
|
|
best_result = {"nprobe": nprobe, "recall": recall, "difference": diff}
|
|
|
|
# Final update to search results file
|
|
search_data["current_best"] = best_result
|
|
search_data["all_results"] = all_results
|
|
search_data["search_range"] = {"min": min_nprobe, "max": max_nprobe, "phase": "fine_tune"}
|
|
with open(search_results_file, "w") as f:
|
|
json.dump(search_data, f, indent=2)
|
|
|
|
return best_result
|
|
|
|
def find_optimal_nprobe_values():
|
|
"""Find the optimal nprobe values for target recall rates using adaptive search."""
|
|
# Dictionary to store results for each target recall
|
|
results = {}
|
|
# Dictionary to store all nprobe-recall pairs for plotting
|
|
all_data = {target: {"nprobe": [], "recall": []} for target in TARGET_RECALLS}
|
|
|
|
# Create a summary file for all runs
|
|
with open(f"nprobe_logs/summary_{path_suffix}.log", "w") as f:
|
|
f.write(f"Find optimal nprobe values - started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
|
f.write(f"Target recalls: {TARGET_RECALLS}\n")
|
|
f.write(f"nprobe ranges: {NPROBE_RANGES}\n\n")
|
|
f.write(f"Max workers: {max_workers}\n")
|
|
f.write(f"Timeout per process: {timeout}s\n")
|
|
f.write(f"Retry count: {retry_count}\n\n")
|
|
|
|
for target in TARGET_RECALLS:
|
|
# Use the existing NPROBE_RANGES to determine min and max values
|
|
min_nprobe = min(NPROBE_RANGES[target])
|
|
max_nprobe = max(NPROBE_RANGES[target])
|
|
|
|
print(f"\nUsing NPROBE_RANGES for target {target*100:.1f}%: {min_nprobe} to {max_nprobe}")
|
|
|
|
# Run adaptive search instead of binary search
|
|
best_result = adaptive_search_nprobe(
|
|
target_recall=target,
|
|
min_nprobe=min_nprobe,
|
|
max_nprobe=max_nprobe
|
|
)
|
|
|
|
results[target] = best_result
|
|
|
|
# Save all tested points to all_data for plotting
|
|
search_results_file = f"nprobe_logs/search_results_{path_suffix}_{target:.2f}.json"
|
|
try:
|
|
with open(search_results_file, "r") as f:
|
|
search_data = json.load(f)
|
|
if "all_results" in search_data:
|
|
all_data[target]["nprobe"] = search_data["all_results"]["nprobe"]
|
|
all_data[target]["recall"] = search_data["all_results"]["recall"]
|
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
|
print(f"Warning: Could not load search results for {target}: {e}")
|
|
|
|
print(f"For target recall {target*100:.1f}%:")
|
|
print(f" Best nprobe value: {best_result['nprobe']}")
|
|
print(f" Achieved recall: {best_result['recall']:.4f}")
|
|
print(f" Difference: {best_result['difference']:.4f}")
|
|
|
|
with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f:
|
|
f.write(f"For target recall {target*100:.1f}%:\n")
|
|
f.write(f" Best nprobe value: {best_result['nprobe']}\n")
|
|
f.write(f" Achieved recall: {best_result['recall']:.4f}\n")
|
|
f.write(f" Difference: {best_result['difference']:.4f}\n")
|
|
|
|
# Plot the results if we have data
|
|
if all_data and any(data["nprobe"] for data in all_data.values()):
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# Plot each target's data
|
|
for target in TARGET_RECALLS:
|
|
if not all_data[target]["nprobe"]:
|
|
continue
|
|
|
|
nprobe_values = all_data[target]["nprobe"]
|
|
recall_values = all_data[target]["recall"]
|
|
|
|
# Sort data points for better visualization
|
|
sorted_points = sorted(zip(nprobe_values, recall_values))
|
|
sorted_nprobe, sorted_recall = zip(*sorted_points) if sorted_points else ([], [])
|
|
|
|
plt.plot(sorted_nprobe, sorted_recall, 'o-',
|
|
label=f"Target {target*100:.1f}%, Best={results[target]['nprobe']}")
|
|
|
|
# Mark the optimal point
|
|
opt_nprobe = results[target]["nprobe"]
|
|
opt_recall = results[target]["recall"]
|
|
plt.plot(opt_nprobe, opt_recall, 'r*', markersize=15)
|
|
|
|
# Add a horizontal line at the target recall
|
|
plt.axhline(y=target, color='gray', linestyle='--', alpha=0.5)
|
|
|
|
plt.xlabel('nprobe value')
|
|
plt.ylabel('Recall rate')
|
|
plt.title(f'Recall Rate vs nprobe Value (Max Workers: {max_workers})')
|
|
plt.grid(True)
|
|
plt.legend()
|
|
plt.savefig(f'nprobe_logs/nprobe_vs_recall_{path_suffix}.png')
|
|
print(f"Plot saved to nprobe_logs/nprobe_vs_recall_{path_suffix}.png")
|
|
else:
|
|
print("No data to plot.")
|
|
with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f:
|
|
f.write("No data to plot.\n")
|
|
|
|
# Save final results
|
|
with open(f"nprobe_logs/optimal_nprobe_values_{path_suffix}.json", "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f:
|
|
f.write(f"\nFind optimal nprobe values - finished at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
if results:
|
|
f.write("\nOptimal nprobe values for target recall rates:\n")
|
|
for target, data in results.items():
|
|
f.write(f"{target*100:.1f}% recall: nprobe={data['nprobe']} (actual recall: {data['recall']:.4f})\n")
|
|
else:
|
|
f.write("No optimal nprobe values found.\n")
|
|
|
|
return results
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
results = find_optimal_nprobe_values()
|
|
|
|
if not results:
|
|
print("No optimal nprobe values found.")
|
|
sys.exit(1)
|
|
|
|
print("\nOptimal nprobe values for target recall rates:")
|
|
for target, data in results.items():
|
|
print(f"{target*100:.1f}% recall: nprobe={data['nprobe']} (actual recall: {data['recall']:.4f})")
|
|
|
|
# Generate the command for running the latency test with the optimal nprobe values
|
|
optimal_values = [data["nprobe"] for target, data in sorted(results.items())]
|
|
test_cmd = f"source ./.venv/bin/activate.fish && cd ~ && python ./Power-RAG/demo/test_serve.py --nprobe_values {' '.join(map(str, optimal_values))}"
|
|
|
|
print("\nRun this command to test latency with the optimal nprobe values:")
|
|
print(test_cmd)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nScript interrupted by user. Cleaning up running processes...")
|
|
signal_handler(signal.SIGINT, None)
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
# Clean up any running processes before re-raising
|
|
import traceback
|
|
traceback.print_exc()
|
|
signal_handler(signal.SIGINT, None)
|
|
raise e |