* feat: finance bench * docs: results * chore: ignroe data README * feat: fix financebench * feat: laion, also required idmaps support * style: format * style: format * fix: resolve ruff linting errors - Remove unused variables in benchmark scripts - Rename unused loop variables to follow convention * feat: enron email bench * experiments for running DiskANN & BM25 on Arch 4090 * style: format * chore(ci): remove paru-bin submodule and config to fix checkout --recurse-submodules * docs: data * docs: data updated * fix: as package * fix(ci): only run pre-commit * chore: use http url of astchunk; use group for some dev deps * fix(ci): should checkout modules as well since `uv sync` checks * fix(ci): run with lint only * fix: find links to install wheels available * CI: force local wheels in uv install step * CI: install local wheels via file paths * CI: pick wheels matching current Python tag * CI: handle python tag mismatches for local wheels * CI: use matrix python venv and set macOS deployment target * CI: revert install step to match main * CI: use uv group install with local wheel selection * CI: rely on setup-uv for Python and tighten group install * CI: install build deps with uv python interpreter * CI: use temporary uv venv for build deps * CI: add build venv scripts path for wheel repair
615 lines
24 KiB
Python
615 lines
24 KiB
Python
"""
|
|
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
|
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
|
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
|
On errors, fail fast without fallbacks.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import pickle
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
from leann import LeannBuilder, LeannSearcher
|
|
from leann_backend_hnsw import faiss
|
|
|
|
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
|
|
|
# Setup logging to reduce verbose output
|
|
logging.basicConfig(level=logging.WARNING)
|
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
|
|
|
|
|
class RecallEvaluator:
|
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
|
|
|
def __init__(self, index_path: str, baseline_dir: str):
|
|
self.index_path = index_path
|
|
self.baseline_dir = baseline_dir
|
|
self.searcher = LeannSearcher(index_path)
|
|
|
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
|
|
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
|
with open(metadata_path, "rb") as f:
|
|
self.passage_ids = pickle.load(f)
|
|
|
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
|
|
|
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
|
|
|
def evaluate_recall_at_3(
|
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
|
) -> float:
|
|
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
|
from leann.api import compute_embeddings
|
|
|
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
|
|
|
total_recall = 0.0
|
|
for i, query in enumerate(queries):
|
|
# Compute query embedding with the same model/mode as the index
|
|
q_emb = compute_embeddings(
|
|
[query],
|
|
self.searcher.embedding_model,
|
|
mode=self.searcher.embedding_mode,
|
|
use_server=False,
|
|
).astype(np.float32)
|
|
|
|
# Search FAISS Flat ground truth
|
|
n = q_emb.shape[0]
|
|
k = 3
|
|
distances = np.zeros((n, k), dtype=np.float32)
|
|
labels = np.zeros((n, k), dtype=np.int64)
|
|
self.faiss_index.search(
|
|
n,
|
|
faiss.swig_ptr(q_emb),
|
|
k,
|
|
faiss.swig_ptr(distances),
|
|
faiss.swig_ptr(labels),
|
|
)
|
|
|
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
|
|
|
# Search with LEANN (may require embedding server depending on index configuration)
|
|
results = self.searcher.search(
|
|
query,
|
|
top_k=3,
|
|
complexity=complexity,
|
|
recompute_embeddings=recompute_embeddings,
|
|
)
|
|
test_ids = {r.id for r in results}
|
|
|
|
intersection = test_ids.intersection(baseline_ids)
|
|
recall = len(intersection) / 3.0
|
|
total_recall += recall
|
|
|
|
if i < 3:
|
|
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
|
print(f" FAISS: {list(baseline_ids)}")
|
|
print(f" LEANN: {list(test_ids)}")
|
|
print(f" ∩: {list(intersection)}")
|
|
|
|
avg = total_recall / max(1, len(queries))
|
|
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
|
return avg
|
|
|
|
def cleanup(self):
|
|
if hasattr(self, "searcher"):
|
|
self.searcher.cleanup()
|
|
|
|
|
|
class EnronEvaluator:
|
|
def __init__(self, index_path: str):
|
|
self.index_path = index_path
|
|
self.searcher = LeannSearcher(index_path)
|
|
|
|
def load_queries(self, queries_file: str) -> list[str]:
|
|
queries: list[str] = []
|
|
with open(queries_file, encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
data = json.loads(line)
|
|
if "query" in data:
|
|
queries.append(data["query"])
|
|
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
|
return queries
|
|
|
|
def cleanup(self):
|
|
if self.searcher:
|
|
self.searcher.cleanup()
|
|
|
|
def analyze_index_sizes(self) -> dict:
|
|
"""Analyze index sizes (.index only), similar to LAION bench."""
|
|
|
|
print("📏 Analyzing index sizes (.index only)...")
|
|
index_path = Path(self.index_path)
|
|
index_dir = index_path.parent
|
|
index_name = index_path.stem
|
|
|
|
sizes: dict[str, float] = {}
|
|
index_file = index_dir / f"{index_name}.index"
|
|
meta_file = index_dir / f"{index_path.name}.meta.json"
|
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
|
|
|
sizes["index_only_mb"] = (
|
|
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
|
)
|
|
sizes["metadata_mb"] = (
|
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
|
)
|
|
sizes["passages_text_mb"] = (
|
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
|
)
|
|
sizes["passages_index_mb"] = (
|
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
|
)
|
|
|
|
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
|
return sizes
|
|
|
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
|
"""Create a non-compact index for comparison using current passages and embeddings."""
|
|
|
|
current_index_path = Path(self.index_path)
|
|
current_index_dir = current_index_path.parent
|
|
current_index_name = current_index_path.name
|
|
|
|
# Read metadata to get passage source and embedding model
|
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
|
with open(meta_path, encoding="utf-8") as f:
|
|
meta = json.load(f)
|
|
|
|
passage_source = meta["passage_sources"][0]
|
|
passage_file = passage_source["path"]
|
|
|
|
# Convert relative path to absolute
|
|
if not Path(passage_file).is_absolute():
|
|
passage_file = current_index_dir / Path(passage_file).name
|
|
|
|
# Load all passages and ids
|
|
ids: list[str] = []
|
|
texts: list[str] = []
|
|
with open(passage_file, encoding="utf-8") as f:
|
|
for line in f:
|
|
if line.strip():
|
|
data = json.loads(line)
|
|
ids.append(str(data["id"]))
|
|
texts.append(data["text"])
|
|
|
|
# Compute embeddings using the same method as LEANN
|
|
from leann.api import compute_embeddings
|
|
|
|
embeddings = compute_embeddings(
|
|
texts,
|
|
meta["embedding_model"],
|
|
mode=meta.get("embedding_mode", "sentence-transformers"),
|
|
use_server=False,
|
|
).astype(np.float32)
|
|
|
|
# Build non-compact index with same passages and embeddings
|
|
builder = LeannBuilder(
|
|
backend_name="hnsw",
|
|
embedding_model=meta["embedding_model"],
|
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
|
is_recompute=False,
|
|
is_compact=False,
|
|
**{
|
|
k: v
|
|
for k, v in meta.get("backend_kwargs", {}).items()
|
|
if k not in ["is_recompute", "is_compact"]
|
|
},
|
|
)
|
|
|
|
# Persist a pickle for build_index_from_embeddings
|
|
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
|
with open(pkl_path, "wb") as pf:
|
|
pickle.dump((ids, embeddings), pf)
|
|
|
|
print(
|
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
|
)
|
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
|
|
|
# Analyze the non-compact index size
|
|
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
|
non_compact_sizes["index_type"] = "non_compact"
|
|
|
|
return non_compact_sizes
|
|
|
|
def compare_index_performance(
|
|
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
|
) -> dict:
|
|
"""Compare search speed for non-compact vs compact indexes."""
|
|
import time
|
|
|
|
results: dict = {
|
|
"non_compact": {"search_times": []},
|
|
"compact": {"search_times": []},
|
|
"avg_search_times": {},
|
|
"speed_ratio": 0.0,
|
|
"retrieval_results": [], # Store retrieval results for Stage 5
|
|
}
|
|
|
|
print("⚡ Comparing search performance between indexes...")
|
|
# Non-compact (no recompute)
|
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
|
for q in test_queries:
|
|
t0 = time.time()
|
|
_ = non_compact_searcher.search(
|
|
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
|
)
|
|
results["non_compact"]["search_times"].append(time.time() - t0)
|
|
|
|
# Compact (with recompute). Fail fast if it cannot run.
|
|
print(" 🔍 Testing compact index (with recompute)...")
|
|
compact_searcher = LeannSearcher(compact_path)
|
|
for q in test_queries:
|
|
t0 = time.time()
|
|
docs = compact_searcher.search(
|
|
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
|
)
|
|
results["compact"]["search_times"].append(time.time() - t0)
|
|
|
|
# Store retrieval results for Stage 5
|
|
results["retrieval_results"].append(
|
|
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
|
)
|
|
compact_searcher.cleanup()
|
|
|
|
if results["non_compact"]["search_times"]:
|
|
results["avg_search_times"]["non_compact"] = sum(
|
|
results["non_compact"]["search_times"]
|
|
) / len(results["non_compact"]["search_times"])
|
|
if results["compact"]["search_times"]:
|
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
|
results["compact"]["search_times"]
|
|
)
|
|
if results["avg_search_times"].get("compact", 0) > 0:
|
|
results["speed_ratio"] = (
|
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
|
)
|
|
else:
|
|
results["speed_ratio"] = 0.0
|
|
|
|
non_compact_searcher.cleanup()
|
|
return results
|
|
|
|
def evaluate_complexity(
|
|
self,
|
|
recall_eval: "RecallEvaluator",
|
|
queries: list[str],
|
|
target: float = 0.90,
|
|
c_min: int = 8,
|
|
c_max: int = 256,
|
|
max_iters: int = 10,
|
|
recompute: bool = False,
|
|
) -> dict:
|
|
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
|
|
|
def round_c(x: int) -> int:
|
|
# snap to multiple of 8 like other benches typically do
|
|
return max(1, int((x + 7) // 8) * 8)
|
|
|
|
metrics: list[dict] = []
|
|
|
|
lo = round_c(c_min)
|
|
hi = round_c(c_max)
|
|
|
|
print(
|
|
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
|
)
|
|
|
|
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
|
r_lo = recall_eval.evaluate_recall_at_3(
|
|
queries, complexity=lo, recompute_embeddings=recompute
|
|
)
|
|
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
|
r_hi = recall_eval.evaluate_recall_at_3(
|
|
queries, complexity=hi, recompute_embeddings=recompute
|
|
)
|
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
|
|
|
cap = 1024
|
|
while r_hi < target and hi < cap:
|
|
lo = hi
|
|
r_lo = r_hi
|
|
hi = round_c(hi * 2)
|
|
r_hi = recall_eval.evaluate_recall_at_3(
|
|
queries, complexity=hi, recompute_embeddings=recompute
|
|
)
|
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
|
|
|
if r_hi < target:
|
|
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
|
print("📈 Observations:")
|
|
for m in metrics:
|
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
|
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
|
|
|
# Binary search within [lo, hi]
|
|
best = hi
|
|
iters = 0
|
|
while lo < hi and iters < max_iters:
|
|
mid = round_c((lo + hi) // 2)
|
|
r_mid = recall_eval.evaluate_recall_at_3(
|
|
queries, complexity=mid, recompute_embeddings=recompute
|
|
)
|
|
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
|
if r_mid >= target:
|
|
best = mid
|
|
hi = mid
|
|
else:
|
|
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
|
iters += 1
|
|
|
|
print("📈 Binary search results (sampled points):")
|
|
# Print unique complexity entries ordered by complexity
|
|
for m in sorted(
|
|
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
|
):
|
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
|
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
|
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
|
parser.add_argument(
|
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
|
)
|
|
parser.add_argument(
|
|
"--stage",
|
|
choices=["2", "3", "4", "5", "all", "45"],
|
|
default="all",
|
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
|
)
|
|
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
|
parser.add_argument(
|
|
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
|
)
|
|
parser.add_argument(
|
|
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
|
)
|
|
parser.add_argument("--output", help="Save results to JSON file")
|
|
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
|
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Resolve queries file: if default path not found, fall back to index's directory
|
|
if not os.path.exists(args.queries):
|
|
from pathlib import Path
|
|
|
|
idx_dir = Path(args.index).parent
|
|
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
|
if fallback_q.exists():
|
|
args.queries = str(fallback_q)
|
|
|
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
|
if not os.path.exists(baseline_index_path):
|
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
|
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
|
raise SystemExit(1)
|
|
|
|
results_out: dict = {}
|
|
|
|
if args.stage in ("2", "all"):
|
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
|
|
|
enron_eval = EnronEvaluator(args.index)
|
|
queries = enron_eval.load_queries(args.queries)
|
|
queries = queries[:10]
|
|
print(f"🧪 Using first {len(queries)} queries")
|
|
|
|
complexity = args.complexity or 64
|
|
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
|
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
|
evaluator.cleanup()
|
|
enron_eval.cleanup()
|
|
print("✅ Stage 2 completed!\n")
|
|
|
|
if args.stage in ("3", "all"):
|
|
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
|
enron_eval = EnronEvaluator(args.index)
|
|
queries = enron_eval.load_queries(args.queries)
|
|
queries = queries[: args.max_queries]
|
|
print(f"🧪 Using first {len(queries)} queries")
|
|
|
|
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
|
from pathlib import Path
|
|
|
|
index_path = Path(args.index)
|
|
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
|
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
|
|
|
# Use non-compact evaluator for binary search with recompute=False
|
|
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
|
sweep = enron_eval.evaluate_complexity(
|
|
evaluator_nc, queries, target=args.target_recall, recompute=False
|
|
)
|
|
results_out["stage3"] = sweep
|
|
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
|
from pathlib import Path
|
|
|
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
|
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
|
json.dump({"stage3": sweep}, f, indent=2)
|
|
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
|
evaluator_nc.cleanup()
|
|
enron_eval.cleanup()
|
|
print("✅ Stage 3 completed!\n")
|
|
|
|
if args.stage in ("4", "all", "45"):
|
|
print("🚀 Starting Stage 4: Index size + performance comparison")
|
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
|
enron_eval = EnronEvaluator(args.index)
|
|
queries = enron_eval.load_queries(args.queries)
|
|
test_q = queries[: min(args.max_queries, len(queries))]
|
|
|
|
current_sizes = enron_eval.analyze_index_sizes()
|
|
# Build non-compact index for comparison (no fallback)
|
|
from pathlib import Path
|
|
|
|
index_path = Path(args.index)
|
|
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
|
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
|
nc_eval = EnronEvaluator(non_compact_path)
|
|
|
|
if (
|
|
current_sizes.get("index_only_mb", 0) > 0
|
|
and non_compact_sizes.get("index_only_mb", 0) > 0
|
|
):
|
|
storage_saving_percent = max(
|
|
0.0,
|
|
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
|
)
|
|
else:
|
|
storage_saving_percent = 0.0
|
|
|
|
if args.complexity is None:
|
|
# Prefer in-session Stage 3 result
|
|
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
|
complexity = results_out["stage3"]["best_complexity"]
|
|
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
|
else:
|
|
# Try to load last saved Stage 3 result near index
|
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
|
if default_stage3_path.exists():
|
|
with open(default_stage3_path, encoding="utf-8") as f:
|
|
prev = json.load(f)
|
|
complexity = prev.get("stage3", {}).get("best_complexity")
|
|
if complexity is None:
|
|
raise SystemExit(
|
|
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
|
)
|
|
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
|
else:
|
|
raise SystemExit(
|
|
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
|
)
|
|
else:
|
|
complexity = args.complexity
|
|
|
|
comp = enron_eval.compare_index_performance(
|
|
non_compact_path, args.index, test_q, complexity=complexity
|
|
)
|
|
results_out["stage4"] = {
|
|
"current_index": current_sizes,
|
|
"non_compact_index": non_compact_sizes,
|
|
"storage_saving_percent": storage_saving_percent,
|
|
"performance_comparison": comp,
|
|
}
|
|
nc_eval.cleanup()
|
|
evaluator.cleanup()
|
|
enron_eval.cleanup()
|
|
print("✅ Stage 4 completed!\n")
|
|
|
|
if args.stage in ("5", "all"):
|
|
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
|
|
|
# Check if Stage 4 results exist
|
|
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
|
print("❌ Stage 5 requires Stage 4 retrieval results")
|
|
print("💡 Run Stage 4 first or use --stage all")
|
|
raise SystemExit(1)
|
|
|
|
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
|
if not retrieval_results:
|
|
print("❌ No retrieval results found from Stage 4")
|
|
raise SystemExit(1)
|
|
|
|
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
|
|
|
# Load LLM
|
|
try:
|
|
if args.llm_backend == "hf":
|
|
tokenizer, model = load_hf_model(args.model_name)
|
|
|
|
def llm_func(prompt):
|
|
return generate_hf(tokenizer, model, prompt)
|
|
else: # vllm
|
|
llm, sampling_params = load_vllm_model(args.model_name)
|
|
|
|
def llm_func(prompt):
|
|
return generate_vllm(llm, sampling_params, prompt)
|
|
|
|
# Run generation using stored retrieval results
|
|
import time
|
|
|
|
from llm_utils import create_prompt
|
|
|
|
generation_times = []
|
|
responses = []
|
|
|
|
print("🤖 Running generation on pre-retrieved results...")
|
|
for i, item in enumerate(retrieval_results):
|
|
query = item["query"]
|
|
retrieved_docs = item["retrieved_docs"]
|
|
|
|
# Prepare context from retrieved docs
|
|
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
|
prompt = create_prompt(context, query, "emails")
|
|
|
|
# Time generation only
|
|
gen_start = time.time()
|
|
response = llm_func(prompt)
|
|
gen_time = time.time() - gen_start
|
|
|
|
generation_times.append(gen_time)
|
|
responses.append(response)
|
|
|
|
if i < 3:
|
|
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
|
|
|
avg_gen_time = sum(generation_times) / len(generation_times)
|
|
|
|
print("\n📊 Generation Results:")
|
|
print(f" Total Queries: {len(retrieval_results)}")
|
|
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
|
print(" (Search time from Stage 4)")
|
|
|
|
results_out["stage5"] = {
|
|
"total_queries": len(retrieval_results),
|
|
"avg_generation_time": avg_gen_time,
|
|
"generation_times": generation_times,
|
|
"responses": responses,
|
|
}
|
|
|
|
# Show sample results
|
|
print("\n📝 Sample Results:")
|
|
for i in range(min(3, len(retrieval_results))):
|
|
query = retrieval_results[i]["query"]
|
|
response = responses[i]
|
|
print(f" Q{i + 1}: {query[:60]}...")
|
|
print(f" A{i + 1}: {response[:100]}...")
|
|
print()
|
|
|
|
except Exception as e:
|
|
print(f"❌ Generation evaluation failed: {e}")
|
|
print("💡 Make sure transformers/vllm is installed and model is available")
|
|
|
|
print("✅ Stage 5 completed!\n")
|
|
|
|
if args.output and results_out:
|
|
with open(args.output, "w", encoding="utf-8") as f:
|
|
json.dump(results_out, f, indent=2)
|
|
print(f"📝 Saved results to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|