feat: finance bench
This commit is contained in:
82
benchmarks/data/.gitattributes
vendored
82
benchmarks/data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - uncompressed
|
||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - compressed
|
||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - uncompressed
|
||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - compressed
|
||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||
# Video files - compressed
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
84
benchmarks/financebench/README.md
Normal file
84
benchmarks/financebench/README.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# FinanceBench Benchmark for LEANN-RAG
|
||||
|
||||
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||
- **Questions**: 150 financial Q&A examples
|
||||
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
benchmarks/financebench/
|
||||
├── setup_financebench.py # Downloads PDFs and builds index
|
||||
├── evaluate_financebench.py # Intelligent evaluation script
|
||||
├── data/
|
||||
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||
│ ├── pdfs/ # Downloaded financial documents
|
||||
│ └── index/ # LEANN indexes
|
||||
│ └── financebench_full_hnsw.leann
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Setup (Download & Build Index)
|
||||
|
||||
```bash
|
||||
cd benchmarks/financebench
|
||||
python setup_financebench.py
|
||||
```
|
||||
|
||||
This will:
|
||||
- Download the 150 Q&A examples
|
||||
- Download all 368 PDF documents (parallel processing)
|
||||
- Build a LEANN index from 53K+ text chunks
|
||||
- Verify setup with test query
|
||||
|
||||
### 2. Evaluation
|
||||
|
||||
```bash
|
||||
# Basic retrieval evaluation
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||
|
||||
# Include QA evaluation with OpenAI
|
||||
export OPENAI_API_KEY="your-key"
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --qa-samples 20
|
||||
```
|
||||
|
||||
## Evaluation Methods
|
||||
|
||||
### Retrieval Evaluation
|
||||
Uses intelligent matching with three strategies:
|
||||
1. **Exact text overlap** - Direct substring matches
|
||||
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||
|
||||
### QA Evaluation
|
||||
LLM-based answer evaluation using GPT-4o:
|
||||
- Handles numerical rounding and equivalent representations
|
||||
- Considers fractions, percentages, and decimal equivalents
|
||||
- Evaluates semantic meaning rather than exact text match
|
||||
|
||||
## Expected Results
|
||||
|
||||
Previous runs show:
|
||||
- **Question Coverage**: ~65-75% (questions with relevant docs retrieved)
|
||||
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||
- **Search Time**: ~0.1-0.2s per query
|
||||
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||
|
||||
## Options
|
||||
|
||||
```bash
|
||||
# Use different backends
|
||||
python setup_financebench.py --backend diskann
|
||||
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||
|
||||
# Use different embedding models
|
||||
python setup_financebench.py --embedding-model facebook/contriever
|
||||
```
|
||||
432
benchmarks/financebench/evaluate_financebench.py
Executable file
432
benchmarks/financebench/evaluate_financebench.py
Executable file
@@ -0,0 +1,432 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FinanceBench Evaluation Script
|
||||
Uses intelligent evaluation similar to VectifyAI/Mafin approach
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
from leann import LeannChat, LeannSearcher
|
||||
|
||||
|
||||
class FinanceBenchEvaluator:
|
||||
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||
self.index_path = index_path
|
||||
self.openai_client = None
|
||||
|
||||
if openai_api_key:
|
||||
self.openai_client = openai.OpenAI(api_key=openai_api_key)
|
||||
|
||||
# Load LEANN
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||
|
||||
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||
"""Load FinanceBench dataset"""
|
||||
data = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data.append(json.loads(line))
|
||||
|
||||
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||
return data
|
||||
|
||||
def evaluate_retrieval_intelligent(self, data: list[dict], top_k: int = 10) -> dict:
|
||||
"""
|
||||
Intelligent retrieval evaluation
|
||||
Uses semantic similarity instead of strict word overlap
|
||||
"""
|
||||
print(f"🔍 Evaluating retrieval performance (top_k={top_k})...")
|
||||
|
||||
metrics = {
|
||||
"total_questions": 0,
|
||||
"questions_with_relevant_retrieved": 0,
|
||||
"exact_matches": 0,
|
||||
"semantic_matches": 0,
|
||||
"number_matches": 0,
|
||||
"search_times": [],
|
||||
"detailed_results": [],
|
||||
}
|
||||
|
||||
for item in data:
|
||||
question = item["question"]
|
||||
evidence_texts = [ev["evidence_text"] for ev in item.get("evidence", [])]
|
||||
expected_answer = item["answer"]
|
||||
|
||||
if not evidence_texts:
|
||||
continue
|
||||
|
||||
metrics["total_questions"] += 1
|
||||
|
||||
# Search for relevant documents
|
||||
start_time = time.time()
|
||||
results = self.searcher.search(question, top_k=top_k, complexity=64)
|
||||
search_time = time.time() - start_time
|
||||
metrics["search_times"].append(search_time)
|
||||
|
||||
# Evaluate retrieved results
|
||||
found_relevant = False
|
||||
match_types = []
|
||||
|
||||
for evidence_text in evidence_texts:
|
||||
for i, result in enumerate(results):
|
||||
retrieved_text = result.text
|
||||
|
||||
# Method 1: Exact substring match
|
||||
if self._has_exact_overlap(evidence_text, retrieved_text):
|
||||
found_relevant = True
|
||||
match_types.append(f"exact_match_rank_{i + 1}")
|
||||
metrics["exact_matches"] += 1
|
||||
break
|
||||
|
||||
# Method 2: Key numbers match
|
||||
elif self._has_number_match(evidence_text, retrieved_text, expected_answer):
|
||||
found_relevant = True
|
||||
match_types.append(f"number_match_rank_{i + 1}")
|
||||
metrics["number_matches"] += 1
|
||||
break
|
||||
|
||||
# Method 3: Semantic similarity (word overlap with lower threshold)
|
||||
elif self._has_semantic_similarity(
|
||||
evidence_text, retrieved_text, threshold=0.2
|
||||
):
|
||||
found_relevant = True
|
||||
match_types.append(f"semantic_match_rank_{i + 1}")
|
||||
metrics["semantic_matches"] += 1
|
||||
break
|
||||
|
||||
if found_relevant:
|
||||
metrics["questions_with_relevant_retrieved"] += 1
|
||||
|
||||
# Store detailed result
|
||||
metrics["detailed_results"].append(
|
||||
{
|
||||
"question": question,
|
||||
"expected_answer": expected_answer,
|
||||
"found_relevant": found_relevant,
|
||||
"match_types": match_types,
|
||||
"search_time": search_time,
|
||||
"top_results": [
|
||||
{"text": r.text[:200] + "...", "score": r.score, "metadata": r.metadata}
|
||||
for r in results[:3]
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate metrics
|
||||
if metrics["total_questions"] > 0:
|
||||
metrics["question_coverage"] = (
|
||||
metrics["questions_with_relevant_retrieved"] / metrics["total_questions"]
|
||||
)
|
||||
metrics["avg_search_time"] = sum(metrics["search_times"]) / len(metrics["search_times"])
|
||||
|
||||
# Match type breakdown
|
||||
metrics["exact_match_rate"] = metrics["exact_matches"] / metrics["total_questions"]
|
||||
metrics["number_match_rate"] = metrics["number_matches"] / metrics["total_questions"]
|
||||
metrics["semantic_match_rate"] = (
|
||||
metrics["semantic_matches"] / metrics["total_questions"]
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def evaluate_qa_intelligent(self, data: list[dict], max_samples: Optional[int] = None) -> dict:
|
||||
"""
|
||||
Intelligent QA evaluation using LLM-based answer comparison
|
||||
Similar to VectifyAI/Mafin approach
|
||||
"""
|
||||
if not self.chat or not self.openai_client:
|
||||
print("⚠️ Skipping QA evaluation (no OpenAI API key provided)")
|
||||
return {"accuracy": 0.0, "total_questions": 0}
|
||||
|
||||
print("🤖 Evaluating QA performance...")
|
||||
|
||||
if max_samples:
|
||||
data = data[:max_samples]
|
||||
print(f"📝 Using first {max_samples} samples for QA evaluation")
|
||||
|
||||
results = []
|
||||
correct_answers = 0
|
||||
|
||||
for i, item in enumerate(data):
|
||||
question = item["question"]
|
||||
expected_answer = item["answer"]
|
||||
|
||||
print(f"Question {i + 1}/{len(data)}: {question[:80]}...")
|
||||
|
||||
try:
|
||||
# Get answer from LEANN
|
||||
start_time = time.time()
|
||||
generated_answer = self.chat.ask(question)
|
||||
qa_time = time.time() - start_time
|
||||
|
||||
# Intelligent evaluation using LLM
|
||||
is_correct = self._evaluate_answer_with_llm(
|
||||
question, expected_answer, generated_answer
|
||||
)
|
||||
|
||||
if is_correct:
|
||||
correct_answers += 1
|
||||
|
||||
results.append(
|
||||
{
|
||||
"question": question,
|
||||
"expected_answer": expected_answer,
|
||||
"generated_answer": generated_answer,
|
||||
"is_correct": is_correct,
|
||||
"qa_time": qa_time,
|
||||
}
|
||||
)
|
||||
|
||||
print(f" ✅ {'Correct' if is_correct else '❌ Incorrect'}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error: {e}")
|
||||
results.append(
|
||||
{
|
||||
"question": question,
|
||||
"expected_answer": expected_answer,
|
||||
"generated_answer": f"ERROR: {e}",
|
||||
"is_correct": False,
|
||||
"qa_time": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
metrics = {
|
||||
"total_questions": len(data),
|
||||
"correct_answers": correct_answers,
|
||||
"accuracy": correct_answers / len(data) if data else 0.0,
|
||||
"avg_qa_time": sum(r["qa_time"] for r in results) / len(results) if results else 0.0,
|
||||
"detailed_results": results,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def _has_exact_overlap(self, evidence_text: str, retrieved_text: str) -> bool:
|
||||
"""Check for exact substring overlap"""
|
||||
# Check if evidence is contained in retrieved text or vice versa
|
||||
return (
|
||||
evidence_text.lower() in retrieved_text.lower()
|
||||
or retrieved_text.lower() in evidence_text.lower()
|
||||
)
|
||||
|
||||
def _has_number_match(
|
||||
self, evidence_text: str, retrieved_text: str, expected_answer: str
|
||||
) -> bool:
|
||||
"""Check if key numbers from evidence/answer appear in retrieved text"""
|
||||
# Extract numbers from evidence and expected answer
|
||||
evidence_numbers = set(re.findall(r"\$?[\d,]+\.?\d*", evidence_text))
|
||||
answer_numbers = set(re.findall(r"\$?[\d,]+\.?\d*", expected_answer))
|
||||
retrieved_numbers = set(re.findall(r"\$?[\d,]+\.?\d*", retrieved_text))
|
||||
|
||||
# Check if any key numbers match
|
||||
key_numbers = evidence_numbers.union(answer_numbers)
|
||||
return bool(key_numbers.intersection(retrieved_numbers))
|
||||
|
||||
def _has_semantic_similarity(
|
||||
self, evidence_text: str, retrieved_text: str, threshold: float = 0.2
|
||||
) -> bool:
|
||||
"""Check semantic similarity using word overlap"""
|
||||
words1 = set(evidence_text.lower().split())
|
||||
words2 = set(retrieved_text.lower().split())
|
||||
|
||||
if len(words1) == 0:
|
||||
return False
|
||||
|
||||
overlap = len(words1.intersection(words2))
|
||||
similarity = overlap / len(words1)
|
||||
|
||||
return similarity >= threshold
|
||||
|
||||
def _evaluate_answer_with_llm(
|
||||
self, question: str, expected_answer: str, generated_answer: str
|
||||
) -> bool:
|
||||
"""
|
||||
Use LLM to evaluate answer equivalence
|
||||
Based on VectifyAI/Mafin approach
|
||||
"""
|
||||
prompt = f"""You are an expert evaluator for AI-generated responses to financial questions. Your task is to determine whether the AI-generated answer correctly answers the question based on the golden answer provided by a human expert.
|
||||
|
||||
Evaluation Criteria:
|
||||
- Numerical Accuracy: Rounding differences should be ignored if they don't meaningfully change the conclusion. Numbers like 1.2 and 1.23 are considered similar.
|
||||
- Fractions, percentages, and numerics could be considered similar. For example: "11 of 14" ≈ "79%" ≈ "0.79".
|
||||
- If the golden answer or any of its equivalence can be inferred from the AI answer, then the AI answer is correct.
|
||||
- The AI answer is correct if it conveys the same meaning, conclusion, or rationale as the golden answer.
|
||||
- If the AI answer is a superset of the golden answer, it is also considered correct.
|
||||
- Subjective judgments are correct as long as they are reasonable and justifiable.
|
||||
|
||||
Question: {question}
|
||||
AI-Generated Answer: {generated_answer}
|
||||
Golden Answer: {expected_answer}
|
||||
|
||||
Your output should be ONLY a boolean value: `True` or `False`, nothing else."""
|
||||
|
||||
try:
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content.strip().lower()
|
||||
return "true" in result
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM evaluation error: {e}")
|
||||
# Fallback to simple number matching
|
||||
return self._simple_answer_check(expected_answer, generated_answer)
|
||||
|
||||
def _simple_answer_check(self, expected: str, generated: str) -> bool:
|
||||
"""Simple fallback evaluation"""
|
||||
# Extract numbers and check for matches
|
||||
expected_numbers = re.findall(r"\$?[\d,]+\.?\d*", expected.lower())
|
||||
generated_numbers = re.findall(r"\$?[\d,]+\.?\d*", generated.lower())
|
||||
|
||||
# Check if main numbers match
|
||||
for exp_num in expected_numbers:
|
||||
if exp_num in generated_numbers:
|
||||
return True
|
||||
|
||||
# Check for key phrase overlap
|
||||
key_words = set(expected.lower().split()) - {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"for",
|
||||
"of",
|
||||
"in",
|
||||
"on",
|
||||
"at",
|
||||
"to",
|
||||
"and",
|
||||
"or",
|
||||
"but",
|
||||
}
|
||||
gen_words = set(generated.lower().split())
|
||||
|
||||
if len(key_words) > 0:
|
||||
overlap = len(key_words.intersection(gen_words))
|
||||
return overlap / len(key_words) >= 0.3
|
||||
|
||||
return False
|
||||
|
||||
def run_evaluation(
|
||||
self,
|
||||
dataset_path: str = "data/financebench_merged.jsonl",
|
||||
top_k: int = 10,
|
||||
qa_samples: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Run complete FinanceBench evaluation"""
|
||||
print("🏦 FinanceBench Evaluation with LEANN")
|
||||
print("=" * 50)
|
||||
print(f"📁 Index: {self.index_path}")
|
||||
print(f"🔍 Top-k: {top_k}")
|
||||
if qa_samples:
|
||||
print(f"🤖 QA samples: {qa_samples}")
|
||||
print()
|
||||
|
||||
# Load dataset
|
||||
data = self.load_dataset(dataset_path)
|
||||
|
||||
# Run retrieval evaluation
|
||||
retrieval_metrics = self.evaluate_retrieval_intelligent(data, top_k=top_k)
|
||||
|
||||
# Run QA evaluation
|
||||
qa_metrics = self.evaluate_qa_intelligent(data, max_samples=qa_samples)
|
||||
|
||||
# Print results
|
||||
self._print_results(retrieval_metrics, qa_metrics)
|
||||
|
||||
return {
|
||||
"retrieval": retrieval_metrics,
|
||||
"qa": qa_metrics,
|
||||
}
|
||||
|
||||
def _print_results(self, retrieval_metrics: dict, qa_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 EVALUATION RESULTS")
|
||||
print("=" * 50)
|
||||
|
||||
print("\n📊 Retrieval Metrics:")
|
||||
print(f" Question Coverage: {retrieval_metrics.get('question_coverage', 0):.1%}")
|
||||
print(f" Exact Match Rate: {retrieval_metrics.get('exact_match_rate', 0):.1%}")
|
||||
print(f" Number Match Rate: {retrieval_metrics.get('number_match_rate', 0):.1%}")
|
||||
print(f" Semantic Match Rate: {retrieval_metrics.get('semantic_match_rate', 0):.1%}")
|
||||
print(f" Avg Search Time: {retrieval_metrics.get('avg_search_time', 0):.3f}s")
|
||||
|
||||
if qa_metrics.get("total_questions", 0) > 0:
|
||||
print("\n🤖 QA Metrics:")
|
||||
print(f" Accuracy: {qa_metrics.get('accuracy', 0):.1%}")
|
||||
print(f" Questions Evaluated: {qa_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg QA Time: {qa_metrics.get('avg_qa_time', 0):.3f}s")
|
||||
|
||||
# Show some example results
|
||||
print("\n📝 Example Results:")
|
||||
for i, result in enumerate(retrieval_metrics.get("detailed_results", [])[:3]):
|
||||
print(f"\n Example {i + 1}:")
|
||||
print(f" Q: {result['question'][:80]}...")
|
||||
print(f" Found relevant: {'✅' if result['found_relevant'] else '❌'}")
|
||||
if result["match_types"]:
|
||||
print(f" Match types: {', '.join(result['match_types'])}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate FinanceBench with LEANN")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||
parser.add_argument("--top-k", type=int, default=10, help="Number of documents to retrieve")
|
||||
parser.add_argument("--qa-samples", type=int, default=None, help="Limit QA evaluation samples")
|
||||
parser.add_argument("--openai-api-key", help="OpenAI API key for QA evaluation")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get OpenAI API key
|
||||
api_key = args.openai_api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not api_key and args.qa_samples != 0:
|
||||
print("⚠️ No OpenAI API key provided. QA evaluation will be skipped.")
|
||||
print(" Set OPENAI_API_KEY environment variable or use --openai-api-key")
|
||||
|
||||
try:
|
||||
# Run evaluation
|
||||
evaluator = FinanceBenchEvaluator(args.index, api_key)
|
||||
results = evaluator.run_evaluation(
|
||||
dataset_path=args.dataset, top_k=args.top_k, qa_samples=args.qa_samples
|
||||
)
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
print(f"\n💾 Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
|
||||
print("\n✅ Evaluation completed!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Evaluation failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
351
benchmarks/financebench/setup_financebench.py
Executable file
351
benchmarks/financebench/setup_financebench.py
Executable file
@@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FinanceBench Complete Setup Script
|
||||
Downloads all PDFs and builds full LEANN datastore
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import pymupdf
|
||||
import requests
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class FinanceBenchSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||
self.data_dir = self.base_dir / data_dir
|
||||
self.pdf_dir = self.data_dir / "pdfs"
|
||||
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||
self.index_dir = self.data_dir / "index"
|
||||
self.download_lock = Lock()
|
||||
|
||||
def download_dataset(self):
|
||||
"""Download the main FinanceBench dataset"""
|
||||
print("📊 Downloading FinanceBench dataset...")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.dataset_file.exists():
|
||||
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||
return
|
||||
|
||||
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(self.dataset_file, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||
|
||||
def get_pdf_list(self):
|
||||
"""Get list of all PDF files from GitHub"""
|
||||
print("📋 Fetching PDF list from GitHub...")
|
||||
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||
)
|
||||
response.raise_for_status()
|
||||
pdf_files = response.json()
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files")
|
||||
return pdf_files
|
||||
|
||||
def download_single_pdf(self, pdf_info, position):
|
||||
"""Download a single PDF file"""
|
||||
pdf_name = pdf_info["name"]
|
||||
pdf_path = self.pdf_dir / pdf_name
|
||||
|
||||
# Skip if already downloaded
|
||||
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||
return f"✅ {pdf_name} (cached)"
|
||||
|
||||
try:
|
||||
# Download PDF
|
||||
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Write to file
|
||||
with self.download_lock:
|
||||
with open(pdf_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ {pdf_name}: {e!s}"
|
||||
|
||||
def download_all_pdfs(self, max_workers: int = 5):
|
||||
"""Download all PDF files with parallel processing"""
|
||||
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_files = self.get_pdf_list()
|
||||
|
||||
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all download tasks
|
||||
future_to_pdf = {
|
||||
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||
for i, pdf_info in enumerate(pdf_files)
|
||||
}
|
||||
|
||||
# Process completed downloads with progress bar
|
||||
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||
for future in as_completed(future_to_pdf):
|
||||
result = future.result()
|
||||
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||
pbar.update(1)
|
||||
|
||||
# Verify downloads
|
||||
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||
|
||||
# Show any failures
|
||||
missing_pdfs = []
|
||||
for pdf_info in pdf_files:
|
||||
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||
missing_pdfs.append(pdf_info["name"])
|
||||
|
||||
if missing_pdfs:
|
||||
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||
for pdf in missing_pdfs[:5]: # Show first 5
|
||||
print(f" - {pdf}")
|
||||
if len(missing_pdfs) > 5:
|
||||
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||
|
||||
def build_leann_index(
|
||||
self,
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
):
|
||||
"""Build LEANN index from all PDFs"""
|
||||
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||
|
||||
# Check if we have PDFs
|
||||
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||
if not pdf_files:
|
||||
raise RuntimeError("No PDF files found! Run download first.")
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files to process")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize builder
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=False, # Store embeddings for speed
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Process PDFs and extract text
|
||||
total_chunks = 0
|
||||
failed_pdfs = []
|
||||
|
||||
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||
try:
|
||||
chunks = self.extract_pdf_text(pdf_path)
|
||||
for chunk in chunks:
|
||||
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||
total_chunks += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||
failed_pdfs.append(pdf_path.name)
|
||||
continue
|
||||
|
||||
# Build index in index directory
|
||||
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||
print(f"🔨 Building index: {index_path}")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
|
||||
print("✅ Index built successfully!")
|
||||
print(f" 📁 Index path: {index_path}")
|
||||
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||
|
||||
if failed_pdfs:
|
||||
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||
|
||||
return str(index_path)
|
||||
|
||||
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||
"""Extract and chunk text from a PDF file"""
|
||||
chunks = []
|
||||
doc = pymupdf.open(pdf_path)
|
||||
|
||||
for page_num in range(len(doc)):
|
||||
page = doc[page_num]
|
||||
text = page.get_text() # type: ignore
|
||||
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source_file": pdf_path.name,
|
||||
"page_number": page_num + 1,
|
||||
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||
"company": pdf_path.name.split("_")[0],
|
||||
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||
}
|
||||
|
||||
# Use recursive character splitting like LangChain
|
||||
if len(text.split()) > 500:
|
||||
# Split by double newlines (paragraphs)
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
|
||||
current_chunk = ""
|
||||
for para in paragraphs:
|
||||
# If adding this paragraph would make chunk too long, save current chunk
|
||||
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
current_chunk = para
|
||||
else:
|
||||
current_chunk = (current_chunk + " " + para).strip()
|
||||
|
||||
# Add the last chunk
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Page is short enough, use as single chunk
|
||||
chunks.append(
|
||||
{
|
||||
"text": text.strip(),
|
||||
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||
}
|
||||
)
|
||||
|
||||
doc.close()
|
||||
return chunks
|
||||
|
||||
def extract_year_from_filename(self, filename: str) -> str:
|
||||
"""Extract year from PDF filename"""
|
||||
# Try to find 4-digit year in filename
|
||||
|
||||
match = re.search(r"(\d{4})", filename)
|
||||
return match.group(1) if match else "unknown"
|
||||
|
||||
def verify_setup(self, index_path: str):
|
||||
"""Verify the setup by testing a simple query"""
|
||||
print("🧪 Verifying setup with test query...")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
# Test query
|
||||
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||
results = searcher.search(test_query, top_k=3)
|
||||
|
||||
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||
for i, result in enumerate(results, 1):
|
||||
company = result.metadata.get("company", "Unknown")
|
||||
year = result.metadata.get("doc_period", "Unknown")
|
||||
page = result.metadata.get("page_number", "Unknown")
|
||||
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||
print(f" {result.text[:100]}...")
|
||||
|
||||
searcher.cleanup()
|
||||
print("✅ Setup verification completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Setup verification failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument(
|
||||
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model",
|
||||
)
|
||||
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🏦 FinanceBench Complete Setup")
|
||||
print("=" * 50)
|
||||
|
||||
setup = FinanceBenchSetup(args.data_dir)
|
||||
|
||||
try:
|
||||
# Step 1: Download dataset
|
||||
setup.download_dataset()
|
||||
|
||||
# Step 2: Download PDFs
|
||||
if not args.skip_download:
|
||||
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||
else:
|
||||
print("⏭️ Skipping PDF download")
|
||||
|
||||
# Step 3: Build LEANN index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
backend=args.backend, embedding_model=args.embedding_model
|
||||
)
|
||||
|
||||
# Step 4: Verify setup
|
||||
setup.verify_setup(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
|
||||
print("\n🎉 FinanceBench setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("\nNext steps:")
|
||||
print(
|
||||
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||
)
|
||||
print(
|
||||
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user