#!/usr/bin/env python3 import subprocess import json import re import numpy as np import matplotlib.pyplot as plt from typing import List, Dict, Any, Tuple, Optional import os import time import sys import argparse import concurrent.futures import signal import psutil parser = argparse.ArgumentParser() parser.add_argument("--path-suffix", type=str, default="", help="Path suffix for the index") parser.add_argument("--pq-compressed", type=int, default=None) parser.add_argument("--beam-width", type=int, default=2, help="DiskANN beam width for search (controls number of IO requests per iteration)") parser.add_argument("--index-type", type=str, default="diskann", help="Index type to test (default: diskann)") parser.add_argument("--task", type=str, default="nq", help="Task to run (default: nq)") parser.add_argument("--max-workers", type=int, default=1, help="Maximum number of concurrent processes") parser.add_argument("--timeout", type=int, default=1800, help="Timeout for each process in seconds") parser.add_argument("--retry-count", type=int, default=2, help="Number of retries for failed runs") parser.add_argument( "--target-recalls", type=float, nargs='+', default=[0.85, 0.90, 0.95], help="Target recalls to achieve (e.g., --target-recalls 0.85 0.90 0.95)" ) args = parser.parse_args() path_suffix = args.path_suffix pq_compressed = args.pq_compressed beam_width = args.beam_width max_workers = args.max_workers timeout = args.timeout retry_count = args.retry_count TARGET_RECALLS = args.target_recalls task = args.task # Process management running_processes = {} # PID -> Process object # Based on previous data, search around these values if args.index_type == "diskann": if task == "nq": if pq_compressed is None: NPROBE_RANGES = { 0.85: range(10, 50), 0.90: range(62, 67), # Narrow range around 64 (63, 64, 65, 66) 0.95: range(190, 195) # Narrow range around 192 (190, 191, 192, 193, 194) } elif pq_compressed == 10: NPROBE_RANGES = { 0.85: range(10, 70), 0.90: range(90, 127), # Narrow range around 64 (63, 64, 65, 66) 0.95: range(200, 384) # Narrow range around 192 (190, 191, 192, 193, 194) } elif pq_compressed == 20: NPROBE_RANGES = { 0.85: range(10, 50), 0.90: range(64, 128), # Narrow range around 64 (63, 64, 65, 66) 0.95: range(188, 192) # Narrow range around 192 (190, 191, 192, 193, 194) } elif pq_compressed == 5: NPROBE_RANGES = { 0.85: range(10, 500), 0.90: range(768, 2000), # Narrow range around 64 (63, 64, 65, 66) 0.95: range(3000, 4096) # Narrow range around 192 (190, 191, 192, 193, 194) } elif task == "trivia": if pq_compressed is None: NPROBE_RANGES = { 0.85: range(90, 150), 0.90: range(150, 200), # Narrow range around 64 (63, 64, 65, 66) 0.95: range(200, 300) # Narrow range around 192 (190, 191, 192, 193, 194) } elif task == "gpqa": if pq_compressed is None: NPROBE_RANGES = { 0.85: range(1, 30), 0.90: range(1, 30), # Narrow range around 64 (63, 64, 65, 66) 0.95: range(1, 30) # Narrow range around 192 (190, 191, 192, 193, 194) } elif task == "hotpot": if pq_compressed is None: NPROBE_RANGES = { 0.85: range(19, 160), 0.90: range(120, 210), # Narrow range around 64 (63, 64, 65, 66) 0.95: range(1000, 1200) # Narrow range around 192 (190, 191, 192, 193, 194) } elif args.index_type == "ivf_disk": if task == "nq": assert pq_compressed is None NPROBE_RANGES = { 0.85: range(13, 16), 0.90: range(30,40), 0.95: range(191, 194) } elif task == "trivia": assert pq_compressed is None NPROBE_RANGES = { 0.85: range(13, 50), 0.90: range(30, 100), 0.95: range(100, 400) } elif task == "gpqa": assert pq_compressed is None NPROBE_RANGES = { 0.85: range(1, 30), 0.90: range(1, 30), # Narrow range around 64 (63, 64, 65, 66) 0.95: range(1, 30) # Narrow range around 192 (190, 191, 192, 193, 194) } elif task == "hotpot": NPROBE_RANGES = { 0.85: range(13, 100), 0.90: range(30, 200), 0.95: range(191, 700) } elif args.index_type == "hnsw": if task == "nq": NPROBE_RANGES = { 0.85: range(130, 140), 0.90: range(550, 666), 0.95: range(499, 1199), } if task == "gpqa": NPROBE_RANGES = { 0.85: range(40, 70), 0.90: range(60, 100), 0.95: range(200, 500), } elif task == "hotpot": NPROBE_RANGES = { 0.85: range(450, 480), 0.90: range(1000, 1300), 0.95: range(2000, 4000), } elif task == "trivia": NPROBE_RANGES = { 0.85: range(100, 400), 0.90: range(700, 1800), 0.95: range(506, 1432) } # Create a directory for logs if it doesn't exist os.makedirs("nprobe_logs", exist_ok=True) # Set up signal handling for clean termination def signal_handler(sig, frame): print("Received termination signal. Cleaning up running processes...") for pid, process in running_processes.items(): try: if process.poll() is None: # Process is still running print(f"Terminating process {pid}...") process.terminate() time.sleep(0.5) if process.poll() is None: # If still running after terminate print(f"Killing process {pid}...") process.kill() # Kill any child processes try: parent = psutil.Process(pid) children = parent.children(recursive=True) for child in children: print(f"Killing child process {child.pid}...") child.kill() except (psutil.NoSuchProcess, psutil.AccessDenied): pass except: pass print("All processes terminated. Exiting.") sys.exit(1) # Register signal handlers signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) def run_batch_demo(nprobe: int, retry: int = 0) -> Optional[float]: """Run main.py in batch mode with a specific nprobe value and extract the recall.""" command = f"python -u ./demo/main.py --search-only --load-indices {args.index_type} --domain rpj_wiki --lazy-load-passages --nprobe {nprobe} --task {task} --skip-passages" if pq_compressed is not None: command += f" --diskann-search-memory-maximum {pq_compressed}" if beam_width is not None: command += f" --diskann-beam-width {beam_width}" if args.index_type == "hnsw": command += f" --hnsw-old" # command += " --embedder intfloat/multilingual-e5-small" cmd = [ "fish", "-c", # f"set -gx LD_PRELOAD \"/lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so /lib/x86_64-linux-gnu/libiomp5.so\" && " "source ./.venv/bin/activate.fish &&" + command ] print(f"Running with nprobe={nprobe}, beam_width={beam_width}, retry={retry}/{retry_count}") log_file = f"nprobe_logs/nprobe_{nprobe}_beam{beam_width}_{path_suffix}_retry{retry}.log" try: # Also save the command to the log file with open(log_file, "w") as f: f.write(f"Command: {cmd[1]}\n\n") f.write(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") f.write("=== OUTPUT BEGINS ===\n") # Run the command and tee the output to both stdout and the log file with open(log_file, "a") as f: process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1 # Line buffered ) # Register the process for cleanup pid = process.pid running_processes[pid] = process # Process output line by line for real-time logging if process.stdout: # Check if stdout is not None # Set a timeout start_time = time.time() current_output = "" while process.poll() is None: # Check for timeout if time.time() - start_time > timeout: print(f"Process timeout for nprobe={nprobe}, killing...") f.write("\n\nProcess timed out, killing...\n") process.terminate() time.sleep(0.5) if process.poll() is None: process.kill() # Clean up child processes try: parent = psutil.Process(pid) children = parent.children(recursive=True) for child in children: child.kill() except (psutil.NoSuchProcess, psutil.AccessDenied): pass if pid in running_processes: del running_processes[pid] # Retry if we have attempts left if retry < retry_count: print(f"Retrying nprobe={nprobe}...") return run_batch_demo(nprobe, retry + 1) return None # Read output with a small timeout to allow for process checking try: line = process.stdout.readline() if not line: time.sleep(0.1) # Small pause to avoid busy waiting continue print(line, end='') # Print to console f.write(line) # Write to log file f.flush() # Make sure it's written immediately except: time.sleep(0.1) exit_code = process.wait() # Process complete, remove from running list if pid in running_processes: del running_processes[pid] f.write(f"\nExit code: {exit_code}\n") f.write(f"End time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") # Re-read the log file to extract recall rate with open(log_file, "r") as f: log_content = f.read() # Try multiple patterns to find recall rate recall = None patterns = [ fr"Avg recall rate for {args.index_type}: ([0-9.]+)", r"recall: ([0-9.]+)", fr"{args.index_type}.*?recall.*?([0-9.]+)", fr"recall.*?{args.index_type}.*?([0-9.]+)" ] for pattern in patterns: matches = re.findall(pattern, log_content, re.IGNORECASE) if matches: try: recall = float(matches[-1]) # Take the last one if multiple matches print(f"Found recall rate using pattern: {pattern}") break except ValueError: continue if recall is None: print(f"Warning: Could not extract recall rate from output log {log_file}") # Try to find any number that looks like a recall rate (between 0 and 1) possible_recalls = re.findall(r"recall.*?([0-9]+\.[0-9]+)", log_content, re.IGNORECASE) if possible_recalls: try: recall_candidates = [float(r) for r in possible_recalls if 0 <= float(r) <= 1] if recall_candidates: recall = recall_candidates[-1] # Take the last one print(f"Guessed recall rate: {recall} (based on pattern matching)") except ValueError: pass if recall is None: # Log this failure with more context with open("nprobe_logs/failed_recalls.log", "a") as f: f.write(f"Failed to extract recall for nprobe={nprobe} at {time.strftime('%Y-%m-%d %H:%M:%S')}\n") # Retry if we have attempts left if retry < retry_count: print(f"Retrying nprobe={nprobe} due to failed recall extraction...") return run_batch_demo(nprobe, retry + 1) return None print(f"nprobe={nprobe}, recall={recall:.4f}") return recall except subprocess.TimeoutExpired: print(f"Command timed out for nprobe={nprobe}") with open(log_file, "a") as f: f.write("\n\nCommand timed out after 1800 seconds\n") # Retry if we have attempts left if retry < retry_count: print(f"Retrying nprobe={nprobe}...") return run_batch_demo(nprobe, retry + 1) return None except Exception as e: print(f"Error running command for nprobe={nprobe}: {e}") with open(log_file, "a") as f: f.write(f"\n\nError: {e}\n") # Retry if we have attempts left if retry < retry_count: print(f"Retrying nprobe={nprobe} due to error: {e}...") return run_batch_demo(nprobe, retry + 1) return None def batch_run_nprobe_values(nprobe_values): """Run multiple nprobe values in parallel with a thread pool.""" results = {} with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_nprobe = {executor.submit(run_batch_demo, nprobe): nprobe for nprobe in nprobe_values} for future in concurrent.futures.as_completed(future_to_nprobe): nprobe = future_to_nprobe[future] try: recall = future.result() if recall is not None: results[nprobe] = recall print(f"Completed nprobe={nprobe} with recall={recall:.4f}") except Exception as e: print(f"Error processing nprobe={nprobe}: {e}") return results def adaptive_search_nprobe(target_recall: float, min_nprobe: int, max_nprobe: int, tolerance: float = 0.001) -> Dict: """ Use an adaptive search strategy to find the optimal nprobe value for a target recall. Combines binary search with exploration to handle non-linear relationships. Args: target_recall: The target recall to achieve min_nprobe: Minimum nprobe value to start search max_nprobe: Maximum nprobe value for search tolerance: How close we need to get to the target_recall Returns: Dictionary with the best nprobe, achieved recall, and other metadata """ print(f"\nAdaptive searching for nprobe that achieves {target_recall*100:.1f}% recall...") print(f"Search range: {min_nprobe} - {max_nprobe}") with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f: f.write(f"\nAdaptive searching for nprobe that achieves {target_recall*100:.1f}% recall...\n") f.write(f"Search range: {min_nprobe} - {max_nprobe}\n") best_result = {"nprobe": None, "recall": None, "difference": float('inf')} all_results = {"nprobe": [], "recall": []} # Save initial file for this search search_results_file = f"nprobe_logs/search_results_{path_suffix}_{target_recall:.2f}.json" search_data = { "target": target_recall, "current_best": best_result, "all_results": all_results, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "search_range": {"min": min_nprobe, "max": max_nprobe} } with open(search_results_file, "w") as f: json.dump(search_data, f, indent=2) # Start with a strategic sampling to understand the recall curve # Choose more points if the range is large range_size = max_nprobe - min_nprobe if range_size > 500: num_initial_samples = 5 elif range_size > 100: num_initial_samples = 4 else: num_initial_samples = 3 sample_points = [min_nprobe] step = range_size // (num_initial_samples - 1) for i in range(1, num_initial_samples - 1): sample_points.append(min_nprobe + i * step) sample_points.append(max_nprobe) # Run initial sample points in parallel initial_results = batch_run_nprobe_values(sample_points) # Update all_results and best_result based on initial_results for nprobe, recall in initial_results.items(): all_results["nprobe"].append(nprobe) all_results["recall"].append(recall) diff = abs(recall - target_recall) if diff < best_result["difference"]: best_result = {"nprobe": nprobe, "recall": recall, "difference": diff} # Update search results file search_data = { "target": target_recall, "current_best": best_result, "all_results": all_results, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "search_range": {"min": min_nprobe, "max": max_nprobe} } with open(search_results_file, "w") as f: json.dump(search_data, f, indent=2) # Check if we've already reached target within tolerance if best_result["difference"] <= tolerance: print(f"Found good enough nprobe value: {best_result['nprobe']} with recall {best_result['recall']:.4f}") return best_result # Analyze initial results to decide on next strategy # Sort results by nprobe sorted_results = sorted([(n, r) for n, r in zip(all_results["nprobe"], all_results["recall"])]) nprobes, recalls = zip(*sorted_results) # Check if the relationship is monotonic is_monotonic = all(recalls[i] <= recalls[i+1] for i in range(len(recalls)-1)) or \ all(recalls[i] >= recalls[i+1] for i in range(len(recalls)-1)) if is_monotonic: print("Relationship appears monotonic, proceeding with binary search.") # Find the two closest points that bracket the target bracket_low, bracket_high = None, None for i in range(len(recalls)-1): if (recalls[i] <= target_recall <= recalls[i+1]) or (recalls[i] >= target_recall >= recalls[i+1]): bracket_low, bracket_high = nprobes[i], nprobes[i+1] break if bracket_low is None: # Target is outside our current range, adjust range if all(r < target_recall for r in recalls): # All recalls are too low, need to increase nprobe bracket_low = nprobes[-1] bracket_high = min(max_nprobe, nprobes[-1] * 2) else: # All recalls are too high, need to decrease nprobe bracket_low = max(min_nprobe, nprobes[0] // 2) bracket_high = nprobes[0] # Binary search between bracket_low and bracket_high while abs(bracket_high - bracket_low) > 3: mid_nprobe = (bracket_low + bracket_high) // 2 if mid_nprobe in initial_results: mid_recall = initial_results[mid_nprobe] else: mid_recall = run_batch_demo(mid_nprobe) if mid_recall is not None: all_results["nprobe"].append(mid_nprobe) all_results["recall"].append(mid_recall) diff = abs(mid_recall - target_recall) if diff < best_result["difference"]: best_result = {"nprobe": mid_nprobe, "recall": mid_recall, "difference": diff} # Update search results file search_data["current_best"] = best_result search_data["all_results"] = all_results with open(search_results_file, "w") as f: json.dump(search_data, f, indent=2) # Check if we're close enough if mid_recall is not None: if abs(mid_recall - target_recall) <= tolerance: break # Adjust brackets if mid_recall < target_recall: bracket_low = mid_nprobe else: bracket_high = mid_nprobe else: # If we failed to get a result, try a different point bracket_high = mid_nprobe - 1 else: print("Relationship appears non-monotonic, using adaptive sampling.") # For non-monotonic relationships, we'll use adaptive sampling # First, find the best current point best_idx = recalls.index(min(recalls, key=lambda r: abs(r - target_recall))) best_nprobe = nprobes[best_idx] # Try points around the best point with decreasing radius radius = max(50, (max_nprobe - min_nprobe) // 10) min_radius = 3 while radius >= min_radius: # Try points at current radius around best_nprobe test_points = [] lower_bound = max(min_nprobe, best_nprobe - radius) upper_bound = min(max_nprobe, best_nprobe + radius) if lower_bound not in initial_results and lower_bound != best_nprobe: test_points.append(lower_bound) if upper_bound not in initial_results and upper_bound != best_nprobe: test_points.append(upper_bound) # Add a point in the middle if range is large enough if upper_bound - lower_bound > 2*radius/3 and len(test_points) < max_workers: mid_point = (lower_bound + upper_bound) // 2 if mid_point not in initial_results and mid_point != best_nprobe: test_points.append(mid_point) # Run tests if test_points: new_results = batch_run_nprobe_values(test_points) initial_results.update(new_results) # Update all_results and best_result for nprobe, recall in new_results.items(): all_results["nprobe"].append(nprobe) all_results["recall"].append(recall) diff = abs(recall - target_recall) if diff < best_result["difference"]: best_result = {"nprobe": nprobe, "recall": recall, "difference": diff} best_nprobe = nprobe # Update the center for next iteration # Update search results file search_data["current_best"] = best_result search_data["all_results"] = all_results with open(search_results_file, "w") as f: json.dump(search_data, f, indent=2) # Check if we're close enough if best_result["difference"] <= tolerance: break # Reduce radius for next iteration radius = max(min_radius, radius // 2) # After search, do a final fine-tuning around the best result if best_result["nprobe"] is not None: fine_tune_range = range(max(min_nprobe, best_result["nprobe"] - 2), min(max_nprobe, best_result["nprobe"] + 3)) fine_tune_points = [n for n in fine_tune_range if n not in all_results["nprobe"]] if fine_tune_points: fine_tune_results = batch_run_nprobe_values(fine_tune_points) for nprobe, recall in fine_tune_results.items(): all_results["nprobe"].append(nprobe) all_results["recall"].append(recall) diff = abs(recall - target_recall) if diff < best_result["difference"]: best_result = {"nprobe": nprobe, "recall": recall, "difference": diff} # Final update to search results file search_data["current_best"] = best_result search_data["all_results"] = all_results search_data["search_range"] = {"min": min_nprobe, "max": max_nprobe, "phase": "fine_tune"} with open(search_results_file, "w") as f: json.dump(search_data, f, indent=2) return best_result def find_optimal_nprobe_values(): """Find the optimal nprobe values for target recall rates using adaptive search.""" # Dictionary to store results for each target recall results = {} # Dictionary to store all nprobe-recall pairs for plotting all_data = {target: {"nprobe": [], "recall": []} for target in TARGET_RECALLS} # Create a summary file for all runs with open(f"nprobe_logs/summary_{path_suffix}.log", "w") as f: f.write(f"Find optimal nprobe values - started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") f.write(f"Target recalls: {TARGET_RECALLS}\n") f.write(f"nprobe ranges: {NPROBE_RANGES}\n\n") f.write(f"Max workers: {max_workers}\n") f.write(f"Timeout per process: {timeout}s\n") f.write(f"Retry count: {retry_count}\n\n") for target in TARGET_RECALLS: # Use the existing NPROBE_RANGES to determine min and max values min_nprobe = min(NPROBE_RANGES[target]) max_nprobe = max(NPROBE_RANGES[target]) print(f"\nUsing NPROBE_RANGES for target {target*100:.1f}%: {min_nprobe} to {max_nprobe}") # Run adaptive search instead of binary search best_result = adaptive_search_nprobe( target_recall=target, min_nprobe=min_nprobe, max_nprobe=max_nprobe ) results[target] = best_result # Save all tested points to all_data for plotting search_results_file = f"nprobe_logs/search_results_{path_suffix}_{target:.2f}.json" try: with open(search_results_file, "r") as f: search_data = json.load(f) if "all_results" in search_data: all_data[target]["nprobe"] = search_data["all_results"]["nprobe"] all_data[target]["recall"] = search_data["all_results"]["recall"] except (FileNotFoundError, json.JSONDecodeError) as e: print(f"Warning: Could not load search results for {target}: {e}") print(f"For target recall {target*100:.1f}%:") print(f" Best nprobe value: {best_result['nprobe']}") print(f" Achieved recall: {best_result['recall']:.4f}") print(f" Difference: {best_result['difference']:.4f}") with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f: f.write(f"For target recall {target*100:.1f}%:\n") f.write(f" Best nprobe value: {best_result['nprobe']}\n") f.write(f" Achieved recall: {best_result['recall']:.4f}\n") f.write(f" Difference: {best_result['difference']:.4f}\n") # Plot the results if we have data if all_data and any(data["nprobe"] for data in all_data.values()): plt.figure(figsize=(10, 6)) # Plot each target's data for target in TARGET_RECALLS: if not all_data[target]["nprobe"]: continue nprobe_values = all_data[target]["nprobe"] recall_values = all_data[target]["recall"] # Sort data points for better visualization sorted_points = sorted(zip(nprobe_values, recall_values)) sorted_nprobe, sorted_recall = zip(*sorted_points) if sorted_points else ([], []) plt.plot(sorted_nprobe, sorted_recall, 'o-', label=f"Target {target*100:.1f}%, Best={results[target]['nprobe']}") # Mark the optimal point opt_nprobe = results[target]["nprobe"] opt_recall = results[target]["recall"] plt.plot(opt_nprobe, opt_recall, 'r*', markersize=15) # Add a horizontal line at the target recall plt.axhline(y=target, color='gray', linestyle='--', alpha=0.5) plt.xlabel('nprobe value') plt.ylabel('Recall rate') plt.title(f'Recall Rate vs nprobe Value (Max Workers: {max_workers})') plt.grid(True) plt.legend() plt.savefig(f'nprobe_logs/nprobe_vs_recall_{path_suffix}.png') print(f"Plot saved to nprobe_logs/nprobe_vs_recall_{path_suffix}.png") else: print("No data to plot.") with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f: f.write("No data to plot.\n") # Save final results with open(f"nprobe_logs/optimal_nprobe_values_{path_suffix}.json", "w") as f: json.dump(results, f, indent=2) with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f: f.write(f"\nFind optimal nprobe values - finished at {time.strftime('%Y-%m-%d %H:%M:%S')}\n") if results: f.write("\nOptimal nprobe values for target recall rates:\n") for target, data in results.items(): f.write(f"{target*100:.1f}% recall: nprobe={data['nprobe']} (actual recall: {data['recall']:.4f})\n") else: f.write("No optimal nprobe values found.\n") return results if __name__ == "__main__": try: results = find_optimal_nprobe_values() if not results: print("No optimal nprobe values found.") sys.exit(1) print("\nOptimal nprobe values for target recall rates:") for target, data in results.items(): print(f"{target*100:.1f}% recall: nprobe={data['nprobe']} (actual recall: {data['recall']:.4f})") # Generate the command for running the latency test with the optimal nprobe values optimal_values = [data["nprobe"] for target, data in sorted(results.items())] test_cmd = f"source ./.venv/bin/activate.fish && cd ~ && python ./Power-RAG/demo/test_serve.py --nprobe_values {' '.join(map(str, optimal_values))}" print("\nRun this command to test latency with the optimal nprobe values:") print(test_cmd) except KeyboardInterrupt: print("\nScript interrupted by user. Cleaning up running processes...") signal_handler(signal.SIGINT, None) sys.exit(1) except Exception as e: # Clean up any running processes before re-raising import traceback traceback.print_exc() signal_handler(signal.SIGINT, None) raise e