""" Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4) Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline, Stage 3 complexity sweep to target recall, Stage 4 index comparison. On errors, fail fast without fallbacks. """ import argparse import json import logging import os import pickle from pathlib import Path import numpy as np from leann import LeannBuilder, LeannSearcher from leann_backend_hnsw import faiss from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model # Setup logging to reduce verbose output logging.basicConfig(level=logging.WARNING) logging.getLogger("leann.api").setLevel(logging.WARNING) logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING) class RecallEvaluator: """Stage 2: Evaluate Recall@3 (LEANN vs FAISS)""" def __init__(self, index_path: str, baseline_dir: str): self.index_path = index_path self.baseline_dir = baseline_dir self.searcher = LeannSearcher(index_path) baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index") metadata_path = os.path.join(baseline_dir, "metadata.pkl") self.faiss_index = faiss.read_index(baseline_index_path) with open(metadata_path, "rb") as f: self.passage_ids = pickle.load(f) print(f"๐Ÿ“š Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors") # No fallbacks here; if embedding server is needed but fails, the caller will see the error. def evaluate_recall_at_3( self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True ) -> float: """Evaluate recall@3 using FAISS Flat as ground truth""" from leann.api import compute_embeddings recompute_str = "with recompute" if recompute_embeddings else "no recompute" print(f"๐Ÿ” Evaluating recall@3 with complexity={complexity} ({recompute_str})...") total_recall = 0.0 for i, query in enumerate(queries): # Compute query embedding with the same model/mode as the index q_emb = compute_embeddings( [query], self.searcher.embedding_model, mode=self.searcher.embedding_mode, use_server=False, ).astype(np.float32) # Search FAISS Flat ground truth n = q_emb.shape[0] k = 3 distances = np.zeros((n, k), dtype=np.float32) labels = np.zeros((n, k), dtype=np.int64) self.faiss_index.search( n, faiss.swig_ptr(q_emb), k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), ) baseline_ids = {self.passage_ids[idx] for idx in labels[0]} # Search with LEANN (may require embedding server depending on index configuration) results = self.searcher.search( query, top_k=3, complexity=complexity, recompute_embeddings=recompute_embeddings, ) test_ids = {r.id for r in results} intersection = test_ids.intersection(baseline_ids) recall = len(intersection) / 3.0 total_recall += recall if i < 3: print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}") print(f" FAISS: {list(baseline_ids)}") print(f" LEANN: {list(test_ids)}") print(f" โˆฉ: {list(intersection)}") avg = total_recall / max(1, len(queries)) print(f"๐Ÿ“Š Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)") return avg def cleanup(self): if hasattr(self, "searcher"): self.searcher.cleanup() class EnronEvaluator: def __init__(self, index_path: str): self.index_path = index_path self.searcher = LeannSearcher(index_path) def load_queries(self, queries_file: str) -> list[str]: queries: list[str] = [] with open(queries_file, encoding="utf-8") as f: for line in f: if not line.strip(): continue data = json.loads(line) if "query" in data: queries.append(data["query"]) print(f"๐Ÿ“Š Loaded {len(queries)} queries from {queries_file}") return queries def cleanup(self): if self.searcher: self.searcher.cleanup() def analyze_index_sizes(self) -> dict: """Analyze index sizes (.index only), similar to LAION bench.""" print("๐Ÿ“ Analyzing index sizes (.index only)...") index_path = Path(self.index_path) index_dir = index_path.parent index_name = index_path.stem sizes: dict[str, float] = {} index_file = index_dir / f"{index_name}.index" meta_file = index_dir / f"{index_path.name}.meta.json" passages_file = index_dir / f"{index_path.name}.passages.jsonl" passages_idx_file = index_dir / f"{index_path.name}.passages.idx" sizes["index_only_mb"] = ( index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0 ) sizes["metadata_mb"] = ( meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0 ) sizes["passages_text_mb"] = ( passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0 ) sizes["passages_index_mb"] = ( passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0 ) print(f" ๐Ÿ“ .index size: {sizes['index_only_mb']:.1f} MB") return sizes def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict: """Create a non-compact index for comparison using current passages and embeddings.""" current_index_path = Path(self.index_path) current_index_dir = current_index_path.parent current_index_name = current_index_path.name # Read metadata to get passage source and embedding model meta_path = current_index_dir / f"{current_index_name}.meta.json" with open(meta_path, encoding="utf-8") as f: meta = json.load(f) passage_source = meta["passage_sources"][0] passage_file = passage_source["path"] # Convert relative path to absolute if not Path(passage_file).is_absolute(): passage_file = current_index_dir / Path(passage_file).name # Load all passages and ids ids: list[str] = [] texts: list[str] = [] with open(passage_file, encoding="utf-8") as f: for line in f: if line.strip(): data = json.loads(line) ids.append(str(data["id"])) texts.append(data["text"]) # Compute embeddings using the same method as LEANN from leann.api import compute_embeddings embeddings = compute_embeddings( texts, meta["embedding_model"], mode=meta.get("embedding_mode", "sentence-transformers"), use_server=False, ).astype(np.float32) # Build non-compact index with same passages and embeddings builder = LeannBuilder( backend_name="hnsw", embedding_model=meta["embedding_model"], embedding_mode=meta.get("embedding_mode", "sentence-transformers"), is_recompute=False, is_compact=False, **{ k: v for k, v in meta.get("backend_kwargs", {}).items() if k not in ["is_recompute", "is_compact"] }, ) # Persist a pickle for build_index_from_embeddings pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl" with open(pkl_path, "wb") as pf: pickle.dump((ids, embeddings), pf) print( f"๐Ÿ”จ Building non-compact index at {non_compact_index_path} from precomputed embeddings..." ) builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path)) # Analyze the non-compact index size temp_evaluator = EnronEvaluator(non_compact_index_path) non_compact_sizes = temp_evaluator.analyze_index_sizes() non_compact_sizes["index_type"] = "non_compact" return non_compact_sizes def compare_index_performance( self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int ) -> dict: """Compare search speed for non-compact vs compact indexes.""" import time results: dict = { "non_compact": {"search_times": []}, "compact": {"search_times": []}, "avg_search_times": {}, "speed_ratio": 0.0, "retrieval_results": [], # Store retrieval results for Stage 5 } print("โšก Comparing search performance between indexes...") # Non-compact (no recompute) print(" ๐Ÿ” Testing non-compact index (no recompute)...") non_compact_searcher = LeannSearcher(non_compact_path) for q in test_queries: t0 = time.time() _ = non_compact_searcher.search( q, top_k=3, complexity=complexity, recompute_embeddings=False ) results["non_compact"]["search_times"].append(time.time() - t0) # Compact (with recompute). Fail fast if it cannot run. print(" ๐Ÿ” Testing compact index (with recompute)...") compact_searcher = LeannSearcher(compact_path) for q in test_queries: t0 = time.time() docs = compact_searcher.search( q, top_k=3, complexity=complexity, recompute_embeddings=True ) results["compact"]["search_times"].append(time.time() - t0) # Store retrieval results for Stage 5 results["retrieval_results"].append( {"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]} ) compact_searcher.cleanup() if results["non_compact"]["search_times"]: results["avg_search_times"]["non_compact"] = sum( results["non_compact"]["search_times"] ) / len(results["non_compact"]["search_times"]) if results["compact"]["search_times"]: results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len( results["compact"]["search_times"] ) if results["avg_search_times"].get("compact", 0) > 0: results["speed_ratio"] = ( results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"] ) else: results["speed_ratio"] = 0.0 non_compact_searcher.cleanup() return results def evaluate_complexity( self, recall_eval: "RecallEvaluator", queries: list[str], target: float = 0.90, c_min: int = 8, c_max: int = 256, max_iters: int = 10, recompute: bool = False, ) -> dict: """Binary search minimal complexity achieving target recall (monotonic assumption).""" def round_c(x: int) -> int: # snap to multiple of 8 like other benches typically do return max(1, int((x + 7) // 8) * 8) metrics: list[dict] = [] lo = round_c(c_min) hi = round_c(c_max) print( f"๐Ÿงช Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..." ) # Ensure upper bound can reach target; expand if needed (up to a cap) r_lo = recall_eval.evaluate_recall_at_3( queries, complexity=lo, recompute_embeddings=recompute ) metrics.append({"complexity": lo, "recall_at_3": r_lo}) r_hi = recall_eval.evaluate_recall_at_3( queries, complexity=hi, recompute_embeddings=recompute ) metrics.append({"complexity": hi, "recall_at_3": r_hi}) cap = 1024 while r_hi < target and hi < cap: lo = hi r_lo = r_hi hi = round_c(hi * 2) r_hi = recall_eval.evaluate_recall_at_3( queries, complexity=hi, recompute_embeddings=recompute ) metrics.append({"complexity": hi, "recall_at_3": r_hi}) if r_hi < target: print(f"โš ๏ธ Max complexity {hi} did not reach target recall {target:.2f}.") print("๐Ÿ“ˆ Observations:") for m in metrics: print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%") return {"metrics": metrics, "best_complexity": None, "target_recall": target} # Binary search within [lo, hi] best = hi iters = 0 while lo < hi and iters < max_iters: mid = round_c((lo + hi) // 2) r_mid = recall_eval.evaluate_recall_at_3( queries, complexity=mid, recompute_embeddings=recompute ) metrics.append({"complexity": mid, "recall_at_3": r_mid}) if r_mid >= target: best = mid hi = mid else: lo = mid + 8 # move past mid, respecting multiple-of-8 step iters += 1 print("๐Ÿ“ˆ Binary search results (sampled points):") # Print unique complexity entries ordered by complexity for m in sorted( {m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"] ): print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%") print(f"โœ… Minimal complexity achieving {int(target * 100)}% recall: {best}") return {"metrics": metrics, "best_complexity": best, "target_recall": target} def main(): parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation") parser.add_argument("--index", required=True, help="Path to LEANN index") parser.add_argument( "--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries" ) parser.add_argument( "--stage", choices=["2", "3", "4", "5", "all", "45"], default="all", help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)", ) parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity") parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory") parser.add_argument( "--max-queries", type=int, help="Limit number of queries to evaluate", default=1000 ) parser.add_argument( "--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3" ) parser.add_argument("--output", help="Save results to JSON file") parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend") parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name") args = parser.parse_args() # Resolve queries file: if default path not found, fall back to index's directory if not os.path.exists(args.queries): from pathlib import Path idx_dir = Path(args.index).parent fallback_q = idx_dir / "evaluation_queries.jsonl" if fallback_q.exists(): args.queries = str(fallback_q) baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index") if not os.path.exists(baseline_index_path): print(f"โŒ FAISS baseline not found at {baseline_index_path}") print("๐Ÿ’ก Please run setup_enron_emails.py first to build the baseline") raise SystemExit(1) results_out: dict = {} if args.stage in ("2", "all"): print("๐Ÿš€ Starting Stage 2: Recall@3 evaluation") evaluator = RecallEvaluator(args.index, args.baseline_dir) enron_eval = EnronEvaluator(args.index) queries = enron_eval.load_queries(args.queries) queries = queries[:10] print(f"๐Ÿงช Using first {len(queries)} queries") complexity = args.complexity or 64 r = evaluator.evaluate_recall_at_3(queries, complexity) results_out["stage2"] = {"complexity": complexity, "recall_at_3": r} evaluator.cleanup() enron_eval.cleanup() print("โœ… Stage 2 completed!\n") if args.stage in ("3", "all"): print("๐Ÿš€ Starting Stage 3: Binary search for target recall (no recompute)") enron_eval = EnronEvaluator(args.index) queries = enron_eval.load_queries(args.queries) queries = queries[: args.max_queries] print(f"๐Ÿงช Using first {len(queries)} queries") # Build non-compact index for fast binary search (recompute_embeddings=False) from pathlib import Path index_path = Path(args.index) non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann") enron_eval.create_non_compact_index_for_comparison(non_compact_index_path) # Use non-compact evaluator for binary search with recompute=False evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir) sweep = enron_eval.evaluate_complexity( evaluator_nc, queries, target=args.target_recall, recompute=False ) results_out["stage3"] = sweep # Persist default stage 3 results near the index for Stage 4 auto-pickup from pathlib import Path default_stage3_path = Path(args.index).parent / "enron_stage3_results.json" with open(default_stage3_path, "w", encoding="utf-8") as f: json.dump({"stage3": sweep}, f, indent=2) print(f"๐Ÿ“ Saved Stage 3 summary to {default_stage3_path}") evaluator_nc.cleanup() enron_eval.cleanup() print("โœ… Stage 3 completed!\n") if args.stage in ("4", "all", "45"): print("๐Ÿš€ Starting Stage 4: Index size + performance comparison") evaluator = RecallEvaluator(args.index, args.baseline_dir) enron_eval = EnronEvaluator(args.index) queries = enron_eval.load_queries(args.queries) test_q = queries[: min(args.max_queries, len(queries))] current_sizes = enron_eval.analyze_index_sizes() # Build non-compact index for comparison (no fallback) from pathlib import Path index_path = Path(args.index) non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann") non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path) nc_eval = EnronEvaluator(non_compact_path) if ( current_sizes.get("index_only_mb", 0) > 0 and non_compact_sizes.get("index_only_mb", 0) > 0 ): storage_saving_percent = max( 0.0, 100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]), ) else: storage_saving_percent = 0.0 if args.complexity is None: # Prefer in-session Stage 3 result if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None: complexity = results_out["stage3"]["best_complexity"] print(f"๐Ÿ“ฅ Using best complexity from Stage 3 in-session: {complexity}") else: # Try to load last saved Stage 3 result near index default_stage3_path = Path(args.index).parent / "enron_stage3_results.json" if default_stage3_path.exists(): with open(default_stage3_path, encoding="utf-8") as f: prev = json.load(f) complexity = prev.get("stage3", {}).get("best_complexity") if complexity is None: raise SystemExit( "โŒ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results" ) print(f"๐Ÿ“ฅ Using best complexity from saved Stage 3: {complexity}") else: raise SystemExit( "โŒ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity." ) else: complexity = args.complexity comp = enron_eval.compare_index_performance( non_compact_path, args.index, test_q, complexity=complexity ) results_out["stage4"] = { "current_index": current_sizes, "non_compact_index": non_compact_sizes, "storage_saving_percent": storage_saving_percent, "performance_comparison": comp, } nc_eval.cleanup() evaluator.cleanup() enron_eval.cleanup() print("โœ… Stage 4 completed!\n") if args.stage in ("5", "all"): print("๐Ÿš€ Starting Stage 5: Generation evaluation with Qwen3-8B") # Check if Stage 4 results exist if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]: print("โŒ Stage 5 requires Stage 4 retrieval results") print("๐Ÿ’ก Run Stage 4 first or use --stage all") raise SystemExit(1) retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"] if not retrieval_results: print("โŒ No retrieval results found from Stage 4") raise SystemExit(1) print(f"๐Ÿ“ Using {len(retrieval_results)} retrieval results from Stage 4") # Load LLM try: if args.llm_backend == "hf": tokenizer, model = load_hf_model(args.model_name) def llm_func(prompt): return generate_hf(tokenizer, model, prompt) else: # vllm llm, sampling_params = load_vllm_model(args.model_name) def llm_func(prompt): return generate_vllm(llm, sampling_params, prompt) # Run generation using stored retrieval results import time from llm_utils import create_prompt generation_times = [] responses = [] print("๐Ÿค– Running generation on pre-retrieved results...") for i, item in enumerate(retrieval_results): query = item["query"] retrieved_docs = item["retrieved_docs"] # Prepare context from retrieved docs context = "\n\n".join([doc["text"] for doc in retrieved_docs]) prompt = create_prompt(context, query, "emails") # Time generation only gen_start = time.time() response = llm_func(prompt) gen_time = time.time() - gen_start generation_times.append(gen_time) responses.append(response) if i < 3: print(f" Q{i + 1}: Gen={gen_time:.3f}s") avg_gen_time = sum(generation_times) / len(generation_times) print("\n๐Ÿ“Š Generation Results:") print(f" Total Queries: {len(retrieval_results)}") print(f" Avg Generation Time: {avg_gen_time:.3f}s") print(" (Search time from Stage 4)") results_out["stage5"] = { "total_queries": len(retrieval_results), "avg_generation_time": avg_gen_time, "generation_times": generation_times, "responses": responses, } # Show sample results print("\n๐Ÿ“ Sample Results:") for i in range(min(3, len(retrieval_results))): query = retrieval_results[i]["query"] response = responses[i] print(f" Q{i + 1}: {query[:60]}...") print(f" A{i + 1}: {response[:100]}...") print() except Exception as e: print(f"โŒ Generation evaluation failed: {e}") print("๐Ÿ’ก Make sure transformers/vllm is installed and model is available") print("โœ… Stage 5 completed!\n") if args.output and results_out: with open(args.output, "w", encoding="utf-8") as f: json.dump(results_out, f, indent=2) print(f"๐Ÿ“ Saved results to {args.output}") if __name__ == "__main__": main()