Datastore reproduce (#3)

* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann

* docs: embedding pruning

* refactor: passage structure

* feat: reproducible research datas, rpj_wiki & dpr

* refactor: chat and base searcher

* feat: chat on mps
This commit is contained in:
Andy Lee
2025-07-11 23:37:23 -07:00
committed by GitHub
parent 91a026f38b
commit eb6f504789
22 changed files with 5070 additions and 3681 deletions

View File

@@ -74,7 +74,7 @@ def main():
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<")
for i, res in enumerate(results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...")
@@ -107,7 +107,7 @@ def main():
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# Compare results
print(f"\n--- Result comparison ---")
@@ -116,8 +116,8 @@ def main():
print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))):
basic_score = results[i]['score']
recompute_score = recompute_results[i]['score']
basic_score = results[i].score
recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")

View File

@@ -1,6 +1,7 @@
import faulthandler
faulthandler.enable()
import argparse
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.readers.base import BaseReader
from llama_index.node_parser.docling import DoclingNodeParser
@@ -69,17 +70,30 @@ if not INDEX_DIR.exists():
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
async def main():
async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
llm_config = {
"type": args.llm,
"model": args.model,
"host": args.host
}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
asyncio.run(main())
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.")
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).")
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Multi-Vector Aggregator for Fat Embeddings
==========================================
This module implements aggregation strategies for multi-vector embeddings,
similar to ColPali's approach where multiple patch vectors represent a single document.
Key features:
- MaxSim aggregation (take maximum similarity across patches)
- Voting-based aggregation (count patch matches)
- Weighted aggregation (attention-score weighted)
- Spatial clustering of matching patches
- Document-level result consolidation
"""
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import json
@dataclass
class PatchResult:
"""Represents a single patch search result."""
patch_id: int
image_name: str
image_path: str
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
score: float
attention_score: float
scale: float
metadata: Dict[str, Any]
@dataclass
class AggregatedResult:
"""Represents an aggregated document-level result."""
image_name: str
image_path: str
doc_score: float
patch_count: int
best_patch: PatchResult
all_patches: List[PatchResult]
aggregation_method: str
spatial_clusters: Optional[List[List[PatchResult]]] = None
class MultiVectorAggregator:
"""
Aggregates multiple patch-level results into document-level results.
"""
def __init__(self,
aggregation_method: str = "maxsim",
spatial_clustering: bool = True,
cluster_distance_threshold: float = 100.0):
"""
Initialize the aggregator.
Args:
aggregation_method: "maxsim", "voting", "weighted", or "mean"
spatial_clustering: Whether to cluster spatially close patches
cluster_distance_threshold: Distance threshold for spatial clustering
"""
self.aggregation_method = aggregation_method
self.spatial_clustering = spatial_clustering
self.cluster_distance_threshold = cluster_distance_threshold
def aggregate_results(self,
search_results: List[Dict[str, Any]],
top_k: int = 10) -> List[AggregatedResult]:
"""
Aggregate patch-level search results into document-level results.
Args:
search_results: List of search results from LeannSearcher
top_k: Number of top documents to return
Returns:
List of aggregated document results
"""
# Group results by image
image_groups = defaultdict(list)
for result in search_results:
metadata = result.metadata
if "image_name" in metadata and "patch_id" in metadata:
patch_result = PatchResult(
patch_id=metadata["patch_id"],
image_name=metadata["image_name"],
image_path=metadata["image_path"],
coordinates=tuple(metadata["coordinates"]),
score=result.score,
attention_score=metadata.get("attention_score", 0.0),
scale=metadata.get("scale", 1.0),
metadata=metadata
)
image_groups[metadata["image_name"]].append(patch_result)
# Aggregate each image group
aggregated_results = []
for image_name, patches in image_groups.items():
if len(patches) == 0:
continue
agg_result = self._aggregate_image_patches(image_name, patches)
aggregated_results.append(agg_result)
# Sort by aggregated score and return top-k
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
return aggregated_results[:top_k]
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
"""Aggregate patches for a single image."""
if self.aggregation_method == "maxsim":
doc_score = max(patch.score for patch in patches)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "voting":
# Count patches above threshold
threshold = np.percentile([p.score for p in patches], 75)
doc_score = sum(1 for patch in patches if patch.score >= threshold)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "weighted":
# Weight by attention scores
total_weighted_score = sum(p.score * p.attention_score for p in patches)
total_weights = sum(p.attention_score for p in patches)
doc_score = total_weighted_score / max(total_weights, 1e-8)
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
elif self.aggregation_method == "mean":
doc_score = np.mean([patch.score for patch in patches])
best_patch = max(patches, key=lambda p: p.score)
else:
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
# Spatial clustering if enabled
spatial_clusters = None
if self.spatial_clustering:
spatial_clusters = self._cluster_patches_spatially(patches)
return AggregatedResult(
image_name=image_name,
image_path=patches[0].image_path,
doc_score=float(doc_score),
patch_count=len(patches),
best_patch=best_patch,
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
aggregation_method=self.aggregation_method,
spatial_clusters=spatial_clusters
)
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
"""Cluster patches that are spatially close to each other."""
if len(patches) <= 1:
return [patches]
clusters = []
remaining_patches = patches.copy()
while remaining_patches:
# Start new cluster with highest scoring remaining patch
seed_patch = max(remaining_patches, key=lambda p: p.score)
current_cluster = [seed_patch]
remaining_patches.remove(seed_patch)
# Add nearby patches to cluster
added_to_cluster = True
while added_to_cluster:
added_to_cluster = False
for patch in remaining_patches.copy():
if self._is_patch_nearby(patch, current_cluster):
current_cluster.append(patch)
remaining_patches.remove(patch)
added_to_cluster = True
clusters.append(current_cluster)
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
"""Check if a patch is spatially close to any patch in the cluster."""
patch_center = self._get_patch_center(patch.coordinates)
for cluster_patch in cluster:
cluster_center = self._get_patch_center(cluster_patch.coordinates)
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
(patch_center[1] - cluster_center[1])**2)
if distance <= self.cluster_distance_threshold:
return True
return False
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Get center point of a patch."""
x1, y1, x2, y2 = coordinates
return ((x1 + x2) / 2, (y1 + y2) / 2)
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
"""Pretty print aggregated results."""
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
print("=" * 80)
for i, result in enumerate(results):
print(f"\n{i+1}. {result.image_name}")
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
print(f" Path: {result.image_path}")
# Show best patch
best = result.best_patch
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
# Show top patches
print(f" 📍 Top Patches:")
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
# Show spatial clusters if available
if result.spatial_clusters and len(result.spatial_clusters) > 1:
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
cluster_score = max(p.score for p in cluster)
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
def demo_aggregation():
"""Demonstrate the multi-vector aggregation functionality."""
print("=== Multi-Vector Aggregation Demo ===")
# Simulate some patch-level search results
# In real usage, these would come from LeannSearcher.search()
class MockResult:
def __init__(self, score, metadata):
self.score = score
self.metadata = metadata
# Simulate results for 2 images with multiple patches each
mock_results = [
# Image 1: cats_and_kitchen.jpg - 4 patches
MockResult(0.85, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 3,
"coordinates": [100, 50, 224, 174], # Kitchen area
"attention_score": 0.92,
"scale": 1.0
}),
MockResult(0.78, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 7,
"coordinates": [200, 300, 324, 424], # Cat area
"attention_score": 0.88,
"scale": 1.0
}),
MockResult(0.72, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 12,
"coordinates": [150, 100, 274, 224], # Appliances
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.65, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 15,
"coordinates": [50, 250, 174, 374], # Furniture
"attention_score": 0.70,
"scale": 1.0
}),
# Image 2: city_street.jpg - 3 patches
MockResult(0.68, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 2,
"coordinates": [300, 100, 424, 224], # Buildings
"attention_score": 0.80,
"scale": 1.0
}),
MockResult(0.62, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 8,
"coordinates": [100, 350, 224, 474], # Street level
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.55, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 11,
"coordinates": [400, 200, 524, 324], # Sky area
"attention_score": 0.60,
"scale": 1.0
}),
]
# Test different aggregation methods
methods = ["maxsim", "voting", "weighted", "mean"]
for method in methods:
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
aggregator = MultiVectorAggregator(
aggregation_method=method,
spatial_clustering=True,
cluster_distance_threshold=100.0
)
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
aggregator.print_aggregated_results(aggregated)
if __name__ == "__main__":
demo_aggregation()

157
examples/run_evaluation.py Normal file
View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import json
import argparse
import time
from pathlib import Path
import sys
import numpy as np
from typing import List, Dict, Any
import glob
import pickle
# Add project root to path to allow importing from leann
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from leann.api import LeannSearcher
# --- Configuration ---
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
# Ground truth files for different datasets
GROUND_TRUTH_FILES = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
}
# Old passages for different datasets
OLD_PASSAGES_GLOBS = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
}
# --- Helper Class to Load Original Passages ---
class OldPassageLoader:
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
def __init__(self, passages_glob: str):
self.jsonl_paths = sorted(glob.glob(passages_glob))
self.offsets = {}
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
print("Building offset map for original passages...")
for i, shard_path_str in enumerate(self.jsonl_paths):
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
if not old_idx_path.exists(): continue
with open(old_idx_path, 'rb') as f:
shard_offsets = pickle.load(f)
for pid, offset in shard_offsets.items():
self.offsets[str(pid)] = (i, offset)
print("Offset map for original passages is ready.")
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
pid = str(pid)
if pid not in self.offsets:
raise ValueError(f"Passage ID {pid} not found in offsets")
file_idx, offset = self.offsets[pid]
fp = self.fps[file_idx]
fp.seek(offset)
return json.loads(fp.readline())
def __del__(self):
for fp in self.fps:
fp.close()
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
queries.append(data['query'])
return queries
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
args = parser.parse_args()
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
# Detect dataset type from index path
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
dataset_type = "rpj_wiki"
print(f"INFO: Detected dataset type: {dataset_type}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(NQ_QUERIES_FILE)
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
print(f"INFO: Using ground truth file: {golden_results_file}")
print(f"INFO: Using old passages glob: {old_passages_glob}")
with open(golden_results_file, 'r') as f:
golden_results_data = json.load(f)
old_passage_loader = OldPassageLoader(old_passages_glob)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
golden_ids = golden_results_data["indices"][i][:args.top_k]
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print(f"--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print(f"\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()