docs: data updated
This commit is contained in:
@@ -7,13 +7,22 @@ On errors, fail fast without fallbacks.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||
@@ -119,7 +128,6 @@ class EnronEvaluator:
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||
from pathlib import Path
|
||||
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
index_path = Path(self.index_path)
|
||||
@@ -150,7 +158,6 @@ class EnronEvaluator:
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||
from pathlib import Path
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
@@ -230,6 +237,7 @@ class EnronEvaluator:
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||
}
|
||||
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
@@ -248,10 +256,15 @@ class EnronEvaluator:
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
_ = compact_searcher.search(
|
||||
docs = compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
results["compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Store retrieval results for Stage 5
|
||||
results["retrieval_results"].append(
|
||||
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||
)
|
||||
compact_searcher.cleanup()
|
||||
|
||||
if results["non_compact"]["search_times"]:
|
||||
@@ -358,9 +371,9 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "all"],
|
||||
choices=["2", "3", "4", "5", "all", "45"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison)",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
@@ -371,6 +384,8 @@ def main():
|
||||
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||
)
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -438,7 +453,7 @@ def main():
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 3 completed!\n")
|
||||
|
||||
if args.stage in ("4", "all"):
|
||||
if args.stage in ("4", "all", "45"):
|
||||
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
@@ -503,6 +518,92 @@ def main():
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||
|
||||
# Check if Stage 4 results exist
|
||||
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||
print("💡 Run Stage 4 first or use --stage all")
|
||||
raise SystemExit(1)
|
||||
|
||||
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||
if not retrieval_results:
|
||||
print("❌ No retrieval results found from Stage 4")
|
||||
raise SystemExit(1)
|
||||
|
||||
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||
|
||||
# Load LLM
|
||||
try:
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Run generation using stored retrieval results
|
||||
import time
|
||||
|
||||
from llm_utils import create_prompt
|
||||
|
||||
generation_times = []
|
||||
responses = []
|
||||
|
||||
print("🤖 Running generation on pre-retrieved results...")
|
||||
for i, item in enumerate(retrieval_results):
|
||||
query = item["query"]
|
||||
retrieved_docs = item["retrieved_docs"]
|
||||
|
||||
# Prepare context from retrieved docs
|
||||
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||
prompt = create_prompt(context, query, "emails")
|
||||
|
||||
# Time generation only
|
||||
gen_start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - gen_start
|
||||
|
||||
generation_times.append(gen_time)
|
||||
responses.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||
|
||||
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Queries: {len(retrieval_results)}")
|
||||
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||
print(" (Search time from Stage 4)")
|
||||
|
||||
results_out["stage5"] = {
|
||||
"total_queries": len(retrieval_results),
|
||||
"avg_generation_time": avg_gen_time,
|
||||
"generation_times": generation_times,
|
||||
"responses": responses,
|
||||
}
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Results:")
|
||||
for i in range(min(3, len(retrieval_results))):
|
||||
query = retrieval_results[i]["query"]
|
||||
response = responses[i]
|
||||
print(f" Q{i + 1}: {query[:60]}...")
|
||||
print(f" A{i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.output and results_out:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results_out, f, indent=2)
|
||||
|
||||
Reference in New Issue
Block a user