feat: enron email bench
This commit is contained in:
116
benchmarks/enron_emails/README.md
Normal file
116
benchmarks/enron_emails/README.md
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Enron Emails Benchmark
|
||||||
|
|
||||||
|
A retrieval-only benchmark for evaluating LEANN search on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation focused on Recall@3.
|
||||||
|
|
||||||
|
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||||
|
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||||
|
- Metric: Recall@3 vs FAISS Flat baseline
|
||||||
|
|
||||||
|
## Layout
|
||||||
|
|
||||||
|
benchmarks/enron_emails/
|
||||||
|
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||||
|
- evaluate_enron_emails.py: Evaluate retrieval recall (Stage 2)
|
||||||
|
- data/: Generated passages, queries, embeddings-related files
|
||||||
|
- baseline/: FAISS Flat baseline files
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
1) Prepare the data and index
|
||||||
|
|
||||||
|
cd benchmarks/enron_emails
|
||||||
|
python setup_enron_emails.py --data-dir data
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||||
|
Alternatively, pass a local path to `--emails-csv`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||||
|
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||||
|
|
||||||
|
2) Run recall evaluation (Stage 2)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||||
|
|
||||||
|
3) Complexity sweep (Stage 3)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||||
|
|
||||||
|
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||||
|
|
||||||
|
4) Index comparison (Stage 4)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --max-queries 100 --output results.json
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||||
|
- `--stage` defaults to `all` (runs 2, 3, 4)
|
||||||
|
- `--baseline-dir` defaults to `baseline`
|
||||||
|
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||||
|
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||||
|
|
||||||
|
4) Index comparison (Stage 4)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --max-queries 100 --output results.json
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||||
|
- --baseline-dir baseline (where FAISS baseline lives)
|
||||||
|
- --complexity 64 (LEANN complexity parameter)
|
||||||
|
|
||||||
|
## Files Produced
|
||||||
|
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||||
|
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||||
|
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||||
|
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- We only evaluate retrieval Recall@3 (no generation). This matches the other benches’ style and stage flow.
|
||||||
|
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||||
|
|
||||||
|
## Stages Summary
|
||||||
|
|
||||||
|
- Stage 2 (Recall@3):
|
||||||
|
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||||
|
- Compact index runs with `recompute_embeddings=True`.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search for Complexity):
|
||||||
|
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||||
|
|
||||||
|
- Stage 4 (Index Comparison):
|
||||||
|
- Reports .index-only sizes for compact vs non-compact.
|
||||||
|
- Measures timings on 100 queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||||
|
- Fails fast if compact recompute cannot run.
|
||||||
|
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||||
|
- First from the current run (when running `--stage all`), otherwise
|
||||||
|
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||||
|
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||||
|
|
||||||
|
## Example Results
|
||||||
|
|
||||||
|
These are sample results obtained on a subset of Enron data using all-mpnet-base-v2.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search):
|
||||||
|
- Minimal complexity achieving 90% Recall@3: 88
|
||||||
|
- Sampled points:
|
||||||
|
- C=8 → 59.9% Recall@3
|
||||||
|
- C=72 → 89.4% Recall@3
|
||||||
|
- C=88 → 90.2% Recall@3
|
||||||
|
- C=96 → 90.7% Recall@3
|
||||||
|
- C=112 → 91.1% Recall@3
|
||||||
|
- C=136 → 91.3% Recall@3
|
||||||
|
- C=256 → 92.0% Recall@3
|
||||||
|
|
||||||
|
- Stage 4 (Index Sizes, .index only):
|
||||||
|
- Compact: ~2.17 MB
|
||||||
|
- Non-compact: ~82.03 MB
|
||||||
|
- Storage saving by compact: ~97.35%
|
||||||
|
|
||||||
|
- Stage 4 (Timing, 100 queries, complexity=88):
|
||||||
|
- Non-compact (no recompute): ~0.0074 s avg per query
|
||||||
|
- Compact (with recompute): ~1.947 s avg per query
|
||||||
|
- Speed ratio (non-compact/compact): ~0.0038x
|
||||||
|
|
||||||
|
Full JSON output for Stage 4 is saved by the script (see `--output`), e.g.:
|
||||||
|
`benchmarks/enron_emails/results_enron_stage4.json`.
|
||||||
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
downloads/
|
||||||
509
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
509
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||||
|
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||||
|
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||||
|
On errors, fail fast without fallbacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Compute query embedding with the same model/mode as the index
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS Flat ground truth
|
||||||
|
n = q_emb.shape[0]
|
||||||
|
k = 3
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(q_emb),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN (may require embedding server depending on index configuration)
|
||||||
|
results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {r.id for r in results}
|
||||||
|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN: {list(test_ids)}")
|
||||||
|
print(f" ∩: {list(intersection)}")
|
||||||
|
|
||||||
|
avg = total_recall / max(1, len(queries))
|
||||||
|
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||||
|
return avg
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class EnronEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
if "query" in data:
|
||||||
|
queries.append(data["query"])
|
||||||
|
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||||
|
|
||||||
|
sizes["index_only_mb"] = (
|
||||||
|
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source and embedding model
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
# Load all passages and ids
|
||||||
|
ids: list[str] = []
|
||||||
|
texts: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
texts.append(data["text"])
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
meta["embedding_model"],
|
||||||
|
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False,
|
||||||
|
is_compact=False,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare search speed for non-compact vs compact indexes."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
results: dict = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
# Non-compact (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Compact (with recompute). Fail fast if it cannot run.
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
results["compact"]["search_times"].append(time.time() - t0)
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
if results["non_compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
if results["compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
if results["avg_search_times"].get("compact", 0) > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = 0.0
|
||||||
|
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_complexity(
|
||||||
|
self,
|
||||||
|
recall_eval: "RecallEvaluator",
|
||||||
|
queries: list[str],
|
||||||
|
target: float = 0.90,
|
||||||
|
c_min: int = 8,
|
||||||
|
c_max: int = 256,
|
||||||
|
max_iters: int = 10,
|
||||||
|
recompute: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||||
|
|
||||||
|
def round_c(x: int) -> int:
|
||||||
|
# snap to multiple of 8 like other benches typically do
|
||||||
|
return max(1, int((x + 7) // 8) * 8)
|
||||||
|
|
||||||
|
metrics: list[dict] = []
|
||||||
|
|
||||||
|
lo = round_c(c_min)
|
||||||
|
hi = round_c(c_max)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||||
|
r_lo = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=lo, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
cap = 1024
|
||||||
|
while r_hi < target and hi < cap:
|
||||||
|
lo = hi
|
||||||
|
r_lo = r_hi
|
||||||
|
hi = round_c(hi * 2)
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
if r_hi < target:
|
||||||
|
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||||
|
print("📈 Observations:")
|
||||||
|
for m in metrics:
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||||
|
|
||||||
|
# Binary search within [lo, hi]
|
||||||
|
best = hi
|
||||||
|
iters = 0
|
||||||
|
while lo < hi and iters < max_iters:
|
||||||
|
mid = round_c((lo + hi) // 2)
|
||||||
|
r_mid = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=mid, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||||
|
if r_mid >= target:
|
||||||
|
best = mid
|
||||||
|
hi = mid
|
||||||
|
else:
|
||||||
|
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||||
|
iters += 1
|
||||||
|
|
||||||
|
print("📈 Binary search results (sampled points):")
|
||||||
|
# Print unique complexity entries ordered by complexity
|
||||||
|
for m in sorted(
|
||||||
|
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||||
|
):
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||||
|
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Resolve queries file: if default path not found, fall back to index's directory
|
||||||
|
if not os.path.exists(args.queries):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
idx_dir = Path(args.index).parent
|
||||||
|
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||||
|
if fallback_q.exists():
|
||||||
|
args.queries = str(fallback_q)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
results_out: dict = {}
|
||||||
|
|
||||||
|
if args.stage in ("2", "all"):
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[:10]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||||
|
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("3", "all"):
|
||||||
|
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[: args.max_queries]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||||
|
from pathlib import Path
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact evaluator for binary search with recompute=False
|
||||||
|
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
sweep = enron_eval.evaluate_complexity(
|
||||||
|
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||||
|
)
|
||||||
|
results_out["stage3"] = sweep
|
||||||
|
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||||
|
from pathlib import Path
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"stage3": sweep}, f, indent=2)
|
||||||
|
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||||
|
evaluator_nc.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 3 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("4", "all"):
|
||||||
|
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
test_q = queries[: min(args.max_queries, len(queries))]
|
||||||
|
|
||||||
|
current_sizes = enron_eval.analyze_index_sizes()
|
||||||
|
# Build non-compact index for comparison (no fallback)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||||
|
nc_eval = EnronEvaluator(non_compact_path)
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_sizes.get("index_only_mb", 0) > 0
|
||||||
|
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||||
|
):
|
||||||
|
storage_saving_percent = max(
|
||||||
|
0.0,
|
||||||
|
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_saving_percent = 0.0
|
||||||
|
|
||||||
|
if args.complexity is None:
|
||||||
|
# Prefer in-session Stage 3 result
|
||||||
|
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||||
|
complexity = results_out["stage3"]["best_complexity"]
|
||||||
|
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||||
|
else:
|
||||||
|
# Try to load last saved Stage 3 result near index
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
if default_stage3_path.exists():
|
||||||
|
with open(default_stage3_path, encoding="utf-8") as f:
|
||||||
|
prev = json.load(f)
|
||||||
|
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||||
|
if complexity is None:
|
||||||
|
raise SystemExit("❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results")
|
||||||
|
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||||
|
else:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
comp = enron_eval.compare_index_performance(
|
||||||
|
non_compact_path, args.index, test_q, complexity=complexity
|
||||||
|
)
|
||||||
|
results_out["stage4"] = {
|
||||||
|
"current_index": current_sizes,
|
||||||
|
"non_compact_index": non_compact_sizes,
|
||||||
|
"storage_saving_percent": storage_saving_percent,
|
||||||
|
"performance_comparison": comp,
|
||||||
|
}
|
||||||
|
nc_eval.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.output and results_out:
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results_out, f, indent=2)
|
||||||
|
print(f"📝 Saved results to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Setup Script
|
||||||
|
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from email import message_from_string
|
||||||
|
from email.policy import default
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class EnronSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||||
|
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||||
|
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
self.downloads_dir = self.data_dir / "downloads"
|
||||||
|
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Dataset acquisition
|
||||||
|
# ----------------------------
|
||||||
|
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||||
|
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||||
|
if emails_csv:
|
||||||
|
p = Path(emails_csv)
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||||
|
|
||||||
|
api = KaggleApi()
|
||||||
|
api.authenticate()
|
||||||
|
api.dataset_download_files(
|
||||||
|
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||||
|
)
|
||||||
|
candidate = self.downloads_dir / "emails.csv"
|
||||||
|
if candidate.exists():
|
||||||
|
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||||
|
return str(candidate)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Data preparation
|
||||||
|
# ----------------------------
|
||||||
|
@staticmethod
|
||||||
|
def _extract_message_id(raw_email: str) -> str:
|
||||||
|
msg = message_from_string(raw_email, policy=default)
|
||||||
|
val = msg.get("Message-ID", "")
|
||||||
|
if val.startswith("<") and val.endswith(">"):
|
||||||
|
val = val[1:-1]
|
||||||
|
return val or ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||||
|
parts = raw_email.split("\n\n", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[0].strip(), parts[1].strip()
|
||||||
|
# Heuristic fallback
|
||||||
|
first_lines = raw_email.splitlines()
|
||||||
|
if first_lines and ":" in first_lines[0]:
|
||||||
|
return raw_email.strip(), ""
|
||||||
|
return "", raw_email.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||||
|
text = (text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
if chunk_words <= 0:
|
||||||
|
return [text]
|
||||||
|
words = text.split()
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
limit = len(words)
|
||||||
|
if not keep_last:
|
||||||
|
limit = (len(words) // chunk_words) * chunk_words
|
||||||
|
if limit == 0:
|
||||||
|
return []
|
||||||
|
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||||
|
return [c for c in (s.strip() for s in chunks) if c]
|
||||||
|
|
||||||
|
def _iter_passages_from_csv(
|
||||||
|
self,
|
||||||
|
emails_csv: Path,
|
||||||
|
chunk_words: int = 256,
|
||||||
|
keep_last_header: bool = True,
|
||||||
|
keep_last_body: bool = True,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> Iterable[dict]:
|
||||||
|
with open(emails_csv, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
count = 0
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if max_emails is not None and count >= max_emails:
|
||||||
|
break
|
||||||
|
|
||||||
|
raw_message = row.get("message", "")
|
||||||
|
email_file_id = row.get("file", "")
|
||||||
|
|
||||||
|
if not raw_message.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
message_id = self._extract_message_id(raw_message)
|
||||||
|
if not message_id:
|
||||||
|
# Fallback ID based on CSV position and file path
|
||||||
|
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||||
|
message_id = f"enron_{i}_{safe_file}"
|
||||||
|
|
||||||
|
header, body = self._split_header_body(raw_message)
|
||||||
|
|
||||||
|
# Header chunks
|
||||||
|
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": True,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Body chunks
|
||||||
|
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": False,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Build LEANN index and FAISS baseline
|
||||||
|
# ----------------------------
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
emails_csv: Optional[str],
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
chunk_words: int = 256,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||||
|
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream passages and add to builder
|
||||||
|
preview_written = 0
|
||||||
|
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||||
|
for p in self._iter_passages_from_csv(
|
||||||
|
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||||
|
):
|
||||||
|
builder.add_text(p["text"], metadata=p["metadata"])
|
||||||
|
if preview_written < 200:
|
||||||
|
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||||
|
preview_written += 1
|
||||||
|
|
||||||
|
print(f"🔨 Building index at {self.index_path}...")
|
||||||
|
builder.build_index(str(self.index_path))
|
||||||
|
print("✅ LEANN index built!")
|
||||||
|
return str(self.index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||||
|
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read meta for passage source and embedding model
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
# Load passages from builder output so IDs match LEANN
|
||||||
|
passages: list[str] = []
|
||||||
|
passage_ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"]) # builder-assigned ID
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||||
|
print(f"🤖 Embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS IndexFlatIP
|
||||||
|
dim = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dim)
|
||||||
|
emb_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||||
|
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as pf:
|
||||||
|
pickle.dump(passage_ids, pf)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved: {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Queries (optional): prepare evaluation queries file
|
||||||
|
# ----------------------------
|
||||||
|
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||||
|
print(
|
||||||
|
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to load dataset: {e}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
kept = 0
|
||||||
|
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||||
|
for i, item in enumerate(ds):
|
||||||
|
how_realistic = float(item.get("how_realistic", 0.0))
|
||||||
|
if how_realistic < min_realism:
|
||||||
|
continue
|
||||||
|
qid = str(item.get("id", f"enron_q_{i}"))
|
||||||
|
query = item.get("question", "")
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
record = {
|
||||||
|
"id": qid,
|
||||||
|
"query": query,
|
||||||
|
# For reference only, not used in recall metric below
|
||||||
|
"gt_message_ids": item.get("message_ids", []),
|
||||||
|
}
|
||||||
|
out.write(json.dumps(record) + "\n")
|
||||||
|
kept += 1
|
||||||
|
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--emails-csv",
|
||||||
|
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model for LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||||
|
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||||
|
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
setup = EnronSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
emails_csv=args.emails_csv,
|
||||||
|
backend=args.backend,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
chunk_words=args.chunk_words,
|
||||||
|
max_emails=args.max_emails,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS baseline from the same passages & embeddings
|
||||||
|
setup.build_faiss_flat_baseline(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LEANN index build and baseline")
|
||||||
|
|
||||||
|
# Queries file (optional)
|
||||||
|
if not args.skip_queries:
|
||||||
|
setup.prepare_queries()
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping query preparation")
|
||||||
|
|
||||||
|
print("\n🎉 Enron Emails setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("Next steps:")
|
||||||
|
print(
|
||||||
|
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user