feat: fix financebench

This commit is contained in:
Andy Lee
2025-08-22 13:32:23 -07:00
parent 772894012e
commit 069bce558b
5 changed files with 1415 additions and 280 deletions

View File

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@ Downloads all PDFs and builds full LEANN datastore
"""
import argparse
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -139,14 +140,15 @@ class FinanceBenchSetup:
start_time = time.time()
# Initialize builder
# Initialize builder with standard compact configuration
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=False, # Store embeddings for speed
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage (pruned)
num_threads=4,
)
@@ -185,6 +187,87 @@ class FinanceBenchSetup:
return str(index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
print("🔨 Building FAISS Flat baseline...")
import os
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read metadata from the built index
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
import json
meta = json.loads(f.read())
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
print(f"📊 Loading passages from {passage_file}...")
print(f"🤖 Using embedding model: {embedding_model}")
# Load all passages for baseline
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"📄 Loaded {len(passages)} passages")
# Compute embeddings using the same method as LEANN
print("🧮 Computing embeddings...")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
print(f"📐 Embedding shape: {embeddings.shape}")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(passage_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
"""Extract and chunk text from a PDF file"""
chunks = []
@@ -300,6 +383,11 @@ def main():
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
parser.add_argument(
"--build-baseline-only",
action="store_true",
help="Only build FAISS baseline from existing index",
)
args = parser.parse_args()
@@ -309,6 +397,24 @@ def main():
setup = FinanceBenchSetup(args.data_dir)
try:
if args.build_baseline_only:
# Only build baseline from existing index
index_path = setup.index_dir / f"financebench_full_{args.backend}"
index_file = f"{index_path}.index"
meta_file = f"{index_path}.leann.meta.json"
if not os.path.exists(index_file) or not os.path.exists(meta_file):
print("❌ Index files not found:")
print(f" Index: {index_file}")
print(f" Meta: {meta_file}")
print("💡 Run without --build-baseline-only to build the index first")
exit(1)
print(f"🔨 Building baseline from existing index: {index_path}")
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
print(f"✅ Baseline built at {baseline_path}")
return
# Step 1: Download dataset
setup.download_dataset()
@@ -324,7 +430,12 @@ def main():
backend=args.backend, embedding_model=args.embedding_model
)
# Step 4: Verify setup
# Step 4: Build FAISS flat baseline
print("\n🔨 Building FAISS flat baseline...")
baseline_path = setup.build_faiss_flat_baseline(index_path)
print(f"✅ Baseline built at {baseline_path}")
# Step 5: Verify setup
setup.verify_setup(index_path)
else:
print("⏭️ Skipping index building")

View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "faiss-cpu",
# "numpy",
# "sentence-transformers",
# "torch",
# "tqdm",
# ]
# ///
"""
Independent recall verification script using standard FAISS.
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
"""
import json
import time
from pathlib import Path
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
"""
Direct embedding computation using sentence-transformers.
Copied logic to avoid dependency issues.
"""
print(f"Loading model: {model_name}")
model = SentenceTransformer(model_name)
print(f"Computing embeddings for {len(chunks)} chunks...")
embeddings = model.encode(
chunks,
show_progress_bar=True,
batch_size=32,
convert_to_numpy=True,
normalize_embeddings=False,
)
return embeddings.astype(np.float32)
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
"""Load FinanceBench queries from dataset"""
queries = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
if len(queries) >= max_queries:
break
return queries
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
"""Load passages from LEANN index structure"""
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
index_dir = Path(index_path).parent
passage_file = index_dir / Path(passage_file).name
print(f"Loading passages from {passage_file}")
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Loading passages"):
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"Loaded {len(passages)} passages")
return passages, passage_ids
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
dimension = embeddings.shape[1]
# Build Flat index (ground truth)
print("Building FAISS IndexFlatIP (ground truth)...")
flat_index = faiss.IndexFlatIP(dimension)
flat_index.add(embeddings)
# Build HNSW index
print("Building FAISS IndexHNSWFlat...")
M = 32 # Same as LEANN default
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
hnsw_index.add(embeddings)
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
return flat_index, hnsw_index
def evaluate_recall_at_k(
query_embeddings: np.ndarray,
flat_index: faiss.Index,
hnsw_index: faiss.Index,
passage_ids: list[str],
k: int = 3,
ef_search: int = 64,
) -> float:
"""Evaluate recall@k comparing HNSW vs Flat"""
# Set search parameters for HNSW
hnsw_index.hnsw.efSearch = ef_search
total_recall = 0.0
num_queries = query_embeddings.shape[0]
for i in range(num_queries):
query = query_embeddings[i : i + 1] # Keep 2D shape
# Get ground truth from Flat index (standard FAISS API)
flat_distances, flat_indices = flat_index.search(query, k)
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
# Get results from HNSW index (standard FAISS API)
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
# Calculate recall
intersection = ground_truth_ids.intersection(hnsw_ids)
recall = len(intersection) / k
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
print(f" Flat: {list(ground_truth_ids)}")
print(f" HNSW: {list(hnsw_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
return avg_recall
def main():
# Configuration
dataset_path = "data/financebench_merged.jsonl"
index_path = "data/index/financebench_full_hnsw.leann"
embedding_model = "sentence-transformers/all-mpnet-base-v2"
print("🔍 FAISS Recall Verification")
print("=" * 50)
# Check if files exist
if not Path(dataset_path).exists():
print(f"❌ Dataset not found: {dataset_path}")
return
if not Path(f"{index_path}.meta.json").exists():
print(f"❌ Index metadata not found: {index_path}.meta.json")
return
# Load data
print("📖 Loading FinanceBench queries...")
queries = load_financebench_queries(dataset_path, max_queries=50)
print(f"Loaded {len(queries)} queries")
print("📄 Loading passages from LEANN index...")
passages, passage_ids = load_passages_from_leann_index(index_path)
# Compute embeddings
print("🧮 Computing passage embeddings...")
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
print("🧮 Computing query embeddings...")
query_embeddings = compute_embeddings_direct(queries, embedding_model)
# Build FAISS indexes
print("🏗️ Building FAISS indexes...")
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
# Test different efSearch values (equivalent to LEANN complexity)
print("\n📊 Evaluating Recall@3 at different efSearch values...")
ef_search_values = [16, 32, 64, 128, 256]
for ef_search in ef_search_values:
print(f"\n🧪 Testing efSearch = {ef_search}")
start_time = time.time()
recall = evaluate_recall_at_k(
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
)
elapsed = time.time() - start_time
print(
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
)
print("\n✅ Verification completed!")
print("\n📋 Summary:")
print(" - Built independent FAISS Flat and HNSW indexes")
print(" - Compared recall@3 at different efSearch values")
print(" - Used same embedding model as LEANN")
print(" - This validates LEANN's recall measurements")
if __name__ == "__main__":
main()