Compare commits

..

1 Commits

Author SHA1 Message Date
aakash
877fbe81f4 Fix: Prevent duplicate PDF processing when using --file-types .pdf
Fixes #175

Problem:
When --file-types .pdf is specified, PDFs were being processed twice:
1. Separately with PyMuPDF/pdfplumber extractors
2. Again in the 'other file types' section via SimpleDirectoryReader

This caused duplicate processing and potential conflicts.

Solution:
- Exclude .pdf from other_file_extensions when PDFs are already
  processed separately
- Only load other file types if there are extensions to process
- Prevents duplicate PDF processing

Changes:
- Added logic to filter out .pdf from code_extensions when loading
  other file types if PDFs were processed separately
- Updated SimpleDirectoryReader to use filtered extensions
- Added check to skip loading if no other extensions to process
2025-11-30 11:18:57 -08:00
13 changed files with 94 additions and 2421 deletions

3
.gitignore vendored
View File

@@ -91,8 +91,7 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json *.meta.json
*.passages.json *.passages.json
*.npy
*.db
batchtest.py batchtest.py
tests/__pytest_cache__/ tests/__pytest_cache__/
tests/__pycache__/ tests/__pycache__/

View File

@@ -201,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
#### LLM Backend #### LLM Backend
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API). LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
<details> <details>
@@ -269,7 +269,6 @@ Below is a list of base URLs for common providers to get you started.
| **SiliconFlow** | `https://api.siliconflow.cn/v1` | | **SiliconFlow** | `https://api.siliconflow.cn/v1` |
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` | | **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
| **Mistral AI** | `https://api.mistral.ai/v1` | | **Mistral AI** | `https://api.mistral.ai/v1` |
| **Anthropic** | `https://api.anthropic.com/v1` |
@@ -329,7 +328,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama --embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models) # LLM Parameters (Text generation models)
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai) --llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct --llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models) --thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
@@ -1058,7 +1057,7 @@ Options:
leann ask INDEX_NAME [OPTIONS] leann ask INDEX_NAME [OPTIONS]
Options: Options:
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama) --llm {ollama,openai,hf} LLM provider (default: ollama)
--model MODEL Model name (default: qwen3:8b) --model MODEL Model name (default: qwen3:8b)
--interactive Interactive chat mode --interactive Interactive chat mode
--top-k N Retrieval count (default: 20) --top-k N Retrieval count (default: 20)

View File

@@ -1,132 +0,0 @@
#!/usr/bin/env python3
"""Simple test script to test colqwen2 forward pass with a single image."""
import os
import sys
from pathlib import Path
# Add the current directory to path to import leann_multi_vector
sys.path.insert(0, str(Path(__file__).parent))
import torch
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
from PIL import Image
# Ensure repo paths are importable
_ensure_repo_paths_importable(__file__)
# Set environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def create_test_image():
"""Create a simple test image."""
# Create a simple RGB image (800x600)
img = Image.new("RGB", (800, 600), color="white")
return img
def load_test_image_from_file():
"""Try to load an image from the indexes directory if available."""
# Try to find an existing image in the indexes directory
indexes_dir = Path(__file__).parent / "indexes"
# Look for images in common locations
possible_paths = [
indexes_dir / "vidore_fastplaid" / "images",
indexes_dir / "colvision_large.leann.images",
indexes_dir / "colvision.leann.images",
]
for img_dir in possible_paths:
if img_dir.exists():
# Find first image file
for ext in [".png", ".jpg", ".jpeg"]:
for img_file in img_dir.glob(f"*{ext}"):
print(f"Loading test image from: {img_file}")
return Image.open(img_file)
return None
def main():
print("=" * 60)
print("Testing ColQwen2 Forward Pass")
print("=" * 60)
# Step 1: Load or create test image
print("\n[Step 1] Loading test image...")
test_image = load_test_image_from_file()
if test_image is None:
print("No existing image found, creating a simple test image...")
test_image = create_test_image()
else:
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
# Convert to RGB if needed
if test_image.mode != "RGB":
test_image = test_image.convert("RGB")
print(f"✓ Converted to RGB: {test_image.size}")
# Step 2: Load model
print("\n[Step 2] Loading ColQwen2 model...")
try:
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
print(f"✓ Model loaded: {model_name}")
print(f"✓ Device: {device_str}, dtype: {dtype}")
# Print model info
if hasattr(model, "device"):
print(f"✓ Model device: {model.device}")
if hasattr(model, "dtype"):
print(f"✓ Model dtype: {model.dtype}")
except Exception as e:
print(f"✗ Error loading model: {e}")
import traceback
traceback.print_exc()
return
# Step 3: Test forward pass
print("\n[Step 3] Running forward pass...")
try:
# Use the _embed_images function which handles batching and forward pass
images = [test_image]
print(f"Processing {len(images)} image(s)...")
doc_vecs = _embed_images(model, processor, images)
print("✓ Forward pass completed!")
print(f"✓ Number of embeddings: {len(doc_vecs)}")
if len(doc_vecs) > 0:
emb = doc_vecs[0]
print(f"✓ Embedding shape: {emb.shape}")
print(f"✓ Embedding dtype: {emb.dtype}")
print("✓ Embedding stats:")
print(f" - Min: {emb.min().item():.4f}")
print(f" - Max: {emb.max().item():.4f}")
print(f" - Mean: {emb.mean().item():.4f}")
print(f" - Std: {emb.std().item():.4f}")
# Check for NaN or Inf
if torch.isnan(emb).any():
print("⚠ Warning: Embedding contains NaN values!")
if torch.isinf(emb).any():
print("⚠ Warning: Embedding contains Inf values!")
except Exception as e:
print(f"✗ Error during forward pass: {e}")
import traceback
traceback.print_exc()
return
print("\n" + "=" * 60)
print("Test completed successfully!")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -1,11 +1,8 @@
import concurrent.futures import concurrent.futures
import glob
import json import json
import logging
import os import os
import re import re
import sys import sys
import time
from pathlib import Path from pathlib import Path
from typing import Any, Optional, cast from typing import Any, Optional, cast
@@ -13,8 +10,6 @@ import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__)
def _ensure_repo_paths_importable(current_file: str) -> None: def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py).""" """Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
@@ -100,63 +95,12 @@ def _natural_sort_key(name: str) -> int:
return int(m.group()) if m else 0 return int(m.group()) if m else 0
def _load_images_from_dir( def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
pages_dir: str, recursive: bool = False filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
) -> tuple[list[str], list[Image.Image]]: filenames = sorted(filenames, key=_natural_sort_key)
""" filepaths = [os.path.join(pages_dir, n) for n in filenames]
Load images from a directory. images = [Image.open(p) for p in filepaths]
return filepaths, images
Args:
pages_dir: Directory path containing images
recursive: If True, recursively search subdirectories (default: False)
Returns:
Tuple of (filepaths, images)
"""
# Supported image extensions
extensions = ("*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG", "*.webp", "*.WEBP")
if recursive:
# Recursive search
filepaths = []
for ext in extensions:
pattern = os.path.join(pages_dir, "**", ext)
filepaths.extend(glob.glob(pattern, recursive=True))
else:
# Non-recursive search (only top-level directory)
filepaths = []
for ext in extensions:
pattern = os.path.join(pages_dir, ext)
filepaths.extend(glob.glob(pattern))
# Sort files naturally
filepaths = sorted(filepaths, key=lambda x: _natural_sort_key(os.path.basename(x)))
# Load images with error handling
images = []
valid_filepaths = []
failed_count = 0
for filepath in filepaths:
try:
img = Image.open(filepath)
# Convert to RGB if necessary (handles RGBA, P, etc.)
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
valid_filepaths.append(filepath)
except Exception as e:
failed_count += 1
print(f"Warning: Failed to load image {filepath}: {e}")
continue
if failed_count > 0:
print(
f"Warning: Failed to load {failed_count} image(s) out of {len(filepaths)} total files"
)
return valid_filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None: def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
@@ -206,99 +150,36 @@ def _select_device_and_dtype():
def _load_colvision(model_choice: str): def _load_colvision(model_choice: str):
import os
import torch import torch
from colpali_engine.models import ( from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
ColPali,
ColQwen2,
ColQwen2_5,
ColQwen2_5_Processor,
ColQwen2Processor,
)
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available from transformers.utils.import_utils import is_flash_attn_2_available
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
# Set environment variables to ensure models are downloaded from HuggingFace
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
# Log model loading info
logger.info(f"Loading ColVision model: {model_choice}")
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
device_str, device, dtype = _select_device_and_dtype() device_str, device, dtype = _select_device_and_dtype()
# Determine model name and type
# IMPORTANT: Check colqwen2.5 BEFORE colqwen2 to avoid false matches
model_choice_lower = model_choice.lower()
if model_choice == "colqwen2": if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0" model_name = "vidore/colqwen2-v1.0"
model_type = "colqwen2" # On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
elif model_choice == "colqwen2.5" or model_choice == "colqwen25":
model_name = "vidore/colqwen2.5-v0.2"
model_type = "colqwen2.5"
elif model_choice == "colpali":
model_name = "vidore/colpali-v1.2"
model_type = "colpali"
elif (
"colqwen2.5" in model_choice_lower
or "colqwen25" in model_choice_lower
or "colqwen2_5" in model_choice_lower
):
# Handle HuggingFace model names like "vidore/colqwen2.5-v0.2"
model_name = model_choice
model_type = "colqwen2.5"
elif "colqwen2" in model_choice_lower and "colqwen2-v1.0" in model_choice_lower:
# Handle HuggingFace model names like "vidore/colqwen2-v1.0" (but not colqwen2.5)
model_name = model_choice
model_type = "colqwen2"
elif "colpali" in model_choice_lower:
# Handle HuggingFace model names like "vidore/colpali-v1.2"
model_name = model_choice
model_type = "colpali"
else:
# Default to colpali for backward compatibility
model_name = "vidore/colpali-v1.2"
model_type = "colpali"
# Load model based on type
attn_implementation = ( attn_implementation = (
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager" "flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
) )
# Load model from HuggingFace Hub (not Google Drive)
# Use local_files_only=False to ensure download from HF if not cached
if model_type == "colqwen2.5":
model = ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval()
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
elif model_type == "colqwen2":
model = ColQwen2.from_pretrained( model = ColQwen2.from_pretrained(
model_name, model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=device, device_map=device,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval() ).eval()
processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False) processor = ColQwen2Processor.from_pretrained(model_name)
else: # colpali else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained( model = ColPali.from_pretrained(
model_name, model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=device, device_map=device,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval() ).eval()
processor = cast( processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
)
return model_name, model, processor, device_str, device, dtype return model_name, model, processor, device_str, device, dtype
@@ -313,7 +194,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
dataloader = DataLoader( dataloader = DataLoader(
dataset=ListDataset[Image.Image](images), dataset=ListDataset[Image.Image](images),
batch_size=32, batch_size=1,
shuffle=False, shuffle=False,
collate_fn=lambda x: processor.process_images(x), collate_fn=lambda x: processor.process_images(x),
) )
@@ -337,47 +218,32 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
def _embed_queries(model, processor, queries: list[str]) -> list[Any]: def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval() model.eval()
# Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings: dataloader = DataLoader(
# 1. MTEB receives batch["text"] which already includes instruction/prompt (from _combine_queries_with_instruction_text) dataset=ListDataset[str](queries),
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10 batch_size=1,
# 3. Calls processor.process_queries(batch) where batch is now a list of strings shuffle=False,
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10) collate_fn=lambda x: processor.process_queries(x),
# )
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
# We need to match this exactly to reproduce MTEB results
all_embeds = []
batch_size = 32 # Match MTEB's default batch_size
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad(): with torch.no_grad():
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"): batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
batch_queries = queries[i : i + batch_size]
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
batch = [
processor.query_prefix + t + processor.query_augmentation_token * 10
for t in batch_queries
]
inputs = processor.process_queries(batch)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
if model.device.type == "cuda": if model.device.type == "cuda":
with torch.autocast( with torch.autocast(
device_type="cuda", device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16, dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
): ):
outs = model(**inputs) embeddings_query = model(**batch_query)
else: else:
outs = model(**inputs) embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# Match MTEB: convert to float32 on CPU return q_vecs
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
return all_embeds
def _build_index( def _build_index(
@@ -417,279 +283,6 @@ def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
return None return None
def _build_fast_plaid_index(
index_path: str,
doc_vecs: list[Any],
filepaths: list[str],
images: list[Image.Image],
) -> tuple[Any, float]:
"""
Build a Fast-Plaid index from document embeddings.
Args:
index_path: Path to save the Fast-Plaid index
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
filepaths: List of filepath identifiers for each document
images: List of PIL Images corresponding to each document
Returns:
Tuple of (FastPlaid index object, build_time_in_seconds)
"""
import torch
from fast_plaid import search as fast_plaid_search
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
_t0 = time.perf_counter()
# Convert doc_vecs to list of tensors
documents_embeddings = []
for i, vec in enumerate(doc_vecs):
if i % 1000 == 0:
print(f" Converting embedding {i}/{len(doc_vecs)}...")
if not isinstance(vec, torch.Tensor):
vec = (
torch.tensor(vec)
if isinstance(vec, np.ndarray)
else torch.from_numpy(np.array(vec))
)
# Ensure float32 for Fast-Plaid
if vec.dtype != torch.float32:
vec = vec.float()
documents_embeddings.append(vec)
print(f" Converted {len(documents_embeddings)} embeddings")
if len(documents_embeddings) > 0:
print(f" First embedding shape: {documents_embeddings[0].shape}")
print(f" First embedding dtype: {documents_embeddings[0].dtype}")
# Prepare metadata for Fast-Plaid
print(f" Preparing metadata for {len(filepaths)} documents...")
metadata_list = []
for i, filepath in enumerate(filepaths):
metadata_list.append(
{
"filepath": filepath,
"index": i,
}
)
# Create Fast-Plaid index
print(f" Creating FastPlaid object with index path: {index_path}")
try:
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
print(" FastPlaid object created successfully")
except Exception as e:
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
try:
fast_plaid_index.create(
documents_embeddings=documents_embeddings,
metadata=metadata_list,
)
print(" Fast-Plaid index created successfully")
except Exception as e:
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
build_secs = time.perf_counter() - _t0
# Save images separately (Fast-Plaid doesn't store images)
print(f" Saving {len(images)} images...")
images_dir = Path(index_path) / "images"
images_dir.mkdir(parents=True, exist_ok=True)
for i, img in enumerate(tqdm(images, desc="Saving images")):
img_path = images_dir / f"doc_{i}.png"
img.save(str(img_path))
return fast_plaid_index, build_secs
def _fast_plaid_index_exists(index_path: str) -> bool:
"""
Check if Fast-Plaid index exists by checking for key files.
This avoids creating the FastPlaid object which may trigger memory allocation.
Args:
index_path: Path to the Fast-Plaid index
Returns:
True if index appears to exist, False otherwise
"""
index_path_obj = Path(index_path)
if not index_path_obj.exists() or not index_path_obj.is_dir():
return False
# Fast-Plaid creates a SQLite database file for metadata
# Check for metadata.db as the most reliable indicator
metadata_db = index_path_obj / "metadata.db"
if metadata_db.exists() and metadata_db.stat().st_size > 0:
return True
# Also check if directory has any files (might be incomplete index)
try:
if any(index_path_obj.iterdir()):
return True
except Exception:
pass
return False
def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]:
"""
Load Fast-Plaid index if it exists.
First checks if index files exist, then creates the FastPlaid object.
The actual index data loading happens lazily when search is called.
Args:
index_path: Path to the Fast-Plaid index
Returns:
FastPlaid index object if exists, None otherwise
"""
try:
from fast_plaid import search as fast_plaid_search
# First check if index files exist without creating the object
if not _fast_plaid_index_exists(index_path):
return None
# Now try to create FastPlaid object
# This may trigger some memory allocation, but the full index loading is deferred
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
return fast_plaid_index
except ImportError:
# fast-plaid not installed
return None
except Exception as e:
# Any error (including memory errors from Rust backend) - return None
# The error will be caught and index will be rebuilt
print(f"Warning: Could not load Fast-Plaid index: {type(e).__name__}: {e}")
return None
def _search_fast_plaid(
fast_plaid_index: Any,
query_vec: Any,
top_k: int,
) -> tuple[list[tuple[float, int]], float]:
"""
Search Fast-Plaid index with a query embedding.
Args:
fast_plaid_index: FastPlaid index object
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
top_k: Number of top results to return
Returns:
Tuple of (results_list, search_time_in_seconds)
results_list: List of (score, doc_id) tuples
"""
import torch
_t0 = time.perf_counter()
# Ensure query is a torch tensor
if not isinstance(query_vec, torch.Tensor):
q_vec_tensor = (
torch.tensor(query_vec)
if isinstance(query_vec, np.ndarray)
else torch.from_numpy(np.array(query_vec))
)
else:
q_vec_tensor = query_vec
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
if q_vec_tensor.dim() == 2:
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
# Perform search
scores = fast_plaid_index.search(
queries_embeddings=q_vec_tensor,
top_k=top_k,
show_progress=True,
)
search_secs = time.perf_counter() - _t0
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
results = []
if scores and len(scores) > 0:
query_results = scores[0]
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
results = [(float(score), int(doc_id)) for doc_id, score in query_results]
return results, search_secs
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve image for a document from Fast-Plaid index.
Args:
index_path: Path to the Fast-Plaid index
doc_id: Document ID returned by Fast-Plaid search
Returns:
PIL Image if found, None otherwise
Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid
doc_id may differ from the file naming index.
"""
# First get metadata to find the actual index used for file naming
metadata = _get_fast_plaid_metadata(index_path, doc_id)
if metadata is None:
# Fallback: try using doc_id directly
file_index = doc_id
else:
# Use the 'index' field from metadata, which matches the file naming
file_index = metadata.get("index", doc_id)
images_dir = Path(index_path) / "images"
image_path = images_dir / f"doc_{file_index}.png"
if image_path.exists():
return Image.open(image_path)
# If not found with index, try doc_id as fallback
if file_index != doc_id:
fallback_path = images_dir / f"doc_{doc_id}.png"
if fallback_path.exists():
return Image.open(fallback_path)
return None
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a document from Fast-Plaid index.
Args:
index_path: Path to the Fast-Plaid index
doc_id: Document ID
Returns:
Dictionary with metadata if found, None otherwise
"""
try:
from fast_plaid import filtering
metadata_list = filtering.get(index=index_path, subset=[doc_id])
if metadata_list and len(metadata_list) > 0:
return metadata_list[0]
except Exception:
pass
return None
def _generate_similarity_map( def _generate_similarity_map(
model, model,
processor, processor,
@@ -1085,15 +678,11 @@ class LeannMultiVector:
return (float(score), doc_id) return (float(score), doc_id)
scores: list[tuple[float, int]] = [] scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids] futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures): for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result()) scores.append(fut.result())
end_time = time.time()
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True) scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores return scores[:topk] if len(scores) >= topk else scores
@@ -1121,6 +710,7 @@ class LeannMultiVector:
emb_path = self._embeddings_path() emb_path = self._embeddings_path()
if not emb_path.exists(): if not emb_path.exists():
return self.search(data, topk) return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r") all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32: if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32) all_embeddings = all_embeddings.astype(np.float32)
@@ -1128,29 +718,23 @@ class LeannMultiVector:
assert self._docid_to_indices is not None assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys()) candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]: def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, []) token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices: if not token_indices:
return (0.0, doc_id) return (0.0, doc_id)
doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32) doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T) sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30) sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum() score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id) return (float(score), doc_id)
scores: list[tuple[float, int]] = [] scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids] futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures): for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result()) scores.append(fut.result())
end_time = time.time()
# print number of candidate doc ids
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True) scores.sort(key=lambda x: x[0], reverse=True)
del all_embeddings
return scores[:topk] if len(scores) >= topk else scores return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]: def get_image(self, doc_id: int) -> Optional[Image.Image]:
@@ -1194,259 +778,3 @@ class LeannMultiVector:
"image_path": meta.get("image_path", ""), "image_path": meta.get("image_path", ""),
} }
return None return None
class ViDoReBenchmarkEvaluator:
"""
A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
This class encapsulates common functionality for building indexes, searching, and evaluating.
"""
def __init__(
self,
model_name: str,
use_fast_plaid: bool = False,
top_k: int = 100,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
):
"""
Initialize the evaluator.
Args:
model_name: Model name ("colqwen2" or "colpali")
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
top_k: Top-k results to retrieve
first_stage_k: First stage k for LEANN search
k_values: List of k values for evaluation metrics
"""
self.model_name = model_name
self.use_fast_plaid = use_fast_plaid
self.top_k = top_k
self.first_stage_k = first_stage_k
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
# Load model once (can be reused across tasks)
self._model = None
self._processor = None
self._model_name_actual = None
def _load_model_if_needed(self):
"""Lazy load the model."""
if self._model is None:
print(f"\nLoading model: {self.model_name}")
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(
self.model_name
)
print(f"Model loaded: {self._model_name_actual}")
def build_index_from_corpus(
self,
corpus: dict[str, Image.Image],
index_path: str,
rebuild: bool = False,
) -> tuple[Any, list[str]]:
"""
Build index from corpus images.
Args:
corpus: dict mapping corpus_id to PIL Image
index_path: Path to save/load the index
rebuild: Whether to rebuild even if index exists
Returns:
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
"""
self._load_model_if_needed()
# Ensure consistent ordering
corpus_ids = sorted(corpus.keys())
images = [corpus[cid] for cid in corpus_ids]
if self.use_fast_plaid:
# Check if Fast-Plaid index exists
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
print(f"Fast-Plaid index already exists at {index_path}")
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
print(f"Building Fast-Plaid index at {index_path}...")
print("Embedding images...")
doc_vecs = _embed_images(self._model, self._processor, images)
fast_plaid_index, build_time = _build_fast_plaid_index(
index_path, doc_vecs, corpus_ids, images
)
print(f"Fast-Plaid index built in {build_time:.2f}s")
return fast_plaid_index, corpus_ids
else:
# Check if LEANN index exists
if not rebuild:
retriever = _load_retriever_if_index_exists(index_path)
if retriever is not None:
print(f"LEANN index already exists at {index_path}")
return retriever, corpus_ids
print(f"Building LEANN index at {index_path}...")
print("Embedding images...")
doc_vecs = _embed_images(self._model, self._processor, images)
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
print("LEANN index built")
return retriever, corpus_ids
def search_queries(
self,
queries: dict[str, str],
corpus_ids: list[str],
index_or_retriever: Any,
fast_plaid_index_path: Optional[str] = None,
task_prompt: Optional[dict[str, str]] = None,
) -> dict[str, dict[str, float]]:
"""
Search queries against the index.
Args:
queries: dict mapping query_id to query text
corpus_ids: list of corpus_ids in the same order as the index
index_or_retriever: index or retriever object
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
Returns:
results: dict mapping query_id to dict of {corpus_id: score}
"""
self._load_model_if_needed()
print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
query_ids = list(queries.keys())
query_texts = [queries[qid] for qid in query_ids]
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
# So we don't append task_prompt here to match MTEB behavior
# Embed queries
print("Embedding queries...")
query_vecs = _embed_queries(self._model, self._processor, query_texts)
results = {}
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
if self.use_fast_plaid:
# Fast-Plaid search
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, self.top_k)
query_results = {}
for score, doc_id in search_results:
if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score)
else:
# LEANN search
import torch
query_np = (
query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
)
search_results = index_or_retriever.search_exact(query_np, topk=self.top_k)
query_results = {}
for score, doc_id in search_results:
if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score)
results[query_id] = query_results
return results
@staticmethod
def evaluate_results(
results: dict[str, dict[str, float]],
qrels: dict[str, dict[str, int]],
k_values: Optional[list[int]] = None,
) -> dict[str, float]:
"""
Evaluate retrieval results using NDCG and other metrics.
Args:
results: dict mapping query_id to dict of {corpus_id: score}
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
k_values: List of k values for evaluation metrics
Returns:
Dictionary of metric scores
"""
try:
from mteb._evaluators.retrieval_metrics import (
calculate_retrieval_scores,
make_score_dict,
)
except ImportError:
raise ImportError(
"pytrec_eval is required for evaluation. Install with: pip install pytrec-eval"
)
if k_values is None:
k_values = [1, 3, 5, 10, 100]
# Check if we have any queries to evaluate
if len(results) == 0:
print("Warning: No queries to evaluate. Returning zero scores.")
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
print(f"Evaluating results with k_values={k_values}...")
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
# Filter to ensure qrels and results have the same query set
# This matches MTEB behavior: only evaluate queries that exist in both
# pytrec_eval only evaluates queries in qrels, so we need to ensure
# results contains all queries in qrels, and filter out queries not in qrels
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
qrels_filtered = {
qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered
}
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
if len(results_filtered) != len(qrels_filtered):
print(
f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries"
)
missing_in_results = set(qrels.keys()) - set(results.keys())
if missing_in_results:
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
# Convert qrels to pytrec_eval format
qrels_pytrec = {}
for qid, rel_docs in qrels_filtered.items():
qrels_pytrec[qid] = dict(rel_docs.items())
# Evaluate
eval_result = calculate_retrieval_scores(
results=results_filtered,
qrels=qrels_pytrec,
k_values=k_values,
)
# Format scores
scores = make_score_dict(
ndcg=eval_result.ndcg,
_map=eval_result.map,
recall=eval_result.recall,
precision=eval_result.precision,
mrr=eval_result.mrr,
naucs=eval_result.naucs,
naucs_mrr=eval_result.naucs_mrr,
cv_recall=eval_result.cv_recall,
task_scores={},
)
return scores

View File

@@ -1,19 +1,12 @@
## Jupyter-style notebook script ## Jupyter-style notebook script
# %% # %%
# uv pip install matplotlib qwen_vl_utils # uv pip install matplotlib qwen_vl_utils
import argparse
import faulthandler
import os import os
import time
from typing import Any, Optional from typing import Any, Optional
import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
# Enable faulthandler to get stack trace on segfault
faulthandler.enable()
from leann_multi_vector import ( # utility functions/classes from leann_multi_vector import ( # utility functions/classes
_ensure_repo_paths_importable, _ensure_repo_paths_importable,
@@ -25,11 +18,6 @@ from leann_multi_vector import ( # utility functions/classes
_build_index, _build_index,
_load_retriever_if_index_exists, _load_retriever_if_index_exists,
_generate_similarity_map, _generate_similarity_map,
_build_fast_plaid_index,
_load_fast_plaid_index_if_exists,
_search_fast_plaid,
_get_fast_plaid_image,
_get_fast_plaid_metadata,
QwenVL, QwenVL,
) )
@@ -43,52 +31,19 @@ MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended) # Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True USE_HF_DATASET: bool = True
# Single dataset name (used when DATASET_NAMES is None)
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector" DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
# Multiple datasets to combine (if provided, DATASET_NAME is ignored) DATASET_SPLIT: str = "train"
# Can be:
# - List of strings: ["dataset1", "dataset2"]
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
# - Mixed: ["dataset1", ("dataset2", "config2")]
#
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
# - "pixparse/arxiv-papers" (if available, arXiv papers)
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
# - "huggingface/document-images" (if available)
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
DATASET_NAMES = [
"weaviate/arXiv-AI-papers-multi-vector",
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
]
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
# Set to None to try loading all available splits automatically
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
# Image field name in the dataset (auto-detect if None)
# Common names: "page_image", "image", "images", "img"
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False) # Local pages (used when USE_HF_DATASET == False)
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf" PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages" PAGES_DIR: str = "./pages"
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
# If set, images will be loaded directly from this folder
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
# Whether to recursively search subdirectories when loading from custom folder
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
# Index + retrieval settings # Index + retrieval settings
# Use a different index path for larger dataset to avoid overwriting existing index INDEX_PATH: str = "./indexes/colvision.leann"
INDEX_PATH: str = "./indexes/colvision_large.leann"
# Fast-Plaid index settings (alternative to LEANN index)
# These are now command-line arguments (see CLI overrides section)
TOPK: int = 3 TOPK: int = 3
FIRST_STAGE_K: int = 500 FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists REBUILD_INDEX: bool = False
# Artifacts # Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png" SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
@@ -99,97 +54,12 @@ ANSWER: bool = True
MAX_NEW_TOKENS: int = 1024 MAX_NEW_TOKENS: int = 1024
# %%
# CLI overrides
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
parser.add_argument(
"--search-method",
type=str,
choices=["ann", "exact", "exact-all"],
default="ann",
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
)
parser.add_argument(
"--query",
type=str,
default=QUERY,
help=f"Query string to search for. Default: '{QUERY}'",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
default=False,
help="Set to True to use fast-plaid instead of LEANN. Default: False",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default="./indexes/colvision_fastplaid",
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
)
parser.add_argument(
"--topk",
type=int,
default=TOPK,
help=f"Number of top results to retrieve. Default: {TOPK}",
)
parser.add_argument(
"--custom-folder",
type=str,
default=None,
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
)
parser.add_argument(
"--recursive",
action="store_true",
default=False,
help="Recursively search subdirectories when loading images from custom folder. Default: False",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
default=False,
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
)
cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
# %% # %%
# Step 1: Check if we can skip data loading (index already exists) # Step 1: Check if we can skip data loading (index already exists)
retriever: Optional[Any] = None retriever: Optional[Any] = None
fast_plaid_index: Optional[Any] = None
need_to_build_index = REBUILD_INDEX need_to_build_index = REBUILD_INDEX
if USE_FAST_PLAID:
# Fast-Plaid index handling
if not REBUILD_INDEX:
try:
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
if fast_plaid_index is not None:
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
need_to_build_index = False
else:
print(f"Fast-Plaid index not found, will build new index")
need_to_build_index = True
except Exception as e:
# If loading fails (e.g., memory error, corrupted index), rebuild
print(f"Warning: Failed to load Fast-Plaid index: {e}")
print("Will rebuild the index...")
need_to_build_index = True
fast_plaid_index = None
else:
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
need_to_build_index = True
else:
# Original LEANN index handling
if not REBUILD_INDEX: if not REBUILD_INDEX:
retriever = _load_retriever_if_index_exists(INDEX_PATH) retriever = _load_retriever_if_index_exists(INDEX_PATH)
if retriever is not None: if retriever is not None:
@@ -199,247 +69,23 @@ else:
else: else:
print(f"Index not found, will build new index") print(f"Index not found, will build new index")
need_to_build_index = True need_to_build_index = True
else:
print(f"REBUILD_INDEX=True, will rebuild index")
need_to_build_index = True
# Step 2: Load data only if we need to build the index # Step 2: Load data only if we need to build the index
if need_to_build_index: if need_to_build_index:
print("Loading dataset...") print("Loading dataset...")
# Check for custom folder path first (takes precedence) if USE_HF_DATASET:
if CUSTOM_FOLDER_PATH: from datasets import load_dataset
if not os.path.isdir(CUSTOM_FOLDER_PATH):
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
if CUSTOM_FOLDER_RECURSIVE:
print(" (recursive mode: searching subdirectories)")
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
print(f" Found {len(filepaths)} image files")
if not images:
raise RuntimeError(
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
)
print(f" Successfully loaded {len(images)} images")
# Use filenames as identifiers instead of full paths for cleaner metadata
filepaths = [os.path.basename(fp) for fp in filepaths]
elif USE_HF_DATASET:
from datasets import load_dataset, concatenate_datasets, DatasetDict
# Determine which datasets to load dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
if DATASET_NAMES is not None:
dataset_names_to_load = DATASET_NAMES
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
else:
dataset_names_to_load = [DATASET_NAME]
print(f"Loading single dataset: {DATASET_NAME}")
# Load and combine datasets
all_datasets_to_concat = []
for dataset_entry in dataset_names_to_load:
# Handle both string and tuple formats
if isinstance(dataset_entry, tuple):
dataset_name, config_name = dataset_entry
else:
dataset_name = dataset_entry
config_name = None
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
# Load dataset to check available splits
# If config_name is provided, use it; otherwise try without config
try:
if config_name:
dataset_dict = load_dataset(dataset_name, config_name)
else:
dataset_dict = load_dataset(dataset_name)
except ValueError as e:
if "Config name is missing" in str(e):
# Try to get available configs and suggest
from datasets import get_dataset_config_names
try:
available_configs = get_dataset_config_names(dataset_name)
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Available configs: {available_configs}. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
except Exception:
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
raise
# Determine which splits to load
if DATASET_SPLITS is None:
# Auto-detect: try to load all available splits
available_splits = list(dataset_dict.keys())
print(f" Auto-detected splits: {available_splits}")
splits_to_load = available_splits
else:
splits_to_load = DATASET_SPLITS
# Load and concatenate multiple splits for this dataset
datasets_to_concat = []
for split in splits_to_load:
if split not in dataset_dict:
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
continue
split_dataset = dataset_dict[split]
print(f" Loaded split '{split}': {len(split_dataset)} pages")
datasets_to_concat.append(split_dataset)
if not datasets_to_concat:
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
continue
# Concatenate splits for this dataset
if len(datasets_to_concat) > 1:
combined_dataset = concatenate_datasets(datasets_to_concat)
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
else:
combined_dataset = datasets_to_concat[0]
all_datasets_to_concat.append(combined_dataset)
if not all_datasets_to_concat:
raise RuntimeError("No valid datasets or splits found.")
# Concatenate all datasets
if len(all_datasets_to_concat) > 1:
dataset = concatenate_datasets(all_datasets_to_concat)
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
else:
dataset = all_datasets_to_concat[0]
# Apply MAX_DOCS limit if specified
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset)) N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
if N < len(dataset):
print(f"Limiting to {N} pages (from {len(dataset)} total)")
dataset = dataset.select(range(N))
# Auto-detect image field name if not specified
if IMAGE_FIELD_NAME is None:
# Check multiple samples to find the most common image field
# (useful when datasets are merged and may have different field names)
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
field_counts = {}
# Check first few samples to find image fields
num_samples_to_check = min(10, len(dataset))
for sample_idx in range(num_samples_to_check):
sample = dataset[sample_idx]
for field in possible_image_fields:
if field in sample and sample[field] is not None:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
field_counts[field] = field_counts.get(field, 0) + 1
# Choose the most common field, or first found if tied
if field_counts:
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
else:
# Fallback: check first sample only
sample = dataset[0]
image_field = None
for field in possible_image_fields:
if field in sample:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
image_field = field
break
if image_field is None:
raise RuntimeError(
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
f"Please specify IMAGE_FIELD_NAME manually."
)
print(f"Auto-detected image field: '{image_field}'")
else:
image_field = IMAGE_FIELD_NAME
if image_field not in dataset[0]:
raise RuntimeError(
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
)
filepaths: list[str] = [] filepaths: list[str] = []
images: list[Image.Image] = [] images: list[Image.Image] = []
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)): for i in tqdm(range(N), desc="Loading dataset", total=N):
p = dataset[i] p = dataset[i]
# Try to compose a descriptive identifier # Compose a descriptive identifier for printing later
# Handle different dataset structures identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
identifier_parts = []
# Helper function to safely get field value
def safe_get(field_name, default=None):
if field_name in p and p[field_name] is not None:
return p[field_name]
return default
# Try to get various identifier fields
if safe_get("paper_arxiv_id"):
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
if safe_get("paper_title"):
identifier_parts.append(f"title:{p['paper_title']}")
if safe_get("page_number") is not None:
try:
identifier_parts.append(f"page:{int(p['page_number'])}")
except (ValueError, TypeError):
# If conversion fails, use the raw value or skip
if p['page_number']:
identifier_parts.append(f"page:{p['page_number']}")
if safe_get("page_id"):
identifier_parts.append(f"id:{p['page_id']}")
elif safe_get("questionId"):
identifier_parts.append(f"qid:{p['questionId']}")
elif safe_get("docId"):
identifier_parts.append(f"docId:{p['docId']}")
elif safe_get("id"):
identifier_parts.append(f"id:{p['id']}")
# If no identifier parts found, create one from index
if identifier_parts:
identifier = "|".join(identifier_parts)
else:
# Create identifier from available fields or index
fallback_parts = []
# Try common fields that might exist
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
if safe_get(field):
fallback_parts.append(f"{field}:{p[field]}")
break
if fallback_parts:
identifier = "|".join(fallback_parts) + f"|idx:{i}"
else:
identifier = f"doc_{i}"
filepaths.append(identifier) filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
# Get image - try detected field first, then fallback to other common fields
img = None
if image_field in p and p[image_field] is not None:
img = p[image_field]
else:
# Fallback: try other common image field names
for fallback_field in ["image", "page_image", "images", "img"]:
if fallback_field in p and p[fallback_field] is not None:
img = p[fallback_field]
break
if img is None:
raise RuntimeError(
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
f"Expected field: {image_field}"
)
# Ensure it's a PIL Image
if not isinstance(img, Image.Image):
if hasattr(img, 'convert'):
img = img.convert('RGB')
else:
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
images.append(img)
else: else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR) _maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR) filepaths, images = _load_images_from_dir(PAGES_DIR)
@@ -448,19 +94,6 @@ if need_to_build_index:
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist." f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
) )
print(f"Loaded {len(images)} images") print(f"Loaded {len(images)} images")
# Memory check before loading model
try:
import psutil
import torch
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
if torch.cuda.is_available():
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
except ImportError:
pass
else: else:
print("Skipping dataset loading (using existing index)") print("Skipping dataset loading (using existing index)")
filepaths = [] # Not needed when using existing index filepaths = [] # Not needed when using existing index
@@ -469,152 +102,36 @@ else:
# %% # %%
# Step 3: Load model and processor (only if we need to build index or perform search) # Step 3: Load model and processor (only if we need to build index or perform search)
print("Step 3: Loading model and processor...")
print(f" Model: {MODEL}")
try:
import sys
print(f" Python version: {sys.version}")
print(f" Python executable: {sys.executable}")
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}") print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# Memory check after loading model
try:
import psutil
import torch
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
if torch.cuda.is_available():
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
except ImportError:
pass
except Exception as e:
print(f"✗ Error loading model: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
# %% # %%
# %% # %%
# Step 4: Build index if needed # Step 4: Build index if needed
if need_to_build_index: if need_to_build_index and retriever is None:
print("Step 4: Building index...") print("Building index...")
print(f" Number of images: {len(images)}")
print(f" Number of filepaths: {len(filepaths)}")
try:
print(" Embedding images...")
doc_vecs = _embed_images(model, processor, images) doc_vecs = _embed_images(model, processor, images)
print(f" Embedded {len(doc_vecs)} documents")
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
except Exception as e:
print(f"Error embedding images: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
if USE_FAST_PLAID:
# Build Fast-Plaid index
print(" Building Fast-Plaid index...")
try:
fast_plaid_index, build_secs = _build_fast_plaid_index(
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
)
from pathlib import Path
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
except Exception as e:
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Clear memory
print(" Clearing memory...")
del images, filepaths, doc_vecs
else:
# Build original LEANN index
try:
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images) retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}") print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
except Exception as e:
print(f"Error building LEANN index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Clear memory # Clear memory
print(" Clearing memory...")
del images, filepaths, doc_vecs del images, filepaths, doc_vecs
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them # Note: Images are now stored in the index, retriever will load them on-demand from disk
# %% # %%
# Step 5: Embed query and search # Step 5: Embed query and search
_t0 = time.perf_counter()
q_vec = _embed_queries(model, processor, [QUERY])[0] q_vec = _embed_queries(model, processor, [QUERY])[0]
query_embed_secs = time.perf_counter() - _t0 results = retriever.search(q_vec.float().numpy(), topk=TOPK)
print(f"[Search] Method: {SEARCH_METHOD}")
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
# Run the selected search method and time it
if USE_FAST_PLAID:
# Fast-Plaid search
if fast_plaid_index is None:
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
if fast_plaid_index is None:
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
else:
# Original LEANN search
query_np = q_vec.float().numpy()
if SEARCH_METHOD == "ann":
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact":
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact-all":
results = retriever.search_exact_all(query_np, topk=TOPK)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
else:
results = []
if not results: if not results:
print("No results found.") print("No results found.")
else: else:
print(f'Top {len(results)} results for query: "{QUERY}"') print(f'Top {len(results)} results for query: "{QUERY}"')
print("\n[DEBUG] Retrieval details:")
top_images: list[Image.Image] = [] top_images: list[Image.Image] = []
image_hashes = {} # Track image hashes to detect duplicates
for rank, (score, doc_id) in enumerate(results, start=1): for rank, (score, doc_id) in enumerate(results, start=1):
# Retrieve image and metadata based on index type # Retrieve image from index instead of memory
if USE_FAST_PLAID:
# Fast-Plaid: load image and get metadata
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
if image is None:
print(f"Warning: Could not find image for doc_id {doc_id}")
continue
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
top_images.append(image)
else:
# Original LEANN: retrieve from retriever
image = retriever.get_image(doc_id) image = retriever.get_image(doc_id)
if image is None: if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}") print(f"Warning: Could not retrieve image for doc_id {doc_id}")
@@ -622,29 +139,10 @@ else:
metadata = retriever.get_metadata(doc_id) metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown" path = metadata.get("filepath", "unknown") if metadata else "unknown"
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(image) top_images.append(image)
# Calculate image hash to detect duplicates
import hashlib
import io
# Convert image to bytes for hashing
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
image_bytes = img_bytes.getvalue()
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
# Check if this image was already seen
duplicate_info = ""
if image_hash in image_hashes:
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
else:
image_hashes[image_hash] = rank
# Print detailed information
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
if metadata:
print(f" Metadata: {metadata}")
if SAVE_TOP_IMAGE: if SAVE_TOP_IMAGE:
from pathlib import Path as _Path from pathlib import Path as _Path
@@ -663,6 +161,7 @@ else:
except Exception: except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}") print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %% # %%
# Step 6: Similarity maps for top-K results # Step 6: Similarity maps for top-K results
@@ -705,9 +204,6 @@ if results and SIMILARITY_MAP:
# Step 7: Optional answer generation # Step 7: Optional answer generation
if results and ANSWER: if results and ANSWER:
qwen = QwenVL(device=device_str) qwen = QwenVL(device=device_str)
_t0 = time.perf_counter()
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS) response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
gen_secs = time.perf_counter() - _t0
print(f"[Timing] Generation: {gen_secs:.3f}s")
print("\nAnswer:") print("\nAnswer:")
print(response) print(response)

View File

@@ -1,448 +0,0 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v1 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v1 tasks
python vidore_v1_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
# Use Fast-Plaid index
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
from typing import Optional
from datasets import load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
)
_ensure_repo_paths_importable(__file__)
# ViDoRe v1 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V1_TASKS = {
"VidoreArxivQARetrieval": {
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreDocVQARetrieval": {
"dataset_path": "vidore/docvqa_test_subsampled_beir",
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreInfoVQARetrieval": {
"dataset_path": "vidore/infovqa_test_subsampled_beir",
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreTabfquadRetrieval": {
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreTatdqaRetrieval": {
"dataset_path": "vidore/tatdqa_test_beir",
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreShiftProjectRetrieval": {
"dataset_path": "vidore/shiftproject_test_beir",
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAAIRetrieval": {
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAEnergyRetrieval": {
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
# Task name aliases (short names -> full names)
TASK_ALIASES = {
"arxivqa": "VidoreArxivQARetrieval",
"docvqa": "VidoreDocVQARetrieval",
"infovqa": "VidoreInfoVQARetrieval",
"tabfquad": "VidoreTabfquadRetrieval",
"tatdqa": "VidoreTatdqaRetrieval",
"shiftproject": "VidoreShiftProjectRetrieval",
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
}
def normalize_task_name(task_name: str) -> str:
"""Normalize task name (handle aliases)."""
task_name_lower = task_name.lower()
if task_name in VIDORE_V1_TASKS:
return task_name
if task_name_lower in TASK_ALIASES:
return TASK_ALIASES[task_name_lower]
# Try partial match
for alias, full_name in TASK_ALIASES.items():
if alias in task_name_lower or task_name_lower in alias:
return full_name
return task_name
def get_safe_model_name(model_name: str) -> str:
"""Get a safe model name for use in file paths."""
import hashlib
import os
# If it's a path, use basename or hash
if os.path.exists(model_name) and os.path.isdir(model_name):
# Use basename if it's reasonable, otherwise use hash
basename = os.path.basename(model_name.rstrip("/"))
if basename and len(basename) < 100 and not basename.startswith("."):
return basename
# Use hash for very long or problematic paths
return hashlib.md5(model_name.encode()).hexdigest()[:16]
# For HuggingFace model names, replace / with _
return model_name.replace("/", "_").replace(":", "_")
def load_vidore_v1_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
):
"""
Load ViDoRe v1 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
}
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 1000,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v1 task.
"""
print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}")
print(f"{'=' * 80}")
# Normalize task name (handle aliases)
task_name = normalize_task_name(task_name)
# Get task config
if task_name not in VIDORE_V1_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
task_config = VIDORE_V1_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Load data
corpus, queries, qrels = load_vidore_v1_data(
dataset_path=dataset_path,
revision=revision,
split="test",
)
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 20, 100, 1000]
# Check if we have any queries
if len(queries) == 0:
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name,
use_fast_plaid=use_fast_plaid,
top_k=top_k,
first_stage_k=first_stage_k,
k_values=k_values,
)
# Build or load index
# Use safe model name for index path (different models need different indexes)
safe_model_name = get_safe_model_name(model_name)
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus,
index_path=index_path_full,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
index_or_retriever=index_or_retriever,
fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'=' * 80}")
print(f"Results for {task_name}:")
print(f"{'=' * 80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--top-k",
type=int,
default=1000,
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,20,100,1000",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v1_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [normalize_task_name(args.task)]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
else:
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'=' * 80}")
print("SUMMARY")
print(f"{'=' * 80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()

View File

@@ -1,439 +0,0 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v2 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v2 tasks
python vidore_v2_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
# Use Fast-Plaid index
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
from typing import Optional
from datasets import load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
)
_ensure_repo_paths_importable(__file__)
# Language name to dataset language field value mapping
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
LANGUAGE_MAPPING = {
"english": "eng-Latn",
"french": "fra-Latn",
"spanish": "spa-Latn",
"german": "deu-Latn",
}
# ViDoRe v2 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V2_TASKS = {
"Vidore2ESGReportsRetrieval": {
"dataset_path": "vidore/esg_reports_v2",
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2EconomicsReportsRetrieval": {
"dataset_path": "vidore/economics_reports_v2",
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2BioMedicalLecturesRetrieval": {
"dataset_path": "vidore/biomedical_lectures_v2",
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2ESGReportsHLRetrieval": {
"dataset_path": "vidore/esg_reports_human_labeled_v2",
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
# Note: This dataset doesn't have language filtering - all queries are English
"languages": None, # No language filtering needed
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
def load_vidore_v2_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
language: Optional[str] = None,
):
"""
Load ViDoRe v2 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
# Check if dataset has language field before filtering
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
if language and has_language_field:
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
dataset_language = LANGUAGE_MAPPING.get(language, language)
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
# Check if filtering resulted in empty dataset
if len(query_ds_filtered) == 0:
print(
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
)
# Try with original language value (dataset might use simple names like 'english')
print(f"Trying with original language value '{language}'...")
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
if len(query_ds_filtered) == 0:
# Try to get a sample to see actual language values
try:
sample_ds = load_dataset(
dataset_path, "queries", split=split, revision=revision
)
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"])
print(f"Available language values in dataset: {sample_langs}")
except Exception:
pass
else:
print(
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
)
query_ds = query_ds_filtered
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
}
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
language: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 100,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v2 task.
"""
print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}")
print(f"{'=' * 80}")
# Get task config
if task_name not in VIDORE_V2_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
task_config = VIDORE_V2_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Determine language
if language is None:
# Use first language if multiple available
languages = task_config.get("languages")
if languages is None:
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
language = None
elif len(languages) == 1:
language = languages[0]
else:
language = None
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 100]
# Load data
corpus, queries, qrels = load_vidore_v2_data(
dataset_path=dataset_path,
revision=revision,
split="test",
language=language,
)
# Check if we have any queries
if len(queries) == 0:
print(
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
)
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name,
use_fast_plaid=use_fast_plaid,
top_k=top_k,
first_stage_k=first_stage_k,
k_values=k_values,
)
# Build or load index
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus,
index_path=index_path_full,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
index_or_retriever=index_or_retriever,
fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'=' * 80}")
print(f"Results for {task_name}:")
print(f"{'=' * 80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
choices=["colqwen2", "colpali"],
help="Model to use",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--language",
type=str,
default=None,
help="Language to evaluate (default: first available)",
)
parser.add_argument(
"--top-k",
type=int,
default=100,
help="Top-k results to retrieve",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,100",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v2_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [args.task]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
else:
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
language=args.language,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'=' * 80}")
print("SUMMARY")
print(f"{'=' * 80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()

View File

@@ -454,7 +454,7 @@ leann search my-index "your query" \
### 2) Run remote builds with SkyPilot (cloud GPU) ### 2) Run remote builds with SkyPilot (cloud GPU)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`. Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
```bash ```bash
# One-time: install and configure SkyPilot # One-time: install and configure SkyPilot

View File

@@ -1251,15 +1251,15 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge." "Please provide the best answer you can based on this context and your knowledge."
) )
logger.info("The context provided to the LLM is:") print("The context provided to the LLM is:")
logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}") print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
logger.info("-" * 150) print("-" * 150)
for r in results: for r in results:
chunk_relevance = f"{r.score:.3f}" chunk_relevance = f"{r.score:.3f}"
chunk_id = r.id chunk_id = r.id
chunk_content = r.text[:60] chunk_content = r.text[:60]
chunk_source = r.metadata.get("source", "")[:80] chunk_source = r.metadata.get("source", "")[:80]
logger.info( print(
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}" f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
) )
ask_time = time.time() ask_time = time.time()

View File

@@ -12,13 +12,7 @@ from typing import Any, Optional
import torch import torch
from .settings import ( from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
resolve_anthropic_api_key,
resolve_anthropic_base_url,
resolve_ollama_host,
resolve_openai_api_key,
resolve_openai_base_url,
)
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -851,81 +845,6 @@ class OpenAIChat(LLMInterface):
return f"Error: Could not get a response from OpenAI. Details: {e}" return f"Error: Could not get a response from OpenAI. Details: {e}"
class AnthropicChat(LLMInterface):
"""LLM interface for Anthropic Claude models."""
def __init__(
self,
model: str = "claude-haiku-4-5",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model
self.base_url = resolve_anthropic_base_url(base_url)
self.api_key = resolve_anthropic_api_key(api_key)
if not self.api_key:
raise ValueError(
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
)
logger.info(
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try:
import anthropic
# Allow custom Anthropic-compatible endpoints via base_url
self.client = anthropic.Anthropic(
api_key=self.api_key,
base_url=self.base_url,
)
except ImportError:
raise ImportError(
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
)
def ask(self, prompt: str, **kwargs) -> str:
logger.info(f"Sending request to Anthropic with model {self.model}")
try:
# Anthropic API parameters
params = {
"model": self.model,
"max_tokens": kwargs.get("max_tokens", 1000),
"messages": [{"role": "user", "content": prompt}],
}
# Add optional parameters
if "temperature" in kwargs:
params["temperature"] = kwargs["temperature"]
if "top_p" in kwargs:
params["top_p"] = kwargs["top_p"]
response = self.client.messages.create(**params)
# Extract text from response
response_text = response.content[0].text
# Log token usage
print(
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
f"input tokens = {response.usage.input_tokens}, "
f"output tokens = {response.usage.output_tokens}"
)
if response.stop_reason == "max_tokens":
print("The query is exceeding the maximum allowed number of tokens")
return response_text.strip()
except Exception as e:
logger.error(f"Error communicating with Anthropic: {e}")
return f"Error: Could not get a response from Anthropic. Details: {e}"
class SimulatedChat(LLMInterface): class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development.""" """A simple simulated chat for testing and development."""
@@ -978,12 +897,6 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
) )
elif llm_type == "gemini": elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "anthropic":
return AnthropicChat(
model=model or "claude-3-5-sonnet-20241022",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "simulated": elif llm_type == "simulated":
return SimulatedChat() return SimulatedChat()
else: else:

View File

@@ -11,12 +11,7 @@ from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannChat, LeannSearcher
from .interactive_utils import create_cli_session from .interactive_utils import create_cli_session
from .registry import register_project_directory from .registry import register_project_directory
from .settings import ( from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
resolve_anthropic_base_url,
resolve_ollama_host,
resolve_openai_api_key,
resolve_openai_base_url,
)
def extract_pdf_text_with_pymupdf(file_path: str) -> str: def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -296,7 +291,7 @@ Examples:
"--llm", "--llm",
type=str, type=str,
default="ollama", default="ollama",
choices=["simulated", "ollama", "hf", "openai", "anthropic"], choices=["simulated", "ollama", "hf", "openai"],
help="LLM provider (default: ollama)", help="LLM provider (default: ollama)",
) )
ask_parser.add_argument( ask_parser.add_argument(
@@ -346,7 +341,7 @@ Examples:
"--api-key", "--api-key",
type=str, type=str,
default=None, default=None,
help="API key for cloud LLM providers (OpenAI, Anthropic)", help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
) )
# List command # List command
@@ -1621,12 +1616,6 @@ Examples:
resolved_api_key = resolve_openai_api_key(args.api_key) resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key: if resolved_api_key:
llm_config["api_key"] = resolved_api_key llm_config["api_key"] = resolved_api_key
elif args.llm == "anthropic":
# For Anthropic, pass base_url and API key if provided
if args.api_base:
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
if args.api_key:
llm_config["api_key"] = args.api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config) chat = LeannChat(index_path=index_path, llm_config=llm_config)

View File

@@ -9,7 +9,6 @@ from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place. # Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434" _DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1" _DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
def _clean_url(value: str) -> str: def _clean_url(value: str) -> str:
@@ -53,23 +52,6 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
return _clean_url(_DEFAULT_OPENAI_BASE_URL) return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for Anthropic-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
os.getenv("ANTHROPIC_BASE_URL"),
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None: def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services.""" """Resolve the API key for OpenAI-compatible services."""
@@ -79,15 +61,6 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
return os.getenv("OPENAI_API_KEY") return os.getenv("OPENAI_API_KEY")
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for Anthropic services."""
if explicit:
return explicit
return os.getenv("ANTHROPIC_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None: def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes.""" """Serialize provider options for child processes."""

View File

@@ -53,11 +53,6 @@ leann build my-project --docs $(git ls-files)
# Start Claude Code # Start Claude Code
claude claude
``` ```
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
```bash
leann build my-project --docs $(git ls-files) --no-recompute
```
## 🚀 Advanced Usage Examples to build the index ## 🚀 Advanced Usage Examples to build the index