Compare commits
11 Commits
feat/multi
...
embed-laun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed15776564 | ||
|
|
8d202b8b0e | ||
|
|
9ac9eab48d | ||
|
|
cd1d853a46 | ||
|
|
253680043a | ||
|
|
36c44b8806 | ||
|
|
66c6aad3e4 | ||
|
|
29ef3c95dc | ||
|
|
469dce0045 | ||
|
|
0ac676f9cb | ||
|
|
97c9f39704 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -91,8 +91,7 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
*.npy
|
||||
*.db
|
||||
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the current directory to path to import leann_multi_vector
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||
from PIL import Image
|
||||
|
||||
# Ensure repo paths are importable
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Set environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image."""
|
||||
# Create a simple RGB image (800x600)
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
return img
|
||||
|
||||
|
||||
def load_test_image_from_file():
|
||||
"""Try to load an image from the indexes directory if available."""
|
||||
# Try to find an existing image in the indexes directory
|
||||
indexes_dir = Path(__file__).parent / "indexes"
|
||||
|
||||
# Look for images in common locations
|
||||
possible_paths = [
|
||||
indexes_dir / "vidore_fastplaid" / "images",
|
||||
indexes_dir / "colvision_large.leann.images",
|
||||
indexes_dir / "colvision.leann.images",
|
||||
]
|
||||
|
||||
for img_dir in possible_paths:
|
||||
if img_dir.exists():
|
||||
# Find first image file
|
||||
for ext in [".png", ".jpg", ".jpeg"]:
|
||||
for img_file in img_dir.glob(f"*{ext}"):
|
||||
print(f"Loading test image from: {img_file}")
|
||||
return Image.open(img_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Testing ColQwen2 Forward Pass")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Load or create test image
|
||||
print("\n[Step 1] Loading test image...")
|
||||
test_image = load_test_image_from_file()
|
||||
if test_image is None:
|
||||
print("No existing image found, creating a simple test image...")
|
||||
test_image = create_test_image()
|
||||
else:
|
||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||
|
||||
# Convert to RGB if needed
|
||||
if test_image.mode != "RGB":
|
||||
test_image = test_image.convert("RGB")
|
||||
print(f"✓ Converted to RGB: {test_image.size}")
|
||||
|
||||
# Step 2: Load model
|
||||
print("\n[Step 2] Loading ColQwen2 model...")
|
||||
try:
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||
print(f"✓ Model loaded: {model_name}")
|
||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||
|
||||
# Print model info
|
||||
if hasattr(model, "device"):
|
||||
print(f"✓ Model device: {model.device}")
|
||||
if hasattr(model, "dtype"):
|
||||
print(f"✓ Model dtype: {model.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Step 3: Test forward pass
|
||||
print("\n[Step 3] Running forward pass...")
|
||||
try:
|
||||
# Use the _embed_images function which handles batching and forward pass
|
||||
images = [test_image]
|
||||
print(f"Processing {len(images)} image(s)...")
|
||||
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
print("✓ Forward pass completed!")
|
||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||
|
||||
if len(doc_vecs) > 0:
|
||||
emb = doc_vecs[0]
|
||||
print(f"✓ Embedding shape: {emb.shape}")
|
||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||
print("✓ Embedding stats:")
|
||||
print(f" - Min: {emb.min().item():.4f}")
|
||||
print(f" - Max: {emb.max().item():.4f}")
|
||||
print(f" - Mean: {emb.mean().item():.4f}")
|
||||
print(f" - Std: {emb.std().item():.4f}")
|
||||
|
||||
# Check for NaN or Inf
|
||||
if torch.isnan(emb).any():
|
||||
print("⚠ Warning: Embedding contains NaN values!")
|
||||
if torch.isinf(emb).any():
|
||||
print("⚠ Warning: Embedding contains Inf values!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during forward pass: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,7 +3,6 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
@@ -195,7 +194,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[Image.Image](images),
|
||||
batch_size=32,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
@@ -219,47 +218,32 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||
|
||||
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
model.eval()
|
||||
|
||||
# Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings:
|
||||
# 1. MTEB receives batch["text"] which already includes instruction/prompt (from _combine_queries_with_instruction_text)
|
||||
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10
|
||||
# 3. Calls processor.process_queries(batch) where batch is now a list of strings
|
||||
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
|
||||
#
|
||||
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
|
||||
# We need to match this exactly to reproduce MTEB results
|
||||
|
||||
all_embeds = []
|
||||
batch_size = 32 # Match MTEB's default batch_size
|
||||
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
|
||||
batch_queries = queries[i : i + batch_size]
|
||||
|
||||
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
|
||||
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
|
||||
batch = [
|
||||
processor.query_prefix + t + processor.query_augmentation_token * 10
|
||||
for t in batch_queries
|
||||
]
|
||||
inputs = processor.process_queries(batch)
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
q_vecs: list[Any] = []
|
||||
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
outs = model(**inputs)
|
||||
embeddings_query = model(**batch_query)
|
||||
else:
|
||||
outs = model(**inputs)
|
||||
|
||||
# Match MTEB: convert to float32 on CPU
|
||||
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
|
||||
|
||||
return all_embeds
|
||||
embeddings_query = model(**batch_query)
|
||||
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
return q_vecs
|
||||
|
||||
|
||||
def _build_index(
|
||||
@@ -299,279 +283,6 @@ def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
|
||||
return None
|
||||
|
||||
|
||||
def _build_fast_plaid_index(
|
||||
index_path: str,
|
||||
doc_vecs: list[Any],
|
||||
filepaths: list[str],
|
||||
images: list[Image.Image],
|
||||
) -> tuple[Any, float]:
|
||||
"""
|
||||
Build a Fast-Plaid index from document embeddings.
|
||||
|
||||
Args:
|
||||
index_path: Path to save the Fast-Plaid index
|
||||
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
|
||||
filepaths: List of filepath identifiers for each document
|
||||
images: List of PIL Images corresponding to each document
|
||||
|
||||
Returns:
|
||||
Tuple of (FastPlaid index object, build_time_in_seconds)
|
||||
"""
|
||||
import torch
|
||||
from fast_plaid import search as fast_plaid_search
|
||||
|
||||
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
|
||||
_t0 = time.perf_counter()
|
||||
|
||||
# Convert doc_vecs to list of tensors
|
||||
documents_embeddings = []
|
||||
for i, vec in enumerate(doc_vecs):
|
||||
if i % 1000 == 0:
|
||||
print(f" Converting embedding {i}/{len(doc_vecs)}...")
|
||||
if not isinstance(vec, torch.Tensor):
|
||||
vec = (
|
||||
torch.tensor(vec)
|
||||
if isinstance(vec, np.ndarray)
|
||||
else torch.from_numpy(np.array(vec))
|
||||
)
|
||||
# Ensure float32 for Fast-Plaid
|
||||
if vec.dtype != torch.float32:
|
||||
vec = vec.float()
|
||||
documents_embeddings.append(vec)
|
||||
|
||||
print(f" Converted {len(documents_embeddings)} embeddings")
|
||||
if len(documents_embeddings) > 0:
|
||||
print(f" First embedding shape: {documents_embeddings[0].shape}")
|
||||
print(f" First embedding dtype: {documents_embeddings[0].dtype}")
|
||||
|
||||
# Prepare metadata for Fast-Plaid
|
||||
print(f" Preparing metadata for {len(filepaths)} documents...")
|
||||
metadata_list = []
|
||||
for i, filepath in enumerate(filepaths):
|
||||
metadata_list.append(
|
||||
{
|
||||
"filepath": filepath,
|
||||
"index": i,
|
||||
}
|
||||
)
|
||||
|
||||
# Create Fast-Plaid index
|
||||
print(f" Creating FastPlaid object with index path: {index_path}")
|
||||
try:
|
||||
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
||||
print(" FastPlaid object created successfully")
|
||||
except Exception as e:
|
||||
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
|
||||
try:
|
||||
fast_plaid_index.create(
|
||||
documents_embeddings=documents_embeddings,
|
||||
metadata=metadata_list,
|
||||
)
|
||||
print(" Fast-Plaid index created successfully")
|
||||
except Exception as e:
|
||||
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
build_secs = time.perf_counter() - _t0
|
||||
|
||||
# Save images separately (Fast-Plaid doesn't store images)
|
||||
print(f" Saving {len(images)} images...")
|
||||
images_dir = Path(index_path) / "images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, img in enumerate(tqdm(images, desc="Saving images")):
|
||||
img_path = images_dir / f"doc_{i}.png"
|
||||
img.save(str(img_path))
|
||||
|
||||
return fast_plaid_index, build_secs
|
||||
|
||||
|
||||
def _fast_plaid_index_exists(index_path: str) -> bool:
|
||||
"""
|
||||
Check if Fast-Plaid index exists by checking for key files.
|
||||
This avoids creating the FastPlaid object which may trigger memory allocation.
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
|
||||
Returns:
|
||||
True if index appears to exist, False otherwise
|
||||
"""
|
||||
index_path_obj = Path(index_path)
|
||||
if not index_path_obj.exists() or not index_path_obj.is_dir():
|
||||
return False
|
||||
|
||||
# Fast-Plaid creates a SQLite database file for metadata
|
||||
# Check for metadata.db as the most reliable indicator
|
||||
metadata_db = index_path_obj / "metadata.db"
|
||||
if metadata_db.exists() and metadata_db.stat().st_size > 0:
|
||||
return True
|
||||
|
||||
# Also check if directory has any files (might be incomplete index)
|
||||
try:
|
||||
if any(index_path_obj.iterdir()):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]:
|
||||
"""
|
||||
Load Fast-Plaid index if it exists.
|
||||
First checks if index files exist, then creates the FastPlaid object.
|
||||
The actual index data loading happens lazily when search is called.
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
|
||||
Returns:
|
||||
FastPlaid index object if exists, None otherwise
|
||||
"""
|
||||
try:
|
||||
from fast_plaid import search as fast_plaid_search
|
||||
|
||||
# First check if index files exist without creating the object
|
||||
if not _fast_plaid_index_exists(index_path):
|
||||
return None
|
||||
|
||||
# Now try to create FastPlaid object
|
||||
# This may trigger some memory allocation, but the full index loading is deferred
|
||||
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
||||
return fast_plaid_index
|
||||
except ImportError:
|
||||
# fast-plaid not installed
|
||||
return None
|
||||
except Exception as e:
|
||||
# Any error (including memory errors from Rust backend) - return None
|
||||
# The error will be caught and index will be rebuilt
|
||||
print(f"Warning: Could not load Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _search_fast_plaid(
|
||||
fast_plaid_index: Any,
|
||||
query_vec: Any,
|
||||
top_k: int,
|
||||
) -> tuple[list[tuple[float, int]], float]:
|
||||
"""
|
||||
Search Fast-Plaid index with a query embedding.
|
||||
|
||||
Args:
|
||||
fast_plaid_index: FastPlaid index object
|
||||
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
|
||||
top_k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
Tuple of (results_list, search_time_in_seconds)
|
||||
results_list: List of (score, doc_id) tuples
|
||||
"""
|
||||
import torch
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
|
||||
# Ensure query is a torch tensor
|
||||
if not isinstance(query_vec, torch.Tensor):
|
||||
q_vec_tensor = (
|
||||
torch.tensor(query_vec)
|
||||
if isinstance(query_vec, np.ndarray)
|
||||
else torch.from_numpy(np.array(query_vec))
|
||||
)
|
||||
else:
|
||||
q_vec_tensor = query_vec
|
||||
|
||||
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
|
||||
if q_vec_tensor.dim() == 2:
|
||||
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
|
||||
|
||||
# Perform search
|
||||
scores = fast_plaid_index.search(
|
||||
queries_embeddings=q_vec_tensor,
|
||||
top_k=top_k,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
search_secs = time.perf_counter() - _t0
|
||||
|
||||
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
|
||||
results = []
|
||||
if scores and len(scores) > 0:
|
||||
query_results = scores[0]
|
||||
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
|
||||
results = [(float(score), int(doc_id)) for doc_id, score in query_results]
|
||||
|
||||
return results, search_secs
|
||||
|
||||
|
||||
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
|
||||
"""
|
||||
Retrieve image for a document from Fast-Plaid index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
doc_id: Document ID returned by Fast-Plaid search
|
||||
|
||||
Returns:
|
||||
PIL Image if found, None otherwise
|
||||
|
||||
Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid
|
||||
doc_id may differ from the file naming index.
|
||||
"""
|
||||
# First get metadata to find the actual index used for file naming
|
||||
metadata = _get_fast_plaid_metadata(index_path, doc_id)
|
||||
if metadata is None:
|
||||
# Fallback: try using doc_id directly
|
||||
file_index = doc_id
|
||||
else:
|
||||
# Use the 'index' field from metadata, which matches the file naming
|
||||
file_index = metadata.get("index", doc_id)
|
||||
|
||||
images_dir = Path(index_path) / "images"
|
||||
image_path = images_dir / f"doc_{file_index}.png"
|
||||
|
||||
if image_path.exists():
|
||||
return Image.open(image_path)
|
||||
|
||||
# If not found with index, try doc_id as fallback
|
||||
if file_index != doc_id:
|
||||
fallback_path = images_dir / f"doc_{doc_id}.png"
|
||||
if fallback_path.exists():
|
||||
return Image.open(fallback_path)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
|
||||
"""
|
||||
Retrieve metadata for a document from Fast-Plaid index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
Dictionary with metadata if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
from fast_plaid import filtering
|
||||
|
||||
metadata_list = filtering.get(index=index_path, subset=[doc_id])
|
||||
if metadata_list and len(metadata_list) > 0:
|
||||
return metadata_list[0]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _generate_similarity_map(
|
||||
model,
|
||||
processor,
|
||||
@@ -967,15 +678,11 @@ class LeannMultiVector:
|
||||
return (float(score), doc_id)
|
||||
|
||||
scores: list[tuple[float, int]] = []
|
||||
# load and core time
|
||||
start_time = time.time()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
|
||||
for fut in concurrent.futures.as_completed(futures):
|
||||
scores.append(fut.result())
|
||||
end_time = time.time()
|
||||
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
|
||||
print(f"Time taken in load and core time: {end_time - start_time} seconds")
|
||||
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
@@ -1003,6 +710,7 @@ class LeannMultiVector:
|
||||
emb_path = self._embeddings_path()
|
||||
if not emb_path.exists():
|
||||
return self.search(data, topk)
|
||||
|
||||
all_embeddings = np.load(emb_path, mmap_mode="r")
|
||||
if all_embeddings.dtype != np.float32:
|
||||
all_embeddings = all_embeddings.astype(np.float32)
|
||||
@@ -1010,29 +718,23 @@ class LeannMultiVector:
|
||||
assert self._docid_to_indices is not None
|
||||
candidate_doc_ids = list(self._docid_to_indices.keys())
|
||||
|
||||
def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]:
|
||||
def _score_one(doc_id: int) -> tuple[float, int]:
|
||||
token_indices = self._docid_to_indices.get(doc_id, [])
|
||||
if not token_indices:
|
||||
return (0.0, doc_id)
|
||||
doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32)
|
||||
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
||||
sim = np.dot(data, doc_vecs.T)
|
||||
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
||||
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
||||
return (float(score), doc_id)
|
||||
|
||||
scores: list[tuple[float, int]] = []
|
||||
# load and core time
|
||||
start_time = time.time()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
|
||||
for fut in concurrent.futures.as_completed(futures):
|
||||
scores.append(fut.result())
|
||||
end_time = time.time()
|
||||
# print number of candidate doc ids
|
||||
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
|
||||
print(f"Time taken in load and core time: {end_time - start_time} seconds")
|
||||
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
del all_embeddings
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
def get_image(self, doc_id: int) -> Optional[Image.Image]:
|
||||
@@ -1076,259 +778,3 @@ class LeannMultiVector:
|
||||
"image_path": meta.get("image_path", ""),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
class ViDoReBenchmarkEvaluator:
|
||||
"""
|
||||
A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
|
||||
This class encapsulates common functionality for building indexes, searching, and evaluating.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
use_fast_plaid: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the evaluator.
|
||||
|
||||
Args:
|
||||
model_name: Model name ("colqwen2" or "colpali")
|
||||
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
|
||||
top_k: Top-k results to retrieve
|
||||
first_stage_k: First stage k for LEANN search
|
||||
k_values: List of k values for evaluation metrics
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.use_fast_plaid = use_fast_plaid
|
||||
self.top_k = top_k
|
||||
self.first_stage_k = first_stage_k
|
||||
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
|
||||
|
||||
# Load model once (can be reused across tasks)
|
||||
self._model = None
|
||||
self._processor = None
|
||||
self._model_name_actual = None
|
||||
|
||||
def _load_model_if_needed(self):
|
||||
"""Lazy load the model."""
|
||||
if self._model is None:
|
||||
print(f"\nLoading model: {self.model_name}")
|
||||
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(
|
||||
self.model_name
|
||||
)
|
||||
print(f"Model loaded: {self._model_name_actual}")
|
||||
|
||||
def build_index_from_corpus(
|
||||
self,
|
||||
corpus: dict[str, Image.Image],
|
||||
index_path: str,
|
||||
rebuild: bool = False,
|
||||
) -> tuple[Any, list[str]]:
|
||||
"""
|
||||
Build index from corpus images.
|
||||
|
||||
Args:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
index_path: Path to save/load the index
|
||||
rebuild: Whether to rebuild even if index exists
|
||||
|
||||
Returns:
|
||||
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
|
||||
"""
|
||||
self._load_model_if_needed()
|
||||
|
||||
# Ensure consistent ordering
|
||||
corpus_ids = sorted(corpus.keys())
|
||||
images = [corpus[cid] for cid in corpus_ids]
|
||||
|
||||
if self.use_fast_plaid:
|
||||
# Check if Fast-Plaid index exists
|
||||
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
|
||||
print(f"Fast-Plaid index already exists at {index_path}")
|
||||
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
|
||||
|
||||
print(f"Building Fast-Plaid index at {index_path}...")
|
||||
print("Embedding images...")
|
||||
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||
|
||||
fast_plaid_index, build_time = _build_fast_plaid_index(
|
||||
index_path, doc_vecs, corpus_ids, images
|
||||
)
|
||||
print(f"Fast-Plaid index built in {build_time:.2f}s")
|
||||
return fast_plaid_index, corpus_ids
|
||||
else:
|
||||
# Check if LEANN index exists
|
||||
if not rebuild:
|
||||
retriever = _load_retriever_if_index_exists(index_path)
|
||||
if retriever is not None:
|
||||
print(f"LEANN index already exists at {index_path}")
|
||||
return retriever, corpus_ids
|
||||
|
||||
print(f"Building LEANN index at {index_path}...")
|
||||
print("Embedding images...")
|
||||
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||
|
||||
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
|
||||
print("LEANN index built")
|
||||
return retriever, corpus_ids
|
||||
|
||||
def search_queries(
|
||||
self,
|
||||
queries: dict[str, str],
|
||||
corpus_ids: list[str],
|
||||
index_or_retriever: Any,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
task_prompt: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Search queries against the index.
|
||||
|
||||
Args:
|
||||
queries: dict mapping query_id to query text
|
||||
corpus_ids: list of corpus_ids in the same order as the index
|
||||
index_or_retriever: index or retriever object
|
||||
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
|
||||
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
|
||||
|
||||
Returns:
|
||||
results: dict mapping query_id to dict of {corpus_id: score}
|
||||
"""
|
||||
self._load_model_if_needed()
|
||||
|
||||
print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
|
||||
|
||||
query_ids = list(queries.keys())
|
||||
query_texts = [queries[qid] for qid in query_ids]
|
||||
|
||||
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata
|
||||
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
|
||||
# So we don't append task_prompt here to match MTEB behavior
|
||||
|
||||
# Embed queries
|
||||
print("Embedding queries...")
|
||||
query_vecs = _embed_queries(self._model, self._processor, query_texts)
|
||||
|
||||
results = {}
|
||||
|
||||
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
|
||||
if self.use_fast_plaid:
|
||||
# Fast-Plaid search
|
||||
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, self.top_k)
|
||||
query_results = {}
|
||||
for score, doc_id in search_results:
|
||||
if doc_id < len(corpus_ids):
|
||||
corpus_id = corpus_ids[doc_id]
|
||||
query_results[corpus_id] = float(score)
|
||||
else:
|
||||
# LEANN search
|
||||
import torch
|
||||
|
||||
query_np = (
|
||||
query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
|
||||
)
|
||||
search_results = index_or_retriever.search_exact(query_np, topk=self.top_k)
|
||||
query_results = {}
|
||||
for score, doc_id in search_results:
|
||||
if doc_id < len(corpus_ids):
|
||||
corpus_id = corpus_ids[doc_id]
|
||||
query_results[corpus_id] = float(score)
|
||||
|
||||
results[query_id] = query_results
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def evaluate_results(
|
||||
results: dict[str, dict[str, float]],
|
||||
qrels: dict[str, dict[str, int]],
|
||||
k_values: Optional[list[int]] = None,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Evaluate retrieval results using NDCG and other metrics.
|
||||
|
||||
Args:
|
||||
results: dict mapping query_id to dict of {corpus_id: score}
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
k_values: List of k values for evaluation metrics
|
||||
|
||||
Returns:
|
||||
Dictionary of metric scores
|
||||
"""
|
||||
try:
|
||||
from mteb._evaluators.retrieval_metrics import (
|
||||
calculate_retrieval_scores,
|
||||
make_score_dict,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pytrec_eval is required for evaluation. Install with: pip install pytrec-eval"
|
||||
)
|
||||
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Check if we have any queries to evaluate
|
||||
if len(results) == 0:
|
||||
print("Warning: No queries to evaluate. Returning zero scores.")
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
print(f"Evaluating results with k_values={k_values}...")
|
||||
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
|
||||
|
||||
# Filter to ensure qrels and results have the same query set
|
||||
# This matches MTEB behavior: only evaluate queries that exist in both
|
||||
# pytrec_eval only evaluates queries in qrels, so we need to ensure
|
||||
# results contains all queries in qrels, and filter out queries not in qrels
|
||||
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
|
||||
qrels_filtered = {
|
||||
qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered
|
||||
}
|
||||
|
||||
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
|
||||
|
||||
if len(results_filtered) != len(qrels_filtered):
|
||||
print(
|
||||
f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries"
|
||||
)
|
||||
missing_in_results = set(qrels.keys()) - set(results.keys())
|
||||
if missing_in_results:
|
||||
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
|
||||
print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
|
||||
|
||||
# Convert qrels to pytrec_eval format
|
||||
qrels_pytrec = {}
|
||||
for qid, rel_docs in qrels_filtered.items():
|
||||
qrels_pytrec[qid] = dict(rel_docs.items())
|
||||
|
||||
# Evaluate
|
||||
eval_result = calculate_retrieval_scores(
|
||||
results=results_filtered,
|
||||
qrels=qrels_pytrec,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Format scores
|
||||
scores = make_score_dict(
|
||||
ndcg=eval_result.ndcg,
|
||||
_map=eval_result.map,
|
||||
recall=eval_result.recall,
|
||||
precision=eval_result.precision,
|
||||
mrr=eval_result.mrr,
|
||||
naucs=eval_result.naucs,
|
||||
naucs_mrr=eval_result.naucs_mrr,
|
||||
cv_recall=eval_result.cv_recall,
|
||||
task_scores={},
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
@@ -1,19 +1,12 @@
|
||||
## Jupyter-style notebook script
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import argparse
|
||||
import faulthandler
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Enable faulthandler to get stack trace on segfault
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
from leann_multi_vector import ( # utility functions/classes
|
||||
_ensure_repo_paths_importable,
|
||||
@@ -25,11 +18,6 @@ from leann_multi_vector import ( # utility functions/classes
|
||||
_build_index,
|
||||
_load_retriever_if_index_exists,
|
||||
_generate_similarity_map,
|
||||
_build_fast_plaid_index,
|
||||
_load_fast_plaid_index_if_exists,
|
||||
_search_fast_plaid,
|
||||
_get_fast_plaid_image,
|
||||
_get_fast_plaid_metadata,
|
||||
QwenVL,
|
||||
)
|
||||
|
||||
@@ -43,33 +31,8 @@ MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||
|
||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||
USE_HF_DATASET: bool = True
|
||||
# Single dataset name (used when DATASET_NAMES is None)
|
||||
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
# Multiple datasets to combine (if provided, DATASET_NAME is ignored)
|
||||
# Can be:
|
||||
# - List of strings: ["dataset1", "dataset2"]
|
||||
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
|
||||
# - Mixed: ["dataset1", ("dataset2", "config2")]
|
||||
#
|
||||
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
|
||||
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
|
||||
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
|
||||
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
|
||||
# - "pixparse/arxiv-papers" (if available, arXiv papers)
|
||||
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
|
||||
# - "huggingface/document-images" (if available)
|
||||
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
|
||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||
DATASET_NAMES = [
|
||||
"weaviate/arXiv-AI-papers-multi-vector",
|
||||
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
]
|
||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||
# Set to None to try loading all available splits automatically
|
||||
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
|
||||
# Image field name in the dataset (auto-detect if None)
|
||||
# Common names: "page_image", "image", "images", "img"
|
||||
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
|
||||
DATASET_SPLIT: str = "train"
|
||||
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
@@ -77,13 +40,10 @@ PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
|
||||
# Index + retrieval settings
|
||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||
INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# Fast-Plaid index settings (alternative to LEANN index)
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = True
|
||||
REBUILD_INDEX: bool = False
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
@@ -94,310 +54,38 @@ ANSWER: bool = True
|
||||
MAX_NEW_TOKENS: int = 1024
|
||||
|
||||
|
||||
# %%
|
||||
# CLI overrides
|
||||
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
|
||||
parser.add_argument(
|
||||
"--search-method",
|
||||
type=str,
|
||||
choices=["ann", "exact", "exact-all"],
|
||||
default="ann",
|
||||
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=QUERY,
|
||||
help=f"Query string to search for. Default: '{QUERY}'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set to True to use fast-plaid instead of LEANN. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default="./indexes/colvision_fastplaid",
|
||||
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=TOPK,
|
||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Check if we can skip data loading (index already exists)
|
||||
retriever: Optional[Any] = None
|
||||
fast_plaid_index: Optional[Any] = None
|
||||
need_to_build_index = REBUILD_INDEX
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid index handling
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is not None:
|
||||
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Fast-Plaid index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
except Exception as e:
|
||||
# If loading fails (e.g., memory error, corrupted index), rebuild
|
||||
print(f"Warning: Failed to load Fast-Plaid index: {e}")
|
||||
print("Will rebuild the index...")
|
||||
need_to_build_index = True
|
||||
fast_plaid_index = None
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
# Original LEANN index handling
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild index")
|
||||
print(f"Index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
|
||||
# Step 2: Load data only if we need to build the index
|
||||
if need_to_build_index:
|
||||
print("Loading dataset...")
|
||||
if USE_HF_DATASET:
|
||||
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
||||
from datasets import load_dataset
|
||||
|
||||
# Determine which datasets to load
|
||||
if DATASET_NAMES is not None:
|
||||
dataset_names_to_load = DATASET_NAMES
|
||||
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
|
||||
else:
|
||||
dataset_names_to_load = [DATASET_NAME]
|
||||
print(f"Loading single dataset: {DATASET_NAME}")
|
||||
|
||||
# Load and combine datasets
|
||||
all_datasets_to_concat = []
|
||||
|
||||
for dataset_entry in dataset_names_to_load:
|
||||
# Handle both string and tuple formats
|
||||
if isinstance(dataset_entry, tuple):
|
||||
dataset_name, config_name = dataset_entry
|
||||
else:
|
||||
dataset_name = dataset_entry
|
||||
config_name = None
|
||||
|
||||
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
|
||||
|
||||
# Load dataset to check available splits
|
||||
# If config_name is provided, use it; otherwise try without config
|
||||
try:
|
||||
if config_name:
|
||||
dataset_dict = load_dataset(dataset_name, config_name)
|
||||
else:
|
||||
dataset_dict = load_dataset(dataset_name)
|
||||
except ValueError as e:
|
||||
if "Config name is missing" in str(e):
|
||||
# Try to get available configs and suggest
|
||||
from datasets import get_dataset_config_names
|
||||
try:
|
||||
available_configs = get_dataset_config_names(dataset_name)
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Available configs: {available_configs}. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Determine which splits to load
|
||||
if DATASET_SPLITS is None:
|
||||
# Auto-detect: try to load all available splits
|
||||
available_splits = list(dataset_dict.keys())
|
||||
print(f" Auto-detected splits: {available_splits}")
|
||||
splits_to_load = available_splits
|
||||
else:
|
||||
splits_to_load = DATASET_SPLITS
|
||||
|
||||
# Load and concatenate multiple splits for this dataset
|
||||
datasets_to_concat = []
|
||||
for split in splits_to_load:
|
||||
if split not in dataset_dict:
|
||||
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
||||
continue
|
||||
split_dataset = dataset_dict[split]
|
||||
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
||||
datasets_to_concat.append(split_dataset)
|
||||
|
||||
if not datasets_to_concat:
|
||||
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
|
||||
continue
|
||||
|
||||
# Concatenate splits for this dataset
|
||||
if len(datasets_to_concat) > 1:
|
||||
combined_dataset = concatenate_datasets(datasets_to_concat)
|
||||
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
|
||||
else:
|
||||
combined_dataset = datasets_to_concat[0]
|
||||
|
||||
all_datasets_to_concat.append(combined_dataset)
|
||||
|
||||
if not all_datasets_to_concat:
|
||||
raise RuntimeError("No valid datasets or splits found.")
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(all_datasets_to_concat) > 1:
|
||||
dataset = concatenate_datasets(all_datasets_to_concat)
|
||||
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
|
||||
else:
|
||||
dataset = all_datasets_to_concat[0]
|
||||
|
||||
# Apply MAX_DOCS limit if specified
|
||||
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
if N < len(dataset):
|
||||
print(f"Limiting to {N} pages (from {len(dataset)} total)")
|
||||
dataset = dataset.select(range(N))
|
||||
|
||||
# Auto-detect image field name if not specified
|
||||
if IMAGE_FIELD_NAME is None:
|
||||
# Check multiple samples to find the most common image field
|
||||
# (useful when datasets are merged and may have different field names)
|
||||
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
|
||||
field_counts = {}
|
||||
|
||||
# Check first few samples to find image fields
|
||||
num_samples_to_check = min(10, len(dataset))
|
||||
for sample_idx in range(num_samples_to_check):
|
||||
sample = dataset[sample_idx]
|
||||
for field in possible_image_fields:
|
||||
if field in sample and sample[field] is not None:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
field_counts[field] = field_counts.get(field, 0) + 1
|
||||
|
||||
# Choose the most common field, or first found if tied
|
||||
if field_counts:
|
||||
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
|
||||
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
|
||||
else:
|
||||
# Fallback: check first sample only
|
||||
sample = dataset[0]
|
||||
image_field = None
|
||||
for field in possible_image_fields:
|
||||
if field in sample:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
image_field = field
|
||||
break
|
||||
if image_field is None:
|
||||
raise RuntimeError(
|
||||
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
|
||||
f"Please specify IMAGE_FIELD_NAME manually."
|
||||
)
|
||||
print(f"Auto-detected image field: '{image_field}'")
|
||||
else:
|
||||
image_field = IMAGE_FIELD_NAME
|
||||
if image_field not in dataset[0]:
|
||||
raise RuntimeError(
|
||||
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
|
||||
)
|
||||
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
|
||||
for i in tqdm(range(N), desc="Loading dataset", total=N):
|
||||
p = dataset[i]
|
||||
# Try to compose a descriptive identifier
|
||||
# Handle different dataset structures
|
||||
identifier_parts = []
|
||||
|
||||
# Helper function to safely get field value
|
||||
def safe_get(field_name, default=None):
|
||||
if field_name in p and p[field_name] is not None:
|
||||
return p[field_name]
|
||||
return default
|
||||
|
||||
# Try to get various identifier fields
|
||||
if safe_get("paper_arxiv_id"):
|
||||
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
|
||||
if safe_get("paper_title"):
|
||||
identifier_parts.append(f"title:{p['paper_title']}")
|
||||
if safe_get("page_number") is not None:
|
||||
try:
|
||||
identifier_parts.append(f"page:{int(p['page_number'])}")
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, use the raw value or skip
|
||||
if p['page_number']:
|
||||
identifier_parts.append(f"page:{p['page_number']}")
|
||||
if safe_get("page_id"):
|
||||
identifier_parts.append(f"id:{p['page_id']}")
|
||||
elif safe_get("questionId"):
|
||||
identifier_parts.append(f"qid:{p['questionId']}")
|
||||
elif safe_get("docId"):
|
||||
identifier_parts.append(f"docId:{p['docId']}")
|
||||
elif safe_get("id"):
|
||||
identifier_parts.append(f"id:{p['id']}")
|
||||
|
||||
# If no identifier parts found, create one from index
|
||||
if identifier_parts:
|
||||
identifier = "|".join(identifier_parts)
|
||||
else:
|
||||
# Create identifier from available fields or index
|
||||
fallback_parts = []
|
||||
# Try common fields that might exist
|
||||
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
|
||||
if safe_get(field):
|
||||
fallback_parts.append(f"{field}:{p[field]}")
|
||||
break
|
||||
if fallback_parts:
|
||||
identifier = "|".join(fallback_parts) + f"|idx:{i}"
|
||||
else:
|
||||
identifier = f"doc_{i}"
|
||||
|
||||
# Compose a descriptive identifier for printing later
|
||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||
filepaths.append(identifier)
|
||||
|
||||
# Get image - try detected field first, then fallback to other common fields
|
||||
img = None
|
||||
if image_field in p and p[image_field] is not None:
|
||||
img = p[image_field]
|
||||
else:
|
||||
# Fallback: try other common image field names
|
||||
for fallback_field in ["image", "page_image", "images", "img"]:
|
||||
if fallback_field in p and p[fallback_field] is not None:
|
||||
img = p[fallback_field]
|
||||
break
|
||||
|
||||
if img is None:
|
||||
raise RuntimeError(
|
||||
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
|
||||
f"Expected field: {image_field}"
|
||||
)
|
||||
|
||||
# Ensure it's a PIL Image
|
||||
if not isinstance(img, Image.Image):
|
||||
if hasattr(img, 'convert'):
|
||||
img = img.convert('RGB')
|
||||
else:
|
||||
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
|
||||
images.append(img)
|
||||
images.append(p["page_image"]) # PIL Image
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
@@ -406,19 +94,6 @@ if need_to_build_index:
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
print(f"Loaded {len(images)} images")
|
||||
|
||||
# Memory check before loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
print("Skipping dataset loading (using existing index)")
|
||||
filepaths = [] # Not needed when using existing index
|
||||
@@ -427,181 +102,46 @@ else:
|
||||
|
||||
# %%
|
||||
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||
print("Step 3: Loading model and processor...")
|
||||
print(f" Model: {MODEL}")
|
||||
try:
|
||||
import sys
|
||||
print(f" Python version: {sys.version}")
|
||||
print(f" Python executable: {sys.executable}")
|
||||
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
# Memory check after loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 4: Build index if needed
|
||||
if need_to_build_index:
|
||||
print("Step 4: Building index...")
|
||||
print(f" Number of images: {len(images)}")
|
||||
print(f" Number of filepaths: {len(filepaths)}")
|
||||
if need_to_build_index and retriever is None:
|
||||
print("Building index...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
# Clear memory
|
||||
del images, filepaths, doc_vecs
|
||||
|
||||
try:
|
||||
print(" Embedding images...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
print(f" Embedded {len(doc_vecs)} documents")
|
||||
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
|
||||
except Exception as e:
|
||||
print(f"Error embedding images: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Build Fast-Plaid index
|
||||
print(" Building Fast-Plaid index...")
|
||||
try:
|
||||
fast_plaid_index, build_secs = _build_fast_plaid_index(
|
||||
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
|
||||
)
|
||||
from pathlib import Path
|
||||
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
|
||||
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
|
||||
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
|
||||
except Exception as e:
|
||||
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
else:
|
||||
# Build original LEANN index
|
||||
try:
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
except Exception as e:
|
||||
print(f"Error building LEANN index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
|
||||
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
|
||||
# Note: Images are now stored in the index, retriever will load them on-demand from disk
|
||||
|
||||
|
||||
# %%
|
||||
# Step 5: Embed query and search
|
||||
_t0 = time.perf_counter()
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
query_embed_secs = time.perf_counter() - _t0
|
||||
|
||||
print(f"[Search] Method: {SEARCH_METHOD}")
|
||||
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
|
||||
|
||||
# Run the selected search method and time it
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid search
|
||||
if fast_plaid_index is None:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is None:
|
||||
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
||||
|
||||
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
||||
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
||||
else:
|
||||
# Original LEANN search
|
||||
query_np = q_vec.float().numpy()
|
||||
|
||||
if SEARCH_METHOD == "ann":
|
||||
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact":
|
||||
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact-all":
|
||||
results = retriever.search_exact_all(query_np, topk=TOPK)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
|
||||
else:
|
||||
results = []
|
||||
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
print("\n[DEBUG] Retrieval details:")
|
||||
top_images: list[Image.Image] = []
|
||||
image_hashes = {} # Track image hashes to detect duplicates
|
||||
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
# Retrieve image and metadata based on index type
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid: load image and get metadata
|
||||
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not find image for doc_id {doc_id}")
|
||||
continue
|
||||
# Retrieve image from index instead of memory
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
||||
top_images.append(image)
|
||||
else:
|
||||
# Original LEANN: retrieve from retriever
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
top_images.append(image)
|
||||
|
||||
# Calculate image hash to detect duplicates
|
||||
import hashlib
|
||||
import io
|
||||
# Convert image to bytes for hashing
|
||||
img_bytes = io.BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
image_bytes = img_bytes.getvalue()
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
|
||||
|
||||
# Check if this image was already seen
|
||||
duplicate_info = ""
|
||||
if image_hash in image_hashes:
|
||||
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
|
||||
else:
|
||||
image_hashes[image_hash] = rank
|
||||
|
||||
# Print detailed information
|
||||
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
|
||||
if metadata:
|
||||
print(f" Metadata: {metadata}")
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||
top_images.append(image)
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
@@ -664,9 +204,6 @@ if results and SIMILARITY_MAP:
|
||||
# Step 7: Optional answer generation
|
||||
if results and ANSWER:
|
||||
qwen = QwenVL(device=device_str)
|
||||
_t0 = time.perf_counter()
|
||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||
gen_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Generation: {gen_secs:.3f}s")
|
||||
print("\nAnswer:")
|
||||
print(response)
|
||||
|
||||
@@ -1,399 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v1 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v1 tasks
|
||||
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# ViDoRe v1 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V1_TASKS = {
|
||||
"VidoreArxivQARetrieval": {
|
||||
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreDocVQARetrieval": {
|
||||
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreInfoVQARetrieval": {
|
||||
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTabfquadRetrieval": {
|
||||
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTatdqaRetrieval": {
|
||||
"dataset_path": "vidore/tatdqa_test_beir",
|
||||
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreShiftProjectRetrieval": {
|
||||
"dataset_path": "vidore/shiftproject_test_beir",
|
||||
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAAIRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_vidore_v1_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v1 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 1000,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v1 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V1_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V1_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v1_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
)
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
choices=["colqwen2", "colpali"],
|
||||
help="Model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,20,100,1000",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v1_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,439 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v2 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v2 tasks
|
||||
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Language name to dataset language field value mapping
|
||||
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
|
||||
LANGUAGE_MAPPING = {
|
||||
"english": "eng-Latn",
|
||||
"french": "fra-Latn",
|
||||
"spanish": "spa-Latn",
|
||||
"german": "deu-Latn",
|
||||
}
|
||||
|
||||
# ViDoRe v2 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V2_TASKS = {
|
||||
"Vidore2ESGReportsRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_v2",
|
||||
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2EconomicsReportsRetrieval": {
|
||||
"dataset_path": "vidore/economics_reports_v2",
|
||||
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2BioMedicalLecturesRetrieval": {
|
||||
"dataset_path": "vidore/biomedical_lectures_v2",
|
||||
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2ESGReportsHLRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_human_labeled_v2",
|
||||
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
|
||||
# Note: This dataset doesn't have language filtering - all queries are English
|
||||
"languages": None, # No language filtering needed
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_vidore_v2_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v2 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
# Check if dataset has language field before filtering
|
||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||
|
||||
if language and has_language_field:
|
||||
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||
# Check if filtering resulted in empty dataset
|
||||
if len(query_ds_filtered) == 0:
|
||||
print(
|
||||
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||
)
|
||||
# Try with original language value (dataset might use simple names like 'english')
|
||||
print(f"Trying with original language value '{language}'...")
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = load_dataset(
|
||||
dataset_path, "queries", split=split, revision=revision
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
print(f"Available language values in dataset: {sample_langs}")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v2 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V2_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V2_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
# Determine language
|
||||
if language is None:
|
||||
# Use first language if multiple available
|
||||
languages = task_config.get("languages")
|
||||
if languages is None:
|
||||
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||
language = None
|
||||
elif len(languages) == 1:
|
||||
language = languages[0]
|
||||
else:
|
||||
language = None
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v2_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(
|
||||
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||
)
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
choices=["colqwen2", "colpali"],
|
||||
help="Model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Language to evaluate (default: first available)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Top-k results to retrieve",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,100",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v2_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
language=args.language,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
98
benchmarks/issue_159.py
Normal file
98
benchmarks/issue_159.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to reproduce issue #159: Slow search performance
|
||||
Configuration:
|
||||
- GPU: A10
|
||||
- embedding_model: BAAI/bge-large-zh-v1.5
|
||||
- data size: 180M text (~90K chunks)
|
||||
- backend: hnsw
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
|
||||
|
||||
# Configuration matching the issue
|
||||
INDEX_PATH = "./test_issue_159.leann"
|
||||
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
|
||||
BACKEND_NAME = "hnsw"
|
||||
|
||||
|
||||
def generate_test_data(num_chunks=90000, chunk_size=2000):
|
||||
"""Generate test data similar to 180MB text (~90K chunks)"""
|
||||
# Each chunk is approximately 2000 characters
|
||||
# 90K chunks * 2000 chars ≈ 180MB
|
||||
chunks = []
|
||||
base_text = (
|
||||
"这是一个测试文档。LEANN是一个创新的向量数据库, 通过图基选择性重计算实现97%的存储节省。"
|
||||
)
|
||||
|
||||
for i in range(num_chunks):
|
||||
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
|
||||
chunks.append(chunk[:chunk_size])
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def test_search_performance():
|
||||
"""Test search performance with different configurations"""
|
||||
print("=" * 80)
|
||||
print("Testing LEANN Search Performance (Issue #159)")
|
||||
print("=" * 80)
|
||||
|
||||
meta_path = Path(f"{INDEX_PATH}.meta.json")
|
||||
if meta_path.exists():
|
||||
print(f"\n✓ Index already exists at {INDEX_PATH}")
|
||||
print(" Skipping build phase. Delete the index to rebuild.")
|
||||
else:
|
||||
print("\n📦 Building index...")
|
||||
print(f" Backend: {BACKEND_NAME}")
|
||||
print(f" Embedding Model: {EMBEDDING_MODEL}")
|
||||
print(" Generating test data (~90K chunks, ~180MB)...")
|
||||
|
||||
chunks = generate_test_data(num_chunks=90000)
|
||||
print(f" Generated {len(chunks)} chunks")
|
||||
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=BACKEND_NAME,
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
print(" Adding chunks to builder...")
|
||||
start_time = time.time()
|
||||
for i, chunk in enumerate(chunks):
|
||||
builder.add_text(chunk)
|
||||
if (i + 1) % 10000 == 0:
|
||||
print(f" Added {i + 1}/{len(chunks)} chunks...")
|
||||
|
||||
print(" Building index...")
|
||||
build_start = time.time()
|
||||
builder.build_index(INDEX_PATH)
|
||||
build_time = time.time() - build_start
|
||||
print(f" ✓ Index built in {build_time:.2f} seconds")
|
||||
|
||||
# Test search with different complexity values
|
||||
print("\n🔍 Testing search performance...")
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
|
||||
test_query = "LEANN向量数据库存储优化"
|
||||
|
||||
# Test with minimal complexity (8)
|
||||
print("\n Test 4: Minimal complexity (8)")
|
||||
print(f" Query: '{test_query}'")
|
||||
start_time = time.time()
|
||||
results = searcher.search(test_query, top_k=10, complexity=8)
|
||||
search_time = time.time() - start_time
|
||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||
print(f" Results: {len(results)} items")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_search_performance()
|
||||
@@ -143,8 +143,6 @@ def create_hnsw_embedding_server(
|
||||
pass
|
||||
return str(nid)
|
||||
|
||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||
|
||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||
"""ZMQ server thread that respects shutdown signal.
|
||||
|
||||
@@ -158,225 +156,238 @@ def create_hnsw_embedding_server(
|
||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||
# Keep sends from blocking during shutdown; fail fast and drop on close
|
||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||
|
||||
# Track last request type/length for shape-correct fallbacks
|
||||
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
||||
last_request_type = "unknown"
|
||||
last_request_length = 0
|
||||
|
||||
def _build_safe_fallback():
|
||||
if last_request_type == "distance":
|
||||
large_distance = 1e9
|
||||
fallback_len = max(0, int(last_request_length))
|
||||
return [[large_distance] * fallback_len]
|
||||
if last_request_type == "embedding":
|
||||
bsz = max(0, int(last_request_length))
|
||||
dim = max(0, int(embedding_dim))
|
||||
if dim > 0:
|
||||
return [[bsz, dim], [0.0] * (bsz * dim)]
|
||||
return [[0, 0], []]
|
||||
if last_request_type == "text":
|
||||
return []
|
||||
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||
|
||||
def _handle_text_embedding(request: list[str]) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
e2e_start = time.time()
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
embeddings = compute_embeddings(
|
||||
request,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
def _handle_distance_request(request: list[Any]) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
e2e_start = time.time()
|
||||
node_ids = request[0]
|
||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||
node_ids = node_ids[0]
|
||||
query_vector = np.array(request[1], dtype=np.float32)
|
||||
last_request_type = "distance"
|
||||
last_request_length = len(node_ids)
|
||||
|
||||
logger.debug("Distance calculation request received")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
except Exception as exc:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||
|
||||
large_distance = 1e9
|
||||
response_distances = [large_distance] * len(node_ids)
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
if distance_metric == "l2":
|
||||
partial = np.sum(
|
||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||
)
|
||||
else:
|
||||
partial = -np.dot(embeddings, query_vector)
|
||||
|
||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||
response_distances[pos] = float(dval)
|
||||
except Exception as exc:
|
||||
logger.error(f"Distance computation error, using sentinels: {exc}")
|
||||
|
||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
def _handle_embedding_by_id(request: Any) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
|
||||
node_ids = request[0]
|
||||
elif isinstance(request, list):
|
||||
node_ids = request
|
||||
else:
|
||||
node_ids = []
|
||||
|
||||
e2e_start = time.time()
|
||||
last_request_type = "embedding"
|
||||
last_request_length = len(node_ids)
|
||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||
|
||||
if embedding_dim <= 0:
|
||||
dims = [0, 0]
|
||||
flat_data: list[float] = []
|
||||
else:
|
||||
dims = [len(node_ids), embedding_dim]
|
||||
flat_data = [0.0] * (dims[0] * dims[1])
|
||||
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage with ID {nid} not found")
|
||||
except Exception as exc:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
logger.error(
|
||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||
)
|
||||
dims = [0, embedding_dim]
|
||||
flat_data = []
|
||||
else:
|
||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
flat = emb_f32.flatten().tolist()
|
||||
for j, pos in enumerate(found_indices):
|
||||
start = pos * embedding_dim
|
||||
end = start + embedding_dim
|
||||
if end <= len(flat_data):
|
||||
flat_data[start:end] = flat[
|
||||
j * embedding_dim : (j + 1) * embedding_dim
|
||||
]
|
||||
except Exception as exc:
|
||||
logger.error(f"Embedding computation error, returning zeros: {exc}")
|
||||
|
||||
response_payload = [dims, flat_data]
|
||||
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
e2e_start = time.time()
|
||||
logger.debug("🔍 Waiting for ZMQ message...")
|
||||
request_bytes = rep_socket.recv()
|
||||
except zmq.Again:
|
||||
continue
|
||||
|
||||
# Rest of the processing logic (same as original)
|
||||
try:
|
||||
request = msgpack.unpackb(request_bytes)
|
||||
except Exception as exc:
|
||||
if shutdown_event.is_set():
|
||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||
break
|
||||
logger.error(f"Error unpacking ZMQ message: {exc}")
|
||||
try:
|
||||
safe = _build_safe_fallback()
|
||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||
response_bytes = msgpack.packb([model_name])
|
||||
rep_socket.send(response_bytes)
|
||||
continue
|
||||
|
||||
# Handle direct text embedding request
|
||||
try:
|
||||
# Model query
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 1
|
||||
and request[0] == "__QUERY_MODEL__"
|
||||
):
|
||||
rep_socket.send(msgpack.packb([model_name]))
|
||||
# Direct text embedding
|
||||
elif (
|
||||
isinstance(request, list)
|
||||
and request
|
||||
and all(isinstance(item, str) for item in request)
|
||||
):
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
embeddings = compute_embeddings(
|
||||
request,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
continue
|
||||
|
||||
# Handle distance calculation request: [[ids], [query_vector]]
|
||||
if (
|
||||
_handle_text_embedding(request)
|
||||
# Distance calculation: [[ids], [query_vector]]
|
||||
elif (
|
||||
isinstance(request, list)
|
||||
and len(request) == 2
|
||||
and isinstance(request[0], list)
|
||||
and isinstance(request[1], list)
|
||||
):
|
||||
node_ids = request[0]
|
||||
# Handle nested [[ids]] shape defensively
|
||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||
node_ids = node_ids[0]
|
||||
query_vector = np.array(request[1], dtype=np.float32)
|
||||
last_request_type = "distance"
|
||||
last_request_length = len(node_ids)
|
||||
|
||||
logger.debug("Distance calculation request received")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
|
||||
# Gather texts for found ids
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
|
||||
# Prepare full-length response with large sentinel values
|
||||
large_distance = 1e9
|
||||
response_distances = [large_distance] * len(node_ids)
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
if distance_metric == "l2":
|
||||
partial = np.sum(
|
||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||
)
|
||||
else: # mips or cosine
|
||||
partial = -np.dot(embeddings, query_vector)
|
||||
|
||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||
response_distances[pos] = float(dval)
|
||||
except Exception as e:
|
||||
logger.error(f"Distance computation error, using sentinels: {e}")
|
||||
|
||||
# Send response in expected shape [[distances]]
|
||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
continue
|
||||
|
||||
# Fallback: treat as embedding-by-id request
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 1
|
||||
and isinstance(request[0], list)
|
||||
):
|
||||
node_ids = request[0]
|
||||
elif isinstance(request, list):
|
||||
node_ids = request
|
||||
else:
|
||||
node_ids = []
|
||||
last_request_type = "embedding"
|
||||
last_request_length = len(node_ids)
|
||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||
|
||||
# Preallocate zero-filled flat data for robustness
|
||||
if embedding_dim <= 0:
|
||||
dims = [0, 0]
|
||||
flat_data: list[float] = []
|
||||
else:
|
||||
dims = [len(node_ids), embedding_dim]
|
||||
flat_data = [0.0] * (dims[0] * dims[1])
|
||||
|
||||
# Collect texts for found ids
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
logger.error(
|
||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||
)
|
||||
dims = [0, embedding_dim]
|
||||
flat_data = []
|
||||
else:
|
||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
flat = emb_f32.flatten().tolist()
|
||||
for j, pos in enumerate(found_indices):
|
||||
start = pos * embedding_dim
|
||||
end = start + embedding_dim
|
||||
if end <= len(flat_data):
|
||||
flat_data[start:end] = flat[
|
||||
j * embedding_dim : (j + 1) * embedding_dim
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding computation error, returning zeros: {e}")
|
||||
|
||||
response_payload = [dims, flat_data]
|
||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||
|
||||
rep_socket.send(response_bytes)
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
# Timeout - check shutdown_event and continue
|
||||
continue
|
||||
except Exception as e:
|
||||
if not shutdown_event.is_set():
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
# Shape-correct fallback
|
||||
try:
|
||||
if last_request_type == "distance":
|
||||
large_distance = 1e9
|
||||
fallback_len = max(0, int(last_request_length))
|
||||
safe = [[large_distance] * fallback_len]
|
||||
elif last_request_type == "embedding":
|
||||
bsz = max(0, int(last_request_length))
|
||||
dim = max(0, int(embedding_dim))
|
||||
safe = (
|
||||
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
||||
)
|
||||
elif last_request_type == "text":
|
||||
safe = [] # direct text embeddings expectation is a flat list
|
||||
else:
|
||||
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||
except Exception:
|
||||
pass
|
||||
_handle_distance_request(request)
|
||||
# Embedding-by-id fallback
|
||||
else:
|
||||
_handle_embedding_by_id(request)
|
||||
except Exception as exc:
|
||||
if shutdown_event.is_set():
|
||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||
break
|
||||
logger.error(f"Error in ZMQ server loop: {exc}")
|
||||
try:
|
||||
safe = _build_safe_fallback()
|
||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
rep_socket.close(0)
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: e2d243c40d...301bf24f14
@@ -864,7 +864,13 @@ class LeannBuilder:
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
index_path: str,
|
||||
enable_warmup: bool = True,
|
||||
recompute_embeddings: bool = True,
|
||||
**backend_kwargs,
|
||||
):
|
||||
# Fix path resolution for Colab and other environments
|
||||
if not Path(index_path).is_absolute():
|
||||
index_path = str(Path(index_path).resolve())
|
||||
@@ -895,14 +901,32 @@ class LeannSearcher:
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
|
||||
# Global recompute flag for this searcher (explicit knob, default True)
|
||||
self.recompute_embeddings: bool = bool(recompute_embeddings)
|
||||
|
||||
# Warmup flag: keep using the existing enable_warmup parameter,
|
||||
# but default it to True so cold-start happens earlier.
|
||||
self._warmup: bool = bool(enable_warmup)
|
||||
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
final_kwargs["enable_warmup"] = self._warmup
|
||||
if self.embedding_options:
|
||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
|
||||
# Optional one-shot warmup at construction time to hide cold-start latency.
|
||||
if self._warmup:
|
||||
try:
|
||||
_ = self.backend_impl.compute_query_embedding(
|
||||
"__LEANN_WARMUP__",
|
||||
use_server_if_available=self.recompute_embeddings,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Warmup embedding failed (ignored): {exc}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -910,7 +934,7 @@ class LeannSearcher:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
recompute_embeddings: Optional[bool] = None,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: int = 5557,
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||
@@ -927,7 +951,8 @@ class LeannSearcher:
|
||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||
beam_width: Number of parallel search paths/IO requests per iteration
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||
recompute_embeddings: (Deprecated) Per-call override for recompute mode.
|
||||
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
|
||||
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
expected_zmq_port: ZMQ port for embedding server communication
|
||||
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||
@@ -966,8 +991,19 @@ class LeannSearcher:
|
||||
|
||||
zmq_port = None
|
||||
|
||||
# Resolve effective recompute flag for this search.
|
||||
if recompute_embeddings is not None:
|
||||
logger.warning(
|
||||
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
|
||||
"will be removed in a future version. Configure recompute at "
|
||||
"LeannSearcher(..., recompute_embeddings=...) instead."
|
||||
)
|
||||
effective_recompute = bool(recompute_embeddings)
|
||||
else:
|
||||
effective_recompute = self.recompute_embeddings
|
||||
|
||||
start_time = time.time()
|
||||
if recompute_embeddings:
|
||||
if effective_recompute:
|
||||
zmq_port = self.backend_impl._ensure_server_running(
|
||||
self.meta_path_str,
|
||||
port=expected_zmq_port,
|
||||
@@ -981,7 +1017,7 @@ class LeannSearcher:
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
use_server_if_available=effective_recompute,
|
||||
zmq_port=zmq_port,
|
||||
)
|
||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
@@ -993,7 +1029,7 @@ class LeannSearcher:
|
||||
"complexity": complexity,
|
||||
"beam_width": beam_width,
|
||||
"prune_ratio": prune_ratio,
|
||||
"recompute_embeddings": recompute_embeddings,
|
||||
"recompute_embeddings": effective_recompute,
|
||||
"pruning_strategy": pruning_strategy,
|
||||
"zmq_port": zmq_port,
|
||||
}
|
||||
|
||||
@@ -215,9 +215,14 @@ def compute_embeddings(
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
provider_options = provider_options or {}
|
||||
wrapper_start_time = time.time()
|
||||
logger.debug(
|
||||
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
|
||||
)
|
||||
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
inner_start_time = time.time()
|
||||
result = compute_embeddings_sentence_transformers(
|
||||
texts,
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
@@ -226,6 +231,14 @@ def compute_embeddings(
|
||||
manual_tokenize=manual_tokenize,
|
||||
max_length=max_length,
|
||||
)
|
||||
inner_end_time = time.time()
|
||||
wrapper_end_time = time.time()
|
||||
logger.debug(
|
||||
"[compute_embeddings] sentence-transformers timings: "
|
||||
f"inner={inner_end_time - inner_start_time:.6f}s, "
|
||||
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
|
||||
)
|
||||
return result
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(
|
||||
texts,
|
||||
@@ -271,6 +284,7 @@ def compute_embeddings_sentence_transformers(
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
"""
|
||||
outer_start_time = time.time()
|
||||
# Handle empty input
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
@@ -301,7 +315,14 @@ def compute_embeddings_sentence_transformers(
|
||||
# Create cache key
|
||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||
|
||||
pre_model_init_end_time = time.time()
|
||||
logger.debug(
|
||||
"compute_embeddings_sentence_transformers pre-model-init time "
|
||||
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
|
||||
)
|
||||
|
||||
# Check if model is already cached
|
||||
start_time = time.time()
|
||||
if cache_key in _model_cache:
|
||||
logger.info(f"Using cached optimized model: {model_name}")
|
||||
model = _model_cache[cache_key]
|
||||
@@ -441,10 +462,13 @@ def compute_embeddings_sentence_transformers(
|
||||
_model_cache[cache_key] = model
|
||||
logger.info(f"Model cached: {cache_key}")
|
||||
|
||||
# Compute embeddings with optimized inference mode
|
||||
logger.info(
|
||||
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
# Compute embeddings with optimized inference mode
|
||||
logger.info(
|
||||
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||
)
|
||||
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
|
||||
|
||||
start_time = time.time()
|
||||
if not manual_tokenize:
|
||||
@@ -465,32 +489,46 @@ def compute_embeddings_sentence_transformers(
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
||||
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel.
|
||||
# This path is reserved for an aggressively optimized FP pipeline
|
||||
# (no quantization), mainly for experimentation.
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||
except Exception as e:
|
||||
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||
|
||||
# Cache tokenizer and model
|
||||
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
||||
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp"
|
||||
|
||||
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||
hf_tokenizer = _model_cache[tok_cache_key]
|
||||
hf_model = _model_cache[mdl_cache_key]
|
||||
logger.info("Using cached HF tokenizer/model for manual path")
|
||||
logger.info("Using cached HF tokenizer/model for manual FP path")
|
||||
else:
|
||||
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||
logger.info("Loading HF tokenizer/model for manual FP path")
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
|
||||
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||
hf_model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
hf_model.to(device)
|
||||
|
||||
hf_model.eval()
|
||||
# Optional compile on supported devices
|
||||
if device in ["cuda", "mps"]:
|
||||
try:
|
||||
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
hf_model = torch.compile( # type: ignore
|
||||
hf_model, mode="reduce-overhead", dynamic=True
|
||||
)
|
||||
logger.info(
|
||||
f"Applied torch.compile to HF model for {model_name} "
|
||||
f"(device={device}, dtype={torch_dtype})"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"torch.compile optimization failed: {exc}")
|
||||
|
||||
_model_cache[tok_cache_key] = hf_tokenizer
|
||||
_model_cache[mdl_cache_key] = hf_model
|
||||
|
||||
@@ -516,7 +554,6 @@ def compute_embeddings_sentence_transformers(
|
||||
for start_index in batch_iter:
|
||||
end_index = min(start_index + batch_size, len(texts))
|
||||
batch_texts = texts[start_index:end_index]
|
||||
tokenize_start_time = time.time()
|
||||
inputs = hf_tokenizer(
|
||||
batch_texts,
|
||||
padding=True,
|
||||
@@ -524,34 +561,17 @@ def compute_embeddings_sentence_transformers(
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokenize_end_time = time.time()
|
||||
logger.info(
|
||||
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
||||
)
|
||||
# Print shapes of all input tensors for debugging
|
||||
for k, v in inputs.items():
|
||||
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
||||
to_device_start_time = time.time()
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
to_device_end_time = time.time()
|
||||
logger.info(
|
||||
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
||||
)
|
||||
forward_start_time = time.time()
|
||||
outputs = hf_model(**inputs)
|
||||
forward_end_time = time.time()
|
||||
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
||||
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||
attention_mask = inputs.get("attention_mask")
|
||||
if attention_mask is None:
|
||||
# Fallback: assume all tokens are valid
|
||||
pooled = last_hidden_state.mean(dim=1)
|
||||
else:
|
||||
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||
masked = last_hidden_state * mask
|
||||
lengths = mask.sum(dim=1).clamp(min=1)
|
||||
pooled = masked.sum(dim=1) / lengths
|
||||
# Move to CPU float32
|
||||
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||
all_embeddings.append(batch_embeddings)
|
||||
|
||||
@@ -571,6 +591,12 @@ def compute_embeddings_sentence_transformers(
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
||||
|
||||
outer_end_time = time.time()
|
||||
logger.debug(
|
||||
"compute_embeddings_sentence_transformers total time "
|
||||
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user