Datastore reproduce (#3)
* fix: diskann zmq port and passages * feat: auto discovery of packages and fix passage gen for diskann * docs: embedding pruning * refactor: passage structure * feat: reproducible research datas, rpj_wiki & dpr * refactor: chat and base searcher * feat: chat on mps
This commit is contained in:
@@ -74,7 +74,7 @@ def main():
|
||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||
print(">>> Basic search results <<<")
|
||||
for i, res in enumerate(results, 1):
|
||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||
|
||||
# --- 3. Recompute search demo ---
|
||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
||||
@@ -107,7 +107,7 @@ def main():
|
||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||
print(">>> Recompute search results <<<")
|
||||
for i, res in enumerate(recompute_results, 1):
|
||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||
|
||||
# Compare results
|
||||
print(f"\n--- Result comparison ---")
|
||||
@@ -116,8 +116,8 @@ def main():
|
||||
|
||||
print("\nBasic search vs Recompute results:")
|
||||
for i in range(min(len(results), len(recompute_results))):
|
||||
basic_score = results[i]['score']
|
||||
recompute_score = recompute_results[i]['score']
|
||||
basic_score = results[i].score
|
||||
recompute_score = recompute_results[i].score
|
||||
score_diff = abs(basic_score - recompute_score)
|
||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import faulthandler
|
||||
faulthandler.enable()
|
||||
|
||||
import argparse
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.node_parser.docling import DoclingNodeParser
|
||||
@@ -69,17 +70,30 @@ if not INDEX_DIR.exists():
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
async def main():
|
||||
async def main(args):
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
|
||||
llm_config = {
|
||||
"type": args.llm,
|
||||
"model": args.model,
|
||||
"host": args.host
|
||||
}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
|
||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
||||
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.")
|
||||
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).")
|
||||
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
319
examples/multi_vector_aggregator.py
Normal file
319
examples/multi_vector_aggregator.py
Normal file
@@ -0,0 +1,319 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-Vector Aggregator for Fat Embeddings
|
||||
==========================================
|
||||
|
||||
This module implements aggregation strategies for multi-vector embeddings,
|
||||
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||
|
||||
Key features:
|
||||
- MaxSim aggregation (take maximum similarity across patches)
|
||||
- Voting-based aggregation (count patch matches)
|
||||
- Weighted aggregation (attention-score weighted)
|
||||
- Spatial clustering of matching patches
|
||||
- Document-level result consolidation
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
@dataclass
|
||||
class PatchResult:
|
||||
"""Represents a single patch search result."""
|
||||
patch_id: int
|
||||
image_name: str
|
||||
image_path: str
|
||||
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||
score: float
|
||||
attention_score: float
|
||||
scale: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class AggregatedResult:
|
||||
"""Represents an aggregated document-level result."""
|
||||
image_name: str
|
||||
image_path: str
|
||||
doc_score: float
|
||||
patch_count: int
|
||||
best_patch: PatchResult
|
||||
all_patches: List[PatchResult]
|
||||
aggregation_method: str
|
||||
spatial_clusters: Optional[List[List[PatchResult]]] = None
|
||||
|
||||
class MultiVectorAggregator:
|
||||
"""
|
||||
Aggregates multiple patch-level results into document-level results.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
aggregation_method: str = "maxsim",
|
||||
spatial_clustering: bool = True,
|
||||
cluster_distance_threshold: float = 100.0):
|
||||
"""
|
||||
Initialize the aggregator.
|
||||
|
||||
Args:
|
||||
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||
spatial_clustering: Whether to cluster spatially close patches
|
||||
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||
"""
|
||||
self.aggregation_method = aggregation_method
|
||||
self.spatial_clustering = spatial_clustering
|
||||
self.cluster_distance_threshold = cluster_distance_threshold
|
||||
|
||||
def aggregate_results(self,
|
||||
search_results: List[Dict[str, Any]],
|
||||
top_k: int = 10) -> List[AggregatedResult]:
|
||||
"""
|
||||
Aggregate patch-level search results into document-level results.
|
||||
|
||||
Args:
|
||||
search_results: List of search results from LeannSearcher
|
||||
top_k: Number of top documents to return
|
||||
|
||||
Returns:
|
||||
List of aggregated document results
|
||||
"""
|
||||
# Group results by image
|
||||
image_groups = defaultdict(list)
|
||||
|
||||
for result in search_results:
|
||||
metadata = result.metadata
|
||||
if "image_name" in metadata and "patch_id" in metadata:
|
||||
patch_result = PatchResult(
|
||||
patch_id=metadata["patch_id"],
|
||||
image_name=metadata["image_name"],
|
||||
image_path=metadata["image_path"],
|
||||
coordinates=tuple(metadata["coordinates"]),
|
||||
score=result.score,
|
||||
attention_score=metadata.get("attention_score", 0.0),
|
||||
scale=metadata.get("scale", 1.0),
|
||||
metadata=metadata
|
||||
)
|
||||
image_groups[metadata["image_name"]].append(patch_result)
|
||||
|
||||
# Aggregate each image group
|
||||
aggregated_results = []
|
||||
for image_name, patches in image_groups.items():
|
||||
if len(patches) == 0:
|
||||
continue
|
||||
|
||||
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||
aggregated_results.append(agg_result)
|
||||
|
||||
# Sort by aggregated score and return top-k
|
||||
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||
return aggregated_results[:top_k]
|
||||
|
||||
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
|
||||
"""Aggregate patches for a single image."""
|
||||
|
||||
if self.aggregation_method == "maxsim":
|
||||
doc_score = max(patch.score for patch in patches)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "voting":
|
||||
# Count patches above threshold
|
||||
threshold = np.percentile([p.score for p in patches], 75)
|
||||
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "weighted":
|
||||
# Weight by attention scores
|
||||
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||
total_weights = sum(p.attention_score for p in patches)
|
||||
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||
|
||||
elif self.aggregation_method == "mean":
|
||||
doc_score = np.mean([patch.score for patch in patches])
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||
|
||||
# Spatial clustering if enabled
|
||||
spatial_clusters = None
|
||||
if self.spatial_clustering:
|
||||
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||
|
||||
return AggregatedResult(
|
||||
image_name=image_name,
|
||||
image_path=patches[0].image_path,
|
||||
doc_score=float(doc_score),
|
||||
patch_count=len(patches),
|
||||
best_patch=best_patch,
|
||||
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||
aggregation_method=self.aggregation_method,
|
||||
spatial_clusters=spatial_clusters
|
||||
)
|
||||
|
||||
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
|
||||
"""Cluster patches that are spatially close to each other."""
|
||||
if len(patches) <= 1:
|
||||
return [patches]
|
||||
|
||||
clusters = []
|
||||
remaining_patches = patches.copy()
|
||||
|
||||
while remaining_patches:
|
||||
# Start new cluster with highest scoring remaining patch
|
||||
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||
current_cluster = [seed_patch]
|
||||
remaining_patches.remove(seed_patch)
|
||||
|
||||
# Add nearby patches to cluster
|
||||
added_to_cluster = True
|
||||
while added_to_cluster:
|
||||
added_to_cluster = False
|
||||
for patch in remaining_patches.copy():
|
||||
if self._is_patch_nearby(patch, current_cluster):
|
||||
current_cluster.append(patch)
|
||||
remaining_patches.remove(patch)
|
||||
added_to_cluster = True
|
||||
|
||||
clusters.append(current_cluster)
|
||||
|
||||
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||
|
||||
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
|
||||
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||
patch_center = self._get_patch_center(patch.coordinates)
|
||||
|
||||
for cluster_patch in cluster:
|
||||
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
|
||||
(patch_center[1] - cluster_center[1])**2)
|
||||
|
||||
if distance <= self.cluster_distance_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
||||
"""Get center point of a patch."""
|
||||
x1, y1, x2, y2 = coordinates
|
||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
|
||||
"""Pretty print aggregated results."""
|
||||
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||
print("=" * 80)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n{i+1}. {result.image_name}")
|
||||
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||
print(f" Path: {result.image_path}")
|
||||
|
||||
# Show best patch
|
||||
best = result.best_patch
|
||||
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
|
||||
|
||||
# Show top patches
|
||||
print(f" 📍 Top Patches:")
|
||||
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
|
||||
|
||||
# Show spatial clusters if available
|
||||
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||
cluster_score = max(p.score for p in cluster)
|
||||
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
|
||||
|
||||
def demo_aggregation():
|
||||
"""Demonstrate the multi-vector aggregation functionality."""
|
||||
print("=== Multi-Vector Aggregation Demo ===")
|
||||
|
||||
# Simulate some patch-level search results
|
||||
# In real usage, these would come from LeannSearcher.search()
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, score, metadata):
|
||||
self.score = score
|
||||
self.metadata = metadata
|
||||
|
||||
# Simulate results for 2 images with multiple patches each
|
||||
mock_results = [
|
||||
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||
MockResult(0.85, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 3,
|
||||
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||
"attention_score": 0.92,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.78, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 7,
|
||||
"coordinates": [200, 300, 324, 424], # Cat area
|
||||
"attention_score": 0.88,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.72, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 12,
|
||||
"coordinates": [150, 100, 274, 224], # Appliances
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.65, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 15,
|
||||
"coordinates": [50, 250, 174, 374], # Furniture
|
||||
"attention_score": 0.70,
|
||||
"scale": 1.0
|
||||
}),
|
||||
|
||||
# Image 2: city_street.jpg - 3 patches
|
||||
MockResult(0.68, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 2,
|
||||
"coordinates": [300, 100, 424, 224], # Buildings
|
||||
"attention_score": 0.80,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.62, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 8,
|
||||
"coordinates": [100, 350, 224, 474], # Street level
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.55, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 11,
|
||||
"coordinates": [400, 200, 524, 324], # Sky area
|
||||
"attention_score": 0.60,
|
||||
"scale": 1.0
|
||||
}),
|
||||
]
|
||||
|
||||
# Test different aggregation methods
|
||||
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||
|
||||
for method in methods:
|
||||
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
|
||||
|
||||
aggregator = MultiVectorAggregator(
|
||||
aggregation_method=method,
|
||||
spatial_clustering=True,
|
||||
cluster_distance_threshold=100.0
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||
aggregator.print_aggregated_results(aggregated)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_aggregation()
|
||||
157
examples/run_evaluation.py
Normal file
157
examples/run_evaluation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This script runs a recall evaluation on a given LEANN index.
|
||||
It correctly compares results by fetching the text content for both the new search
|
||||
results and the golden standard results, making the comparison robust to ID changes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
# Add project root to path to allow importing from leann
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
# --- Configuration ---
|
||||
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
|
||||
|
||||
# Ground truth files for different datasets
|
||||
GROUND_TRUTH_FILES = {
|
||||
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
|
||||
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
|
||||
}
|
||||
|
||||
# Old passages for different datasets
|
||||
OLD_PASSAGES_GLOBS = {
|
||||
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
|
||||
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
|
||||
}
|
||||
|
||||
# --- Helper Class to Load Original Passages ---
|
||||
class OldPassageLoader:
|
||||
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
|
||||
def __init__(self, passages_glob: str):
|
||||
self.jsonl_paths = sorted(glob.glob(passages_glob))
|
||||
self.offsets = {}
|
||||
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
|
||||
print("Building offset map for original passages...")
|
||||
for i, shard_path_str in enumerate(self.jsonl_paths):
|
||||
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
|
||||
if not old_idx_path.exists(): continue
|
||||
with open(old_idx_path, 'rb') as f:
|
||||
shard_offsets = pickle.load(f)
|
||||
for pid, offset in shard_offsets.items():
|
||||
self.offsets[str(pid)] = (i, offset)
|
||||
print("Offset map for original passages is ready.")
|
||||
|
||||
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
|
||||
pid = str(pid)
|
||||
if pid not in self.offsets:
|
||||
raise ValueError(f"Passage ID {pid} not found in offsets")
|
||||
file_idx, offset = self.offsets[pid]
|
||||
fp = self.fps[file_idx]
|
||||
fp.seek(offset)
|
||||
return json.loads(fp.readline())
|
||||
|
||||
def __del__(self):
|
||||
for fp in self.fps:
|
||||
fp.close()
|
||||
|
||||
def load_queries(file_path: Path) -> List[str]:
|
||||
queries = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
queries.append(data['query'])
|
||||
return queries
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
|
||||
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
|
||||
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
|
||||
|
||||
# Detect dataset type from index path
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
|
||||
dataset_type = "rpj_wiki"
|
||||
|
||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(args.index_path)
|
||||
queries = load_queries(NQ_QUERIES_FILE)
|
||||
|
||||
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
|
||||
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
|
||||
|
||||
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||
print(f"INFO: Using old passages glob: {old_passages_glob}")
|
||||
|
||||
with open(golden_results_file, 'r') as f:
|
||||
golden_results_data = json.load(f)
|
||||
|
||||
old_passage_loader = OldPassageLoader(old_passages_glob)
|
||||
|
||||
num_eval_queries = min(args.num_queries, len(queries))
|
||||
queries = queries[:num_eval_queries]
|
||||
|
||||
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
||||
recall_scores = []
|
||||
search_times = []
|
||||
|
||||
for i in range(num_eval_queries):
|
||||
start_time = time.time()
|
||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
||||
search_times.append(time.time() - start_time)
|
||||
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
golden_ids = golden_results_data["indices"][i][:args.top_k]
|
||||
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
|
||||
|
||||
overlap = len(new_texts & golden_texts)
|
||||
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||
recall_scores.append(recall)
|
||||
|
||||
print("\n--- EVALUATION RESULTS ---")
|
||||
print(f"Query: {queries[i]}")
|
||||
print(f"New Results: {new_texts}")
|
||||
print(f"Golden Results: {golden_texts}")
|
||||
print(f"Overlap: {overlap}")
|
||||
print(f"Recall: {recall}")
|
||||
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||
print(f"--------------------------------")
|
||||
|
||||
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||
avg_time = np.mean(search_times) if search_times else 0
|
||||
|
||||
print(f"\n🎉 --- Evaluation Complete ---")
|
||||
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user