diff --git a/benchmarks/enron_emails/README.md b/benchmarks/enron_emails/README.md new file mode 100644 index 0000000..16d2271 --- /dev/null +++ b/benchmarks/enron_emails/README.md @@ -0,0 +1,116 @@ +# Enron Emails Benchmark + +A retrieval-only benchmark for evaluating LEANN search on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation focused on Recall@3. + +- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages +- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions) +- Metric: Recall@3 vs FAISS Flat baseline + +## Layout + +benchmarks/enron_emails/ +- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline +- evaluate_enron_emails.py: Evaluate retrieval recall (Stage 2) +- data/: Generated passages, queries, embeddings-related files +- baseline/: FAISS Flat baseline files + +## Quickstart + +1) Prepare the data and index + +cd benchmarks/enron_emails +python setup_enron_emails.py --data-dir data + +Notes: +- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`). + Alternatively, pass a local path to `--emails-csv`. + +Notes: +- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model. +- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`. + +2) Run recall evaluation (Stage 2) + +python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2 + +3) Complexity sweep (Stage 3) + +python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200 + +Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8. + +4) Index comparison (Stage 4) + +python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --max-queries 100 --output results.json + +Notes: +- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns: + - `--stage` defaults to `all` (runs 2, 3, 4) + - `--baseline-dir` defaults to `baseline` + - `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory) +- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out. + +4) Index comparison (Stage 4) + +python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --max-queries 100 --output results.json + +Optional flags: +- --queries data/evaluation_queries.jsonl (custom queries file) +- --baseline-dir baseline (where FAISS baseline lives) +- --complexity 64 (LEANN complexity parameter) + +## Files Produced +- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection) +- data/enron_index_hnsw.leann.*: LEANN index files +- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs +- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference) + +## Notes +- We only evaluate retrieval Recall@3 (no generation). This matches the other benches’ style and stage flow. +- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present. + +## Stages Summary + +- Stage 2 (Recall@3): + - Compares LEANN vs FAISS Flat baseline on Recall@3. + - Compact index runs with `recompute_embeddings=True`. + +- Stage 3 (Binary Search for Complexity): + - Builds a non-compact index (`_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%). + +- Stage 4 (Index Comparison): + - Reports .index-only sizes for compact vs non-compact. + - Measures timings on 100 queries by default: non-compact (no recompute) vs compact (with recompute). + - Fails fast if compact recompute cannot run. + - If `--complexity` is not provided, the script tries to use the best complexity from Stage 3: + - First from the current run (when running `--stage all`), otherwise + - From `enron_stage3_results.json` saved next to the index during the last Stage 3 run. + - If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`. + +## Example Results + +These are sample results obtained on a subset of Enron data using all-mpnet-base-v2. + +- Stage 3 (Binary Search): + - Minimal complexity achieving 90% Recall@3: 88 + - Sampled points: + - C=8 → 59.9% Recall@3 + - C=72 → 89.4% Recall@3 + - C=88 → 90.2% Recall@3 + - C=96 → 90.7% Recall@3 + - C=112 → 91.1% Recall@3 + - C=136 → 91.3% Recall@3 + - C=256 → 92.0% Recall@3 + +- Stage 4 (Index Sizes, .index only): + - Compact: ~2.17 MB + - Non-compact: ~82.03 MB + - Storage saving by compact: ~97.35% + +- Stage 4 (Timing, 100 queries, complexity=88): + - Non-compact (no recompute): ~0.0074 s avg per query + - Compact (with recompute): ~1.947 s avg per query + - Speed ratio (non-compact/compact): ~0.0038x + +Full JSON output for Stage 4 is saved by the script (see `--output`), e.g.: +`benchmarks/enron_emails/results_enron_stage4.json`. diff --git a/benchmarks/enron_emails/data/.gitignore b/benchmarks/enron_emails/data/.gitignore new file mode 100644 index 0000000..361ae0f --- /dev/null +++ b/benchmarks/enron_emails/data/.gitignore @@ -0,0 +1 @@ +downloads/ \ No newline at end of file diff --git a/benchmarks/enron_emails/evaluate_enron_emails.py b/benchmarks/enron_emails/evaluate_enron_emails.py new file mode 100644 index 0000000..d780c59 --- /dev/null +++ b/benchmarks/enron_emails/evaluate_enron_emails.py @@ -0,0 +1,509 @@ +""" +Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4) +Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline, +Stage 3 complexity sweep to target recall, Stage 4 index comparison. +On errors, fail fast without fallbacks. +""" + +import argparse +import json +import os +import pickle + +import numpy as np +from leann import LeannBuilder, LeannSearcher +from leann_backend_hnsw import faiss + + +class RecallEvaluator: + """Stage 2: Evaluate Recall@3 (LEANN vs FAISS)""" + + def __init__(self, index_path: str, baseline_dir: str): + self.index_path = index_path + self.baseline_dir = baseline_dir + self.searcher = LeannSearcher(index_path) + + baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index") + metadata_path = os.path.join(baseline_dir, "metadata.pkl") + + self.faiss_index = faiss.read_index(baseline_index_path) + with open(metadata_path, "rb") as f: + self.passage_ids = pickle.load(f) + + print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors") + + # No fallbacks here; if embedding server is needed but fails, the caller will see the error. + + def evaluate_recall_at_3( + self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True + ) -> float: + """Evaluate recall@3 using FAISS Flat as ground truth""" + from leann.api import compute_embeddings + + recompute_str = "with recompute" if recompute_embeddings else "no recompute" + print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...") + + total_recall = 0.0 + for i, query in enumerate(queries): + # Compute query embedding with the same model/mode as the index + q_emb = compute_embeddings( + [query], + self.searcher.embedding_model, + mode=self.searcher.embedding_mode, + use_server=False, + ).astype(np.float32) + + # Search FAISS Flat ground truth + n = q_emb.shape[0] + k = 3 + distances = np.zeros((n, k), dtype=np.float32) + labels = np.zeros((n, k), dtype=np.int64) + self.faiss_index.search( + n, + faiss.swig_ptr(q_emb), + k, + faiss.swig_ptr(distances), + faiss.swig_ptr(labels), + ) + + baseline_ids = {self.passage_ids[idx] for idx in labels[0]} + + # Search with LEANN (may require embedding server depending on index configuration) + results = self.searcher.search( + query, + top_k=3, + complexity=complexity, + recompute_embeddings=recompute_embeddings, + ) + test_ids = {r.id for r in results} + + intersection = test_ids.intersection(baseline_ids) + recall = len(intersection) / 3.0 + total_recall += recall + + if i < 3: + print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}") + print(f" FAISS: {list(baseline_ids)}") + print(f" LEANN: {list(test_ids)}") + print(f" ∩: {list(intersection)}") + + avg = total_recall / max(1, len(queries)) + print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)") + return avg + + def cleanup(self): + if hasattr(self, "searcher"): + self.searcher.cleanup() + + +class EnronEvaluator: + def __init__(self, index_path: str): + self.index_path = index_path + self.searcher = LeannSearcher(index_path) + + def load_queries(self, queries_file: str) -> list[str]: + queries: list[str] = [] + with open(queries_file, encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + data = json.loads(line) + if "query" in data: + queries.append(data["query"]) + print(f"📊 Loaded {len(queries)} queries from {queries_file}") + return queries + + def cleanup(self): + if self.searcher: + self.searcher.cleanup() + + def analyze_index_sizes(self) -> dict: + """Analyze index sizes (.index only), similar to LAION bench.""" + from pathlib import Path + + print("📏 Analyzing index sizes (.index only)...") + index_path = Path(self.index_path) + index_dir = index_path.parent + index_name = index_path.stem + + sizes: dict[str, float] = {} + index_file = index_dir / f"{index_name}.index" + meta_file = index_dir / f"{index_path.name}.meta.json" + passages_file = index_dir / f"{index_path.name}.passages.jsonl" + passages_idx_file = index_dir / f"{index_path.name}.passages.idx" + + sizes["index_only_mb"] = ( + index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0 + ) + sizes["metadata_mb"] = ( + meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0 + ) + sizes["passages_text_mb"] = ( + passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0 + ) + sizes["passages_index_mb"] = ( + passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0 + ) + + print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB") + return sizes + + def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict: + """Create a non-compact index for comparison using current passages and embeddings.""" + from pathlib import Path + + current_index_path = Path(self.index_path) + current_index_dir = current_index_path.parent + current_index_name = current_index_path.name + + # Read metadata to get passage source and embedding model + meta_path = current_index_dir / f"{current_index_name}.meta.json" + with open(meta_path, encoding="utf-8") as f: + meta = json.load(f) + + passage_source = meta["passage_sources"][0] + passage_file = passage_source["path"] + + # Convert relative path to absolute + if not Path(passage_file).is_absolute(): + passage_file = current_index_dir / Path(passage_file).name + + # Load all passages and ids + ids: list[str] = [] + texts: list[str] = [] + with open(passage_file, encoding="utf-8") as f: + for line in f: + if line.strip(): + data = json.loads(line) + ids.append(str(data["id"])) + texts.append(data["text"]) + + # Compute embeddings using the same method as LEANN + from leann.api import compute_embeddings + + embeddings = compute_embeddings( + texts, + meta["embedding_model"], + mode=meta.get("embedding_mode", "sentence-transformers"), + use_server=False, + ).astype(np.float32) + + # Build non-compact index with same passages and embeddings + builder = LeannBuilder( + backend_name="hnsw", + embedding_model=meta["embedding_model"], + embedding_mode=meta.get("embedding_mode", "sentence-transformers"), + is_recompute=False, + is_compact=False, + **{ + k: v + for k, v in meta.get("backend_kwargs", {}).items() + if k not in ["is_recompute", "is_compact"] + }, + ) + + # Persist a pickle for build_index_from_embeddings + pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl" + with open(pkl_path, "wb") as pf: + pickle.dump((ids, embeddings), pf) + + print( + f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..." + ) + builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path)) + + # Analyze the non-compact index size + temp_evaluator = EnronEvaluator(non_compact_index_path) + non_compact_sizes = temp_evaluator.analyze_index_sizes() + non_compact_sizes["index_type"] = "non_compact" + + return non_compact_sizes + + def compare_index_performance( + self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int + ) -> dict: + """Compare search speed for non-compact vs compact indexes.""" + import time + + results: dict = { + "non_compact": {"search_times": []}, + "compact": {"search_times": []}, + "avg_search_times": {}, + "speed_ratio": 0.0, + } + + print("⚡ Comparing search performance between indexes...") + # Non-compact (no recompute) + print(" 🔍 Testing non-compact index (no recompute)...") + non_compact_searcher = LeannSearcher(non_compact_path) + for q in test_queries: + t0 = time.time() + _ = non_compact_searcher.search( + q, top_k=3, complexity=complexity, recompute_embeddings=False + ) + results["non_compact"]["search_times"].append(time.time() - t0) + + # Compact (with recompute). Fail fast if it cannot run. + print(" 🔍 Testing compact index (with recompute)...") + compact_searcher = LeannSearcher(compact_path) + for q in test_queries: + t0 = time.time() + _ = compact_searcher.search( + q, top_k=3, complexity=complexity, recompute_embeddings=True + ) + results["compact"]["search_times"].append(time.time() - t0) + compact_searcher.cleanup() + + if results["non_compact"]["search_times"]: + results["avg_search_times"]["non_compact"] = sum( + results["non_compact"]["search_times"] + ) / len(results["non_compact"]["search_times"]) + if results["compact"]["search_times"]: + results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len( + results["compact"]["search_times"] + ) + if results["avg_search_times"].get("compact", 0) > 0: + results["speed_ratio"] = ( + results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"] + ) + else: + results["speed_ratio"] = 0.0 + + non_compact_searcher.cleanup() + return results + + def evaluate_complexity( + self, + recall_eval: "RecallEvaluator", + queries: list[str], + target: float = 0.90, + c_min: int = 8, + c_max: int = 256, + max_iters: int = 10, + recompute: bool = False, + ) -> dict: + """Binary search minimal complexity achieving target recall (monotonic assumption).""" + + def round_c(x: int) -> int: + # snap to multiple of 8 like other benches typically do + return max(1, int((x + 7) // 8) * 8) + + metrics: list[dict] = [] + + lo = round_c(c_min) + hi = round_c(c_max) + + print( + f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..." + ) + + # Ensure upper bound can reach target; expand if needed (up to a cap) + r_lo = recall_eval.evaluate_recall_at_3( + queries, complexity=lo, recompute_embeddings=recompute + ) + metrics.append({"complexity": lo, "recall_at_3": r_lo}) + r_hi = recall_eval.evaluate_recall_at_3( + queries, complexity=hi, recompute_embeddings=recompute + ) + metrics.append({"complexity": hi, "recall_at_3": r_hi}) + + cap = 1024 + while r_hi < target and hi < cap: + lo = hi + r_lo = r_hi + hi = round_c(hi * 2) + r_hi = recall_eval.evaluate_recall_at_3( + queries, complexity=hi, recompute_embeddings=recompute + ) + metrics.append({"complexity": hi, "recall_at_3": r_hi}) + + if r_hi < target: + print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.") + print("📈 Observations:") + for m in metrics: + print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%") + return {"metrics": metrics, "best_complexity": None, "target_recall": target} + + # Binary search within [lo, hi] + best = hi + iters = 0 + while lo < hi and iters < max_iters: + mid = round_c((lo + hi) // 2) + r_mid = recall_eval.evaluate_recall_at_3( + queries, complexity=mid, recompute_embeddings=recompute + ) + metrics.append({"complexity": mid, "recall_at_3": r_mid}) + if r_mid >= target: + best = mid + hi = mid + else: + lo = mid + 8 # move past mid, respecting multiple-of-8 step + iters += 1 + + print("📈 Binary search results (sampled points):") + # Print unique complexity entries ordered by complexity + for m in sorted( + {m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"] + ): + print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%") + print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}") + return {"metrics": metrics, "best_complexity": best, "target_recall": target} + + +def main(): + parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation") + parser.add_argument("--index", required=True, help="Path to LEANN index") + parser.add_argument( + "--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries" + ) + parser.add_argument( + "--stage", + choices=["2", "3", "4", "all"], + default="all", + help="Which stage to run (2=recall, 3=complexity, 4=index comparison)", + ) + parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity") + parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory") + parser.add_argument( + "--max-queries", type=int, help="Limit number of queries to evaluate", default=1000 + ) + parser.add_argument( + "--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3" + ) + parser.add_argument("--output", help="Save results to JSON file") + + args = parser.parse_args() + + # Resolve queries file: if default path not found, fall back to index's directory + if not os.path.exists(args.queries): + from pathlib import Path + + idx_dir = Path(args.index).parent + fallback_q = idx_dir / "evaluation_queries.jsonl" + if fallback_q.exists(): + args.queries = str(fallback_q) + + baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index") + if not os.path.exists(baseline_index_path): + print(f"❌ FAISS baseline not found at {baseline_index_path}") + print("💡 Please run setup_enron_emails.py first to build the baseline") + raise SystemExit(1) + + results_out: dict = {} + + if args.stage in ("2", "all"): + print("🚀 Starting Stage 2: Recall@3 evaluation") + evaluator = RecallEvaluator(args.index, args.baseline_dir) + + enron_eval = EnronEvaluator(args.index) + queries = enron_eval.load_queries(args.queries) + queries = queries[:10] + print(f"🧪 Using first {len(queries)} queries") + + complexity = args.complexity or 64 + r = evaluator.evaluate_recall_at_3(queries, complexity) + results_out["stage2"] = {"complexity": complexity, "recall_at_3": r} + evaluator.cleanup() + enron_eval.cleanup() + print("✅ Stage 2 completed!\n") + + if args.stage in ("3", "all"): + print("🚀 Starting Stage 3: Binary search for target recall (no recompute)") + enron_eval = EnronEvaluator(args.index) + queries = enron_eval.load_queries(args.queries) + queries = queries[: args.max_queries] + print(f"🧪 Using first {len(queries)} queries") + + # Build non-compact index for fast binary search (recompute_embeddings=False) + from pathlib import Path + index_path = Path(args.index) + non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann") + enron_eval.create_non_compact_index_for_comparison(non_compact_index_path) + + # Use non-compact evaluator for binary search with recompute=False + evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir) + sweep = enron_eval.evaluate_complexity( + evaluator_nc, queries, target=args.target_recall, recompute=False + ) + results_out["stage3"] = sweep + # Persist default stage 3 results near the index for Stage 4 auto-pickup + from pathlib import Path + default_stage3_path = Path(args.index).parent / "enron_stage3_results.json" + with open(default_stage3_path, "w", encoding="utf-8") as f: + json.dump({"stage3": sweep}, f, indent=2) + print(f"📝 Saved Stage 3 summary to {default_stage3_path}") + evaluator_nc.cleanup() + enron_eval.cleanup() + print("✅ Stage 3 completed!\n") + + if args.stage in ("4", "all"): + print("🚀 Starting Stage 4: Index size + performance comparison") + evaluator = RecallEvaluator(args.index, args.baseline_dir) + enron_eval = EnronEvaluator(args.index) + queries = enron_eval.load_queries(args.queries) + test_q = queries[: min(args.max_queries, len(queries))] + + current_sizes = enron_eval.analyze_index_sizes() + # Build non-compact index for comparison (no fallback) + from pathlib import Path + + index_path = Path(args.index) + non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann") + non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path) + nc_eval = EnronEvaluator(non_compact_path) + + if ( + current_sizes.get("index_only_mb", 0) > 0 + and non_compact_sizes.get("index_only_mb", 0) > 0 + ): + storage_saving_percent = max( + 0.0, + 100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]), + ) + else: + storage_saving_percent = 0.0 + + if args.complexity is None: + # Prefer in-session Stage 3 result + if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None: + complexity = results_out["stage3"]["best_complexity"] + print(f"📥 Using best complexity from Stage 3 in-session: {complexity}") + else: + # Try to load last saved Stage 3 result near index + default_stage3_path = Path(args.index).parent / "enron_stage3_results.json" + if default_stage3_path.exists(): + with open(default_stage3_path, encoding="utf-8") as f: + prev = json.load(f) + complexity = prev.get("stage3", {}).get("best_complexity") + if complexity is None: + raise SystemExit("❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results") + print(f"📥 Using best complexity from saved Stage 3: {complexity}") + else: + raise SystemExit( + "❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity." + ) + else: + complexity = args.complexity + + comp = enron_eval.compare_index_performance( + non_compact_path, args.index, test_q, complexity=complexity + ) + results_out["stage4"] = { + "current_index": current_sizes, + "non_compact_index": non_compact_sizes, + "storage_saving_percent": storage_saving_percent, + "performance_comparison": comp, + } + nc_eval.cleanup() + evaluator.cleanup() + enron_eval.cleanup() + print("✅ Stage 4 completed!\n") + + if args.output and results_out: + with open(args.output, "w", encoding="utf-8") as f: + json.dump(results_out, f, indent=2) + print(f"📝 Saved results to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/enron_emails/setup_enron_emails.py b/benchmarks/enron_emails/setup_enron_emails.py new file mode 100644 index 0000000..ca88748 --- /dev/null +++ b/benchmarks/enron_emails/setup_enron_emails.py @@ -0,0 +1,359 @@ +""" +Enron Emails Benchmark Setup Script +Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline +""" + +import argparse +import csv +import json +import os +import re +from collections.abc import Iterable +from email import message_from_string +from email.policy import default +from pathlib import Path +from typing import Optional + +from leann import LeannBuilder + + +class EnronSetup: + def __init__(self, data_dir: str = "data"): + self.data_dir = Path(data_dir) + self.data_dir.mkdir(parents=True, exist_ok=True) + + self.passages_preview = self.data_dir / "enron_passages_preview.jsonl" + self.index_path = self.data_dir / "enron_index_hnsw.leann" + self.queries_file = self.data_dir / "evaluation_queries.jsonl" + self.downloads_dir = self.data_dir / "downloads" + self.downloads_dir.mkdir(parents=True, exist_ok=True) + + # ---------------------------- + # Dataset acquisition + # ---------------------------- + def ensure_emails_csv(self, emails_csv: Optional[str]) -> str: + """Return a path to emails.csv, downloading from Kaggle if needed.""" + if emails_csv: + p = Path(emails_csv) + if not p.exists(): + raise FileNotFoundError(f"emails.csv not found: {emails_csv}") + return str(p) + + print( + "📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..." + ) + try: + from kaggle.api.kaggle_api_extended import KaggleApi + + api = KaggleApi() + api.authenticate() + api.dataset_download_files( + "wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True + ) + candidate = self.downloads_dir / "emails.csv" + if candidate.exists(): + print(f"✅ Downloaded emails.csv: {candidate}") + return str(candidate) + else: + raise FileNotFoundError( + f"emails.csv was not found in {self.downloads_dir} after Kaggle download" + ) + except Exception as e: + print( + "❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API." + ) + print( + " Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv." + ) + raise e + + # ---------------------------- + # Data preparation + # ---------------------------- + @staticmethod + def _extract_message_id(raw_email: str) -> str: + msg = message_from_string(raw_email, policy=default) + val = msg.get("Message-ID", "") + if val.startswith("<") and val.endswith(">"): + val = val[1:-1] + return val or "" + + @staticmethod + def _split_header_body(raw_email: str) -> tuple[str, str]: + parts = raw_email.split("\n\n", 1) + if len(parts) == 2: + return parts[0].strip(), parts[1].strip() + # Heuristic fallback + first_lines = raw_email.splitlines() + if first_lines and ":" in first_lines[0]: + return raw_email.strip(), "" + return "", raw_email.strip() + + @staticmethod + def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]: + text = (text or "").strip() + if not text: + return [] + if chunk_words <= 0: + return [text] + words = text.split() + if not words: + return [] + limit = len(words) + if not keep_last: + limit = (len(words) // chunk_words) * chunk_words + if limit == 0: + return [] + chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)] + return [c for c in (s.strip() for s in chunks) if c] + + def _iter_passages_from_csv( + self, + emails_csv: Path, + chunk_words: int = 256, + keep_last_header: bool = True, + keep_last_body: bool = True, + max_emails: int | None = None, + ) -> Iterable[dict]: + with open(emails_csv, encoding="utf-8") as f: + reader = csv.DictReader(f) + count = 0 + for i, row in enumerate(reader): + if max_emails is not None and count >= max_emails: + break + + raw_message = row.get("message", "") + email_file_id = row.get("file", "") + + if not raw_message.strip(): + continue + + message_id = self._extract_message_id(raw_message) + if not message_id: + # Fallback ID based on CSV position and file path + safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id) + message_id = f"enron_{i}_{safe_file}" + + header, body = self._split_header_body(raw_message) + + # Header chunks + for chunk in self._split_fixed_words(header, chunk_words, keep_last_header): + yield { + "text": chunk, + "metadata": { + "message_id": message_id, + "is_header": True, + "email_file_id": email_file_id, + }, + } + + # Body chunks + for chunk in self._split_fixed_words(body, chunk_words, keep_last_body): + yield { + "text": chunk, + "metadata": { + "message_id": message_id, + "is_header": False, + "email_file_id": email_file_id, + }, + } + + count += 1 + + # ---------------------------- + # Build LEANN index and FAISS baseline + # ---------------------------- + def build_leann_index( + self, + emails_csv: Optional[str], + backend: str = "hnsw", + embedding_model: str = "sentence-transformers/all-mpnet-base-v2", + chunk_words: int = 256, + max_emails: int | None = None, + ) -> str: + emails_csv_path = self.ensure_emails_csv(emails_csv) + print(f"🏗️ Building LEANN index from {emails_csv_path}...") + + builder = LeannBuilder( + backend_name=backend, + embedding_model=embedding_model, + embedding_mode="sentence-transformers", + graph_degree=32, + complexity=64, + is_recompute=True, + is_compact=True, + num_threads=4, + ) + + # Stream passages and add to builder + preview_written = 0 + with open(self.passages_preview, "w", encoding="utf-8") as preview_out: + for p in self._iter_passages_from_csv( + Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails + ): + builder.add_text(p["text"], metadata=p["metadata"]) + if preview_written < 200: + preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n") + preview_written += 1 + + print(f"🔨 Building index at {self.index_path}...") + builder.build_index(str(self.index_path)) + print("✅ LEANN index built!") + return str(self.index_path) + + def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str: + print("🔨 Building FAISS Flat baseline from LEANN passages...") + + import pickle + + import numpy as np + from leann.api import compute_embeddings + from leann_backend_hnsw import faiss + + os.makedirs(output_dir, exist_ok=True) + baseline_path = os.path.join(output_dir, "faiss_flat.index") + metadata_path = os.path.join(output_dir, "metadata.pkl") + + if os.path.exists(baseline_path) and os.path.exists(metadata_path): + print(f"✅ Baseline already exists at {baseline_path}") + return baseline_path + + # Read meta for passage source and embedding model + meta_path = f"{index_path}.meta.json" + with open(meta_path, encoding="utf-8") as f: + meta = json.load(f) + + embedding_model = meta["embedding_model"] + passage_source = meta["passage_sources"][0] + passage_file = passage_source["path"] + + if not os.path.isabs(passage_file): + index_dir = os.path.dirname(index_path) + passage_file = os.path.join(index_dir, os.path.basename(passage_file)) + + # Load passages from builder output so IDs match LEANN + passages: list[str] = [] + passage_ids: list[str] = [] + with open(passage_file, encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + data = json.loads(line) + passages.append(data["text"]) + passage_ids.append(data["id"]) # builder-assigned ID + + print(f"📄 Loaded {len(passages)} passages for baseline") + print(f"🤖 Embedding model: {embedding_model}") + + embeddings = compute_embeddings( + passages, + embedding_model, + mode="sentence-transformers", + use_server=False, + ) + + # Build FAISS IndexFlatIP + dim = embeddings.shape[1] + index = faiss.IndexFlatIP(dim) + emb_f32 = embeddings.astype(np.float32) + index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32)) + + faiss.write_index(index, baseline_path) + with open(metadata_path, "wb") as pf: + pickle.dump(passage_ids, pf) + + print(f"✅ FAISS baseline saved: {baseline_path}") + print(f"✅ Metadata saved: {metadata_path}") + print(f"📊 Total vectors: {index.ntotal}") + return baseline_path + + # ---------------------------- + # Queries (optional): prepare evaluation queries file + # ---------------------------- + def prepare_queries(self, min_realism: float = 0.85) -> Path: + print( + "📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..." + ) + try: + from datasets import load_dataset + + ds = load_dataset("corbt/enron_emails_sample_questions", split="train") + except Exception as e: + print(f"⚠️ Failed to load dataset: {e}") + return self.queries_file + + kept = 0 + with open(self.queries_file, "w", encoding="utf-8") as out: + for i, item in enumerate(ds): + how_realistic = float(item.get("how_realistic", 0.0)) + if how_realistic < min_realism: + continue + qid = str(item.get("id", f"enron_q_{i}")) + query = item.get("question", "") + if not query: + continue + record = { + "id": qid, + "query": query, + # For reference only, not used in recall metric below + "gt_message_ids": item.get("message_ids", []), + } + out.write(json.dumps(record) + "\n") + kept += 1 + print(f"✅ Wrote {kept} queries to {self.queries_file}") + return self.queries_file + + +def main(): + parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark") + parser.add_argument( + "--emails-csv", + help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.", + ) + parser.add_argument("--data-dir", default="data", help="Data directory") + parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw") + parser.add_argument( + "--embedding-model", + default="sentence-transformers/all-mpnet-base-v2", + help="Embedding model for LEANN", + ) + parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size") + parser.add_argument("--max-emails", type=int, help="Limit number of emails to process") + parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file") + parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index") + + args = parser.parse_args() + + setup = EnronSetup(args.data_dir) + + # Build index + if not args.skip_build: + index_path = setup.build_leann_index( + emails_csv=args.emails_csv, + backend=args.backend, + embedding_model=args.embedding_model, + chunk_words=args.chunk_words, + max_emails=args.max_emails, + ) + + # Build FAISS baseline from the same passages & embeddings + setup.build_faiss_flat_baseline(index_path) + else: + print("⏭️ Skipping LEANN index build and baseline") + + # Queries file (optional) + if not args.skip_queries: + setup.prepare_queries() + else: + print("⏭️ Skipping query preparation") + + print("\n🎉 Enron Emails setup completed!") + print(f"📁 Data directory: {setup.data_dir.absolute()}") + print("Next steps:") + print( + "1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2" + ) + + +if __name__ == "__main__": + main()