Compare commits

..

2 Commits

Author SHA1 Message Date
yichuan-w
5be0c144ad fix readme 2025-10-08 21:38:55 +00:00
yichuan-w
3ec5e8d035 gitignore 2025-10-08 21:23:29 +00:00
38 changed files with 4287 additions and 10898 deletions

6
.gitignore vendored
View File

@@ -91,8 +91,7 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json
*.npy
*.db
batchtest.py
tests/__pytest_cache__/
tests/__pycache__/
@@ -106,6 +105,3 @@ apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weavia
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
# AUR build directory (Arch Linux)
paru-bin/

View File

@@ -8,12 +8,8 @@
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q">
<img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
</a>
<a href="assets/wechat_user_group.JPG" title="Join WeChat group">
<img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group">
</a>
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
</p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
@@ -781,7 +777,7 @@ Once your iMessage conversations are indexed, you can search with queries like:
### MCP Integration: RAG on Live Data from Any Platform
Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
**NEW!** Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
**Key Benefits:**
- **Live Data Access**: Fetch real-time data without manual exports
@@ -805,17 +801,18 @@ python -m apps.slack_rag \
--query "What did we decide about the product launch?"
```
**📖 Comprehensive Setup Guide**: For detailed setup instructions, troubleshooting common issues (like "users cache is not ready yet"), and advanced configuration options, see our [**Slack Setup Guide**](docs/slack-setup-guide.md).
**Quick Setup:**
**Setup Requirements:**
1. Install a Slack MCP server (e.g., `npm install -g slack-mcp-server`)
2. Create a Slack App and get API credentials (see detailed guide above)
3. Set environment variables:
2. Create a Slack App and get API credentials:
- Go to [api.slack.com/apps](https://api.slack.com/apps) and create a new app
- Under "OAuth & Permissions", add these Bot Token Scopes: `channels:read`, `channels:history`, `groups:read`, `groups:history`, `im:read`, `im:history`, `mpim:read`, `mpim:history`
- Install the app to your workspace and copy the "Bot User OAuth Token" (starts with `xoxb-`)
- Under "App-Level Tokens", create a token with `connections:write` scope (starts with `xapp-`)
```bash
export SLACK_BOT_TOKEN="xoxb-your-bot-token"
export SLACK_APP_TOKEN="xapp-your-app-token" # Optional
export SLACK_APP_TOKEN="xapp-your-app-token"
```
4. Test connection with `--test-connection` flag
3. Test connection with `--test-connection` flag
**Arguments:**
- `--mcp-server`: Command to start the Slack MCP server
@@ -823,8 +820,6 @@ python -m apps.slack_rag \
- `--channels`: Specific channels to index (optional)
- `--concatenate-conversations`: Group messages by channel (default: true)
- `--max-messages-per-channel`: Limit messages per channel (default: 100)
- `--max-retries`: Maximum retries for cache sync issues (default: 5)
- `--retry-delay`: Initial delay between retries in seconds (default: 2.0)
#### 🐦 Twitter Bookmarks: Your Personal Tweet Library
@@ -863,7 +858,7 @@ python -m apps.twitter_rag \
- `--no-tweet-content`: Exclude tweet content, only metadata
- `--no-metadata`: Exclude engagement metadata
</details>
<!-- </details> -->
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
@@ -880,7 +875,7 @@ python -m apps.twitter_rag \
- "Show me bookmarked threads about startup advice"
- "What Python tutorials did I save?"
</details>
<details>
<summary><strong>🔧 Using MCP with CLI Commands</strong></summary>
**Want to use MCP data with regular LEANN CLI?** You can combine MCP apps with CLI commands:
@@ -926,7 +921,7 @@ Want to add support for other platforms? LEANN's MCP integration is designed for
### 🚀 Claude Code Integration: Transform Your Development Workflow!
<details>
<summary><strong>ASTAware Code Chunking</strong></summary>
<summary><strong>NEW!! ASTAware Code Chunking</strong></summary>
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
@@ -1213,7 +1208,3 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
<p align="center">
Made with ❤️ by the Leann team
</p>
## 🤖 Explore LEANN with AI
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.

View File

@@ -10,39 +10,9 @@ from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
# Optional import: older PyPI builds may not include interactive_utils
try:
from leann.interactive_utils import create_rag_session
except ImportError:
def create_rag_session(app_name: str, data_description: str):
class _SimpleSession:
def run_interactive_loop(self, handler):
print(f"Interactive session for {app_name}: {data_description}")
print("Interactive mode not available in this build")
return _SimpleSession()
from leann.interactive_utils import create_rag_session
from leann.registry import register_project_directory
# Optional import: older PyPI builds may not include settings
try:
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
except ImportError:
# Minimal fallbacks if settings helpers are unavailable
import os
def resolve_ollama_host(value: str | None) -> str | None:
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
def resolve_openai_api_key(value: str | None) -> str | None:
return value or os.getenv("OPENAI_API_KEY")
def resolve_openai_base_url(value: str | None) -> str | None:
return value or os.getenv("OPENAI_BASE_URL")
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
dotenv.load_dotenv()
@@ -180,14 +150,14 @@ class BaseRAGExample(ABC):
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=300,
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
default=512,
help="Maximum characters per AST chunk (default: 512)",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
help="Overlap between AST chunks (default: 64)",
)
ast_group.add_argument(
"--code-file-extensions",

View File

@@ -12,7 +12,6 @@ from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
@@ -26,7 +25,6 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
@@ -38,7 +36,6 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
__all__ = [
"CODE_EXTENSIONS",
"_traditional_chunks_as_dicts",
"create_ast_chunks",
"create_text_chunks",
"create_traditional_chunks",

View File

@@ -1,132 +0,0 @@
#!/usr/bin/env python3
"""Simple test script to test colqwen2 forward pass with a single image."""
import os
import sys
from pathlib import Path
# Add the current directory to path to import leann_multi_vector
sys.path.insert(0, str(Path(__file__).parent))
import torch
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
from PIL import Image
# Ensure repo paths are importable
_ensure_repo_paths_importable(__file__)
# Set environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def create_test_image():
"""Create a simple test image."""
# Create a simple RGB image (800x600)
img = Image.new("RGB", (800, 600), color="white")
return img
def load_test_image_from_file():
"""Try to load an image from the indexes directory if available."""
# Try to find an existing image in the indexes directory
indexes_dir = Path(__file__).parent / "indexes"
# Look for images in common locations
possible_paths = [
indexes_dir / "vidore_fastplaid" / "images",
indexes_dir / "colvision_large.leann.images",
indexes_dir / "colvision.leann.images",
]
for img_dir in possible_paths:
if img_dir.exists():
# Find first image file
for ext in [".png", ".jpg", ".jpeg"]:
for img_file in img_dir.glob(f"*{ext}"):
print(f"Loading test image from: {img_file}")
return Image.open(img_file)
return None
def main():
print("=" * 60)
print("Testing ColQwen2 Forward Pass")
print("=" * 60)
# Step 1: Load or create test image
print("\n[Step 1] Loading test image...")
test_image = load_test_image_from_file()
if test_image is None:
print("No existing image found, creating a simple test image...")
test_image = create_test_image()
else:
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
# Convert to RGB if needed
if test_image.mode != "RGB":
test_image = test_image.convert("RGB")
print(f"✓ Converted to RGB: {test_image.size}")
# Step 2: Load model
print("\n[Step 2] Loading ColQwen2 model...")
try:
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
print(f"✓ Model loaded: {model_name}")
print(f"✓ Device: {device_str}, dtype: {dtype}")
# Print model info
if hasattr(model, "device"):
print(f"✓ Model device: {model.device}")
if hasattr(model, "dtype"):
print(f"✓ Model dtype: {model.dtype}")
except Exception as e:
print(f"✗ Error loading model: {e}")
import traceback
traceback.print_exc()
return
# Step 3: Test forward pass
print("\n[Step 3] Running forward pass...")
try:
# Use the _embed_images function which handles batching and forward pass
images = [test_image]
print(f"Processing {len(images)} image(s)...")
doc_vecs = _embed_images(model, processor, images)
print("✓ Forward pass completed!")
print(f"✓ Number of embeddings: {len(doc_vecs)}")
if len(doc_vecs) > 0:
emb = doc_vecs[0]
print(f"✓ Embedding shape: {emb.shape}")
print(f"✓ Embedding dtype: {emb.dtype}")
print("✓ Embedding stats:")
print(f" - Min: {emb.min().item():.4f}")
print(f" - Max: {emb.max().item():.4f}")
print(f" - Mean: {emb.mean().item():.4f}")
print(f" - Std: {emb.std().item():.4f}")
# Check for NaN or Inf
if torch.isnan(emb).any():
print("⚠ Warning: Embedding contains NaN values!")
if torch.isinf(emb).any():
print("⚠ Warning: Embedding contains Inf values!")
except Exception as e:
print(f"✗ Error during forward pass: {e}")
import traceback
traceback.print_exc()
return
print("\n" + "=" * 60)
print("Test completed successfully!")
print("=" * 60)
if __name__ == "__main__":
main()

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,75 +1,41 @@
## Jupyter-style notebook script
# %%
# uv pip install matplotlib qwen_vl_utils
import argparse
import faulthandler
import os
import time
from typing import Any, Optional
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
import numpy as np
from PIL import Image
from tqdm import tqdm
# Enable faulthandler to get stack trace on segfault
faulthandler.enable()
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import ( # utility functions/classes
_ensure_repo_paths_importable,
_load_images_from_dir,
_maybe_convert_pdf_to_images,
_load_colvision,
_embed_images,
_embed_queries,
_build_index,
_load_retriever_if_index_exists,
_generate_similarity_map,
_build_fast_plaid_index,
_load_fast_plaid_index_if_exists,
_search_fast_plaid,
_get_fast_plaid_image,
_get_fast_plaid_metadata,
QwenVL,
)
_ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True
# Single dataset name (used when DATASET_NAMES is None)
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
# Multiple datasets to combine (if provided, DATASET_NAME is ignored)
# Can be:
# - List of strings: ["dataset1", "dataset2"]
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
# - Mixed: ["dataset1", ("dataset2", "config2")]
#
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
# - "pixparse/arxiv-papers" (if available, arXiv papers)
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
# - "huggingface/document-images" (if available)
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
DATASET_NAMES = [
"weaviate/arXiv-AI-papers-multi-vector",
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
]
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
# Set to None to try loading all available splits automatically
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
# Image field name in the dataset (auto-detect if None)
# Common names: "page_image", "image", "images", "img"
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
DATASET_SPLIT: str = "train"
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False)
@@ -77,13 +43,10 @@ PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages"
# Index + retrieval settings
# Use a different index path for larger dataset to avoid overwriting existing index
INDEX_PATH: str = "./indexes/colvision_large.leann"
# Fast-Plaid index settings (alternative to LEANN index)
# These are now command-line arguments (see CLI overrides section)
TOPK: int = 3
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = True
REBUILD_INDEX: bool = False
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
@@ -91,517 +54,367 @@ SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 1024
MAX_NEW_TOKENS: int = 128
# %%
# CLI overrides
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
parser.add_argument(
"--search-method",
type=str,
choices=["ann", "exact", "exact-all"],
default="ann",
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
)
parser.add_argument(
"--query",
type=str,
default=QUERY,
help=f"Query string to search for. Default: '{QUERY}'",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
default=False,
help="Set to True to use fast-plaid instead of LEANN. Default: False",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default="./indexes/colvision_fastplaid",
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
)
parser.add_argument(
"--topk",
type=int,
default=TOPK,
help=f"Number of top results to retrieve. Default: {TOPK}",
)
cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
# %%
# Step 1: Check if we can skip data loading (index already exists)
retriever: Optional[Any] = None
fast_plaid_index: Optional[Any] = None
need_to_build_index = REBUILD_INDEX
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
if USE_FAST_PLAID:
# Fast-Plaid index handling
if not REBUILD_INDEX:
try:
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
if fast_plaid_index is not None:
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
need_to_build_index = False
else:
print(f"Fast-Plaid index not found, will build new index")
need_to_build_index = True
except Exception as e:
# If loading fails (e.g., memory error, corrupted index), rebuild
print(f"Warning: Failed to load Fast-Plaid index: {e}")
print("Will rebuild the index...")
need_to_build_index = True
fast_plaid_index = None
else:
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
need_to_build_index = True
else:
# Original LEANN index handling
if not REBUILD_INDEX:
retriever = _load_retriever_if_index_exists(INDEX_PATH)
if retriever is not None:
print(f"✓ Index loaded from {INDEX_PATH}")
print(f"✓ Images available at: {retriever._images_dir_path()}")
need_to_build_index = False
else:
print(f"Index not found, will build new index")
need_to_build_index = True
else:
print(f"REBUILD_INDEX=True, will rebuild index")
need_to_build_index = True
# Step 2: Load data only if we need to build the index
if need_to_build_index:
print("Loading dataset...")
if USE_HF_DATASET:
from datasets import load_dataset, concatenate_datasets, DatasetDict
# Determine which datasets to load
if DATASET_NAMES is not None:
dataset_names_to_load = DATASET_NAMES
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
else:
dataset_names_to_load = [DATASET_NAME]
print(f"Loading single dataset: {DATASET_NAME}")
# Load and combine datasets
all_datasets_to_concat = []
for dataset_entry in dataset_names_to_load:
# Handle both string and tuple formats
if isinstance(dataset_entry, tuple):
dataset_name, config_name = dataset_entry
else:
dataset_name = dataset_entry
config_name = None
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
# Load dataset to check available splits
# If config_name is provided, use it; otherwise try without config
try:
if config_name:
dataset_dict = load_dataset(dataset_name, config_name)
else:
dataset_dict = load_dataset(dataset_name)
except ValueError as e:
if "Config name is missing" in str(e):
# Try to get available configs and suggest
from datasets import get_dataset_config_names
try:
available_configs = get_dataset_config_names(dataset_name)
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Available configs: {available_configs}. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
except Exception:
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
raise
# Determine which splits to load
if DATASET_SPLITS is None:
# Auto-detect: try to load all available splits
available_splits = list(dataset_dict.keys())
print(f" Auto-detected splits: {available_splits}")
splits_to_load = available_splits
else:
splits_to_load = DATASET_SPLITS
# Load and concatenate multiple splits for this dataset
datasets_to_concat = []
for split in splits_to_load:
if split not in dataset_dict:
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
continue
split_dataset = dataset_dict[split]
print(f" Loaded split '{split}': {len(split_dataset)} pages")
datasets_to_concat.append(split_dataset)
if not datasets_to_concat:
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
continue
# Concatenate splits for this dataset
if len(datasets_to_concat) > 1:
combined_dataset = concatenate_datasets(datasets_to_concat)
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
else:
combined_dataset = datasets_to_concat[0]
all_datasets_to_concat.append(combined_dataset)
if not all_datasets_to_concat:
raise RuntimeError("No valid datasets or splits found.")
# Concatenate all datasets
if len(all_datasets_to_concat) > 1:
dataset = concatenate_datasets(all_datasets_to_concat)
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
else:
dataset = all_datasets_to_concat[0]
# Apply MAX_DOCS limit if specified
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
if N < len(dataset):
print(f"Limiting to {N} pages (from {len(dataset)} total)")
dataset = dataset.select(range(N))
# Auto-detect image field name if not specified
if IMAGE_FIELD_NAME is None:
# Check multiple samples to find the most common image field
# (useful when datasets are merged and may have different field names)
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
field_counts = {}
# Check first few samples to find image fields
num_samples_to_check = min(10, len(dataset))
for sample_idx in range(num_samples_to_check):
sample = dataset[sample_idx]
for field in possible_image_fields:
if field in sample and sample[field] is not None:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
field_counts[field] = field_counts.get(field, 0) + 1
# Choose the most common field, or first found if tied
if field_counts:
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
else:
# Fallback: check first sample only
sample = dataset[0]
image_field = None
for field in possible_image_fields:
if field in sample:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
image_field = field
break
if image_field is None:
raise RuntimeError(
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
f"Please specify IMAGE_FIELD_NAME manually."
)
print(f"Auto-detected image field: '{image_field}'")
else:
image_field = IMAGE_FIELD_NAME
if image_field not in dataset[0]:
raise RuntimeError(
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
)
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
p = dataset[i]
# Try to compose a descriptive identifier
# Handle different dataset structures
identifier_parts = []
# Helper function to safely get field value
def safe_get(field_name, default=None):
if field_name in p and p[field_name] is not None:
return p[field_name]
return default
# Try to get various identifier fields
if safe_get("paper_arxiv_id"):
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
if safe_get("paper_title"):
identifier_parts.append(f"title:{p['paper_title']}")
if safe_get("page_number") is not None:
try:
identifier_parts.append(f"page:{int(p['page_number'])}")
except (ValueError, TypeError):
# If conversion fails, use the raw value or skip
if p['page_number']:
identifier_parts.append(f"page:{p['page_number']}")
if safe_get("page_id"):
identifier_parts.append(f"id:{p['page_id']}")
elif safe_get("questionId"):
identifier_parts.append(f"qid:{p['questionId']}")
elif safe_get("docId"):
identifier_parts.append(f"docId:{p['docId']}")
elif safe_get("id"):
identifier_parts.append(f"id:{p['id']}")
# If no identifier parts found, create one from index
if identifier_parts:
identifier = "|".join(identifier_parts)
else:
# Create identifier from available fields or index
fallback_parts = []
# Try common fields that might exist
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
if safe_get(field):
fallback_parts.append(f"{field}:{p[field]}")
break
if fallback_parts:
identifier = "|".join(fallback_parts) + f"|idx:{i}"
else:
identifier = f"doc_{i}"
filepaths.append(identifier)
# Get image - try detected field first, then fallback to other common fields
img = None
if image_field in p and p[image_field] is not None:
img = p[image_field]
else:
# Fallback: try other common image field names
for fallback_field in ["image", "page_image", "images", "img"]:
if fallback_field in p and p[fallback_field] is not None:
img = p[fallback_field]
break
if img is None:
raise RuntimeError(
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
f"Expected field: {image_field}"
)
# Ensure it's a PIL Image
if not isinstance(img, Image.Image):
if hasattr(img, 'convert'):
img = img.convert('RGB')
else:
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
images.append(img)
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print(f"Loaded {len(images)} images")
# Memory check before loading model
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
import psutil
import torch
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
if torch.cuda.is_available():
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
except ImportError:
pass
else:
print("Skipping dataset loading (using existing index)")
filepaths = [] # Not needed when using existing index
images = [] # Not needed when using existing index
# %%
# Step 3: Load model and processor (only if we need to build index or perform search)
print("Step 3: Loading model and processor...")
print(f" Model: {MODEL}")
try:
import sys
print(f" Python version: {sys.version}")
print(f" Python executable: {sys.executable}")
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
# Memory check after loading model
try:
import psutil
import torch
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
if torch.cuda.is_available():
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
except ImportError:
pass
except Exception as e:
print(f"✗ Error loading model: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
# %%
# %%
# Step 4: Build index if needed
if need_to_build_index:
print("Step 4: Building index...")
print(f" Number of images: {len(images)}")
print(f" Number of filepaths: {len(filepaths)}")
try:
print(" Embedding images...")
doc_vecs = _embed_images(model, processor, images)
print(f" Embedded {len(doc_vecs)} documents")
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
from pdf2image import convert_from_path
except Exception as e:
print(f"Error embedding images: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
if USE_FAST_PLAID:
# Build Fast-Plaid index
print(" Building Fast-Plaid index...")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
fast_plaid_index, build_secs = _build_fast_plaid_index(
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
)
from pathlib import Path
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
except Exception as e:
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Clear memory
print(" Clearing memory...")
del images, filepaths, doc_vecs
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
# Build original LEANN index
try:
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
except Exception as e:
print(f"Error building LEANN index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Clear memory
print(" Clearing memory...")
del images, filepaths, doc_vecs
dtype = torch.float32
return device_str, device, dtype
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# %%
# Step 5: Embed query and search
_t0 = time.perf_counter()
q_vec = _embed_queries(model, processor, [QUERY])[0]
query_embed_secs = time.perf_counter() - _t0
print(f"[Search] Method: {SEARCH_METHOD}")
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
# Step 1: Prepare data
if USE_HF_DATASET:
from datasets import load_dataset
# Run the selected search method and time it
if USE_FAST_PLAID:
# Fast-Plaid search
if fast_plaid_index is None:
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
if fast_plaid_index is None:
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N ):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
# Original LEANN search
query_np = q_vec.float().numpy()
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
if SEARCH_METHOD == "ann":
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact":
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact-all":
results = retriever.search_exact_all(query_np, topk=TOPK)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
else:
results = []
# %%
# Step 2: Load model and processor
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
try:
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None:
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
# %%
# Step 4: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
print("\n[DEBUG] Retrieval details:")
top_images: list[Image.Image] = []
image_hashes = {} # Track image hashes to detect duplicates
for rank, (score, doc_id) in enumerate(results, start=1):
# Retrieve image and metadata based on index type
if USE_FAST_PLAID:
# Fast-Plaid: load image and get metadata
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
if image is None:
print(f"Warning: Could not find image for doc_id {doc_id}")
continue
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
top_images.append(image)
else:
# Original LEANN: retrieve from retriever
image = retriever.get_image(doc_id)
if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
continue
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
top_images.append(image)
# Calculate image hash to detect duplicates
import hashlib
import io
# Convert image to bytes for hashing
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
image_bytes = img_bytes.getvalue()
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
# Check if this image was already seen
duplicate_info = ""
if image_hash in image_hashes:
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
else:
image_hashes[image_hash] = rank
# Print detailed information
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
if metadata:
print(f" Metadata: {metadata}")
path = filepaths[doc_id]
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(images[doc_id])
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
@@ -614,17 +427,12 @@ else:
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
# Print the retrieval score (document-level MaxSim) alongside the saved path
try:
score, _doc_id = results[rank - 1]
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 6: Similarity maps for top-K results
# Step 5: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
@@ -661,12 +469,9 @@ if results and SIMILARITY_MAP:
# %%
# Step 7: Optional answer generation
# Step 6: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
_t0 = time.perf_counter()
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
gen_secs = time.perf_counter() - _t0
print(f"[Timing] Generation: {gen_secs:.3f}s")
print("\nAnswer:")
print(response)

View File

@@ -1,399 +0,0 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v1 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v1 tasks
python vidore_v1_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
# Use Fast-Plaid index
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
from typing import Optional
from datasets import load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
)
_ensure_repo_paths_importable(__file__)
# ViDoRe v1 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V1_TASKS = {
"VidoreArxivQARetrieval": {
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreDocVQARetrieval": {
"dataset_path": "vidore/docvqa_test_subsampled_beir",
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreInfoVQARetrieval": {
"dataset_path": "vidore/infovqa_test_subsampled_beir",
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreTabfquadRetrieval": {
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreTatdqaRetrieval": {
"dataset_path": "vidore/tatdqa_test_beir",
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreShiftProjectRetrieval": {
"dataset_path": "vidore/shiftproject_test_beir",
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAAIRetrieval": {
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAEnergyRetrieval": {
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
def load_vidore_v1_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
):
"""
Load ViDoRe v1 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
}
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 1000,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v1 task.
"""
print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}")
print(f"{'=' * 80}")
# Get task config
if task_name not in VIDORE_V1_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
task_config = VIDORE_V1_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Load data
corpus, queries, qrels = load_vidore_v1_data(
dataset_path=dataset_path,
revision=revision,
split="test",
)
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 20, 100, 1000]
# Check if we have any queries
if len(queries) == 0:
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name,
use_fast_plaid=use_fast_plaid,
top_k=top_k,
first_stage_k=first_stage_k,
k_values=k_values,
)
# Build or load index
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus,
index_path=index_path_full,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
index_or_retriever=index_or_retriever,
fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'=' * 80}")
print(f"Results for {task_name}:")
print(f"{'=' * 80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
choices=["colqwen2", "colpali"],
help="Model to use",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--top-k",
type=int,
default=1000,
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,20,100,1000",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v1_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [args.task]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
else:
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'=' * 80}")
print("SUMMARY")
print(f"{'=' * 80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()

View File

@@ -1,439 +0,0 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v2 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v2 tasks
python vidore_v2_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
# Use Fast-Plaid index
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
from typing import Optional
from datasets import load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
)
_ensure_repo_paths_importable(__file__)
# Language name to dataset language field value mapping
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
LANGUAGE_MAPPING = {
"english": "eng-Latn",
"french": "fra-Latn",
"spanish": "spa-Latn",
"german": "deu-Latn",
}
# ViDoRe v2 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V2_TASKS = {
"Vidore2ESGReportsRetrieval": {
"dataset_path": "vidore/esg_reports_v2",
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2EconomicsReportsRetrieval": {
"dataset_path": "vidore/economics_reports_v2",
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2BioMedicalLecturesRetrieval": {
"dataset_path": "vidore/biomedical_lectures_v2",
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2ESGReportsHLRetrieval": {
"dataset_path": "vidore/esg_reports_human_labeled_v2",
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
# Note: This dataset doesn't have language filtering - all queries are English
"languages": None, # No language filtering needed
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
def load_vidore_v2_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
language: Optional[str] = None,
):
"""
Load ViDoRe v2 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
# Check if dataset has language field before filtering
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
if language and has_language_field:
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
dataset_language = LANGUAGE_MAPPING.get(language, language)
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
# Check if filtering resulted in empty dataset
if len(query_ds_filtered) == 0:
print(
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
)
# Try with original language value (dataset might use simple names like 'english')
print(f"Trying with original language value '{language}'...")
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
if len(query_ds_filtered) == 0:
# Try to get a sample to see actual language values
try:
sample_ds = load_dataset(
dataset_path, "queries", split=split, revision=revision
)
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"])
print(f"Available language values in dataset: {sample_langs}")
except Exception:
pass
else:
print(
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
)
query_ds = query_ds_filtered
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
}
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
language: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 100,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v2 task.
"""
print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}")
print(f"{'=' * 80}")
# Get task config
if task_name not in VIDORE_V2_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
task_config = VIDORE_V2_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Determine language
if language is None:
# Use first language if multiple available
languages = task_config.get("languages")
if languages is None:
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
language = None
elif len(languages) == 1:
language = languages[0]
else:
language = None
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 100]
# Load data
corpus, queries, qrels = load_vidore_v2_data(
dataset_path=dataset_path,
revision=revision,
split="test",
language=language,
)
# Check if we have any queries
if len(queries) == 0:
print(
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
)
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name,
use_fast_plaid=use_fast_plaid,
top_k=top_k,
first_stage_k=first_stage_k,
k_values=k_values,
)
# Build or load index
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus,
index_path=index_path_full,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
index_or_retriever=index_or_retriever,
fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'=' * 80}")
print(f"Results for {task_name}:")
print(f"{'=' * 80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
choices=["colqwen2", "colpali"],
help="Model to use",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--language",
type=str,
default=None,
help="Language to evaluate (default: first available)",
)
parser.add_argument(
"--top-k",
type=int,
default=100,
help="Top-k results to retrieve",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,100",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v2_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [args.task]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
else:
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
language=args.language,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'=' * 80}")
print("SUMMARY")
print(f"{'=' * 80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()

View File

@@ -29,8 +29,6 @@ class SlackMCPReader:
workspace_name: Optional[str] = None,
concatenate_conversations: bool = True,
max_messages_per_conversation: int = 100,
max_retries: int = 5,
retry_delay: float = 2.0,
):
"""
Initialize the Slack MCP Reader.
@@ -40,15 +38,11 @@ class SlackMCPReader:
workspace_name: Optional workspace name to filter messages
concatenate_conversations: Whether to group messages by channel/thread
max_messages_per_conversation: Maximum messages to include per conversation
max_retries: Maximum number of retries for failed operations
retry_delay: Initial delay between retries in seconds
"""
self.mcp_server_command = mcp_server_command
self.workspace_name = workspace_name
self.concatenate_conversations = concatenate_conversations
self.max_messages_per_conversation = max_messages_per_conversation
self.max_retries = max_retries
self.retry_delay = retry_delay
self.mcp_process = None
async def start_mcp_server(self):
@@ -116,73 +110,11 @@ class SlackMCPReader:
return response.get("result", {}).get("tools", [])
def _is_cache_sync_error(self, error: dict) -> bool:
"""Check if the error is related to users cache not being ready."""
if isinstance(error, dict):
message = error.get("message", "").lower()
return (
"users cache is not ready" in message or "sync process is still running" in message
)
return False
async def _retry_with_backoff(self, func, *args, **kwargs):
"""Retry a function with exponential backoff, especially for cache sync issues."""
last_exception = None
for attempt in range(self.max_retries + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# Check if this is a cache sync error
error_dict = {}
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
error_dict = e.args[0]
elif "Failed to fetch messages" in str(e):
# Try to extract error from the exception message
import re
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = eval(match.group(1))
except (ValueError, SyntaxError, NameError):
pass
else:
# Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = eval(match.group(1))
except (ValueError, SyntaxError, NameError):
pass
if self._is_cache_sync_error(error_dict):
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt) # Exponential backoff
logger.info(
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
)
await asyncio.sleep(delay)
continue
else:
logger.warning(
f"Cache sync still not ready after {self.max_retries} retries, giving up"
)
break
else:
# Not a cache sync error, don't retry
break
# If we get here, all retries failed or it's not a retryable error
raise last_exception
async def fetch_slack_messages(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
Fetch Slack messages using MCP tools.
Args:
channel: Optional channel name to filter messages
@@ -191,59 +123,32 @@ class SlackMCPReader:
Returns:
List of message dictionaries
"""
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
async def _fetch_slack_messages_impl(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Internal implementation of fetch_slack_messages without retry logic.
"""
# This is a generic implementation - specific MCP servers may have different tool names
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
tools = await self.list_available_tools()
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
message_tool = None
# Look for a tool that can fetch messages - prioritize conversations_history
message_tool = None
# First, try to find conversations_history specifically
# Look for a tool that can fetch messages
for tool in tools:
tool_name = tool.get("name", "").lower()
if "conversations_history" in tool_name:
if any(
keyword in tool_name
for keyword in ["message", "history", "channel", "conversation"]
):
message_tool = tool
logger.info(f"Found conversations_history tool: {tool}")
break
# If not found, look for other message-fetching tools
if not message_tool:
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(
keyword in tool_name
for keyword in ["conversations_search", "message", "history"]
):
message_tool = tool
break
if not message_tool:
raise RuntimeError("No message fetching tool found in MCP server")
# Prepare tool call parameters
tool_params = {"limit": "180d"} # Use 180 days to get older messages
tool_params = {"limit": limit}
if channel:
# For conversations_history, use channel_id parameter
if message_tool["name"] == "conversations_history":
tool_params["channel_id"] = channel
else:
# Try common parameter names for channel specification
for param_name in ["channel", "channel_id", "channel_name"]:
tool_params[param_name] = channel
break
logger.info(f"Tool parameters: {tool_params}")
# Try common parameter names for channel specification
for param_name in ["channel", "channel_id", "channel_name"]:
tool_params[param_name] = channel
break
fetch_request = {
"jsonrpc": "2.0",
@@ -265,8 +170,8 @@ class SlackMCPReader:
try:
messages = json.loads(content["text"])
except json.JSONDecodeError:
# If not JSON, try to parse as CSV format (Slack MCP server format)
messages = self._parse_csv_messages(content["text"], channel)
# If not JSON, treat as plain text
messages = [{"text": content["text"], "channel": channel or "unknown"}]
else:
messages = result["content"]
else:
@@ -275,56 +180,6 @@ class SlackMCPReader:
return messages if isinstance(messages, list) else [messages]
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
"""Parse CSV format messages from Slack MCP server."""
import csv
import io
messages = []
try:
# Split by lines and process each line as a CSV row
lines = csv_text.strip().split("\n")
if not lines:
return messages
# Skip header line if it exists
start_idx = 0
if lines[0].startswith("MsgID,UserID,UserName"):
start_idx = 1
for line in lines[start_idx:]:
if not line.strip():
continue
# Parse CSV line
reader = csv.reader(io.StringIO(line))
try:
row = next(reader)
if len(row) >= 7: # Ensure we have enough columns
message = {
"ts": row[0],
"user": row[1],
"username": row[2],
"real_name": row[3],
"channel": row[4],
"thread_ts": row[5],
"text": row[6],
"time": row[7] if len(row) > 7 else "",
"reactions": row[8] if len(row) > 8 else "",
"cursor": row[9] if len(row) > 9 else "",
}
messages.append(message)
except Exception as e:
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
continue
except Exception as e:
logger.warning(f"Failed to parse CSV messages: {e}")
# Fallback: treat entire text as one message
messages = [{"text": csv_text, "channel": channel or "unknown"}]
return messages
def _format_message(self, message: dict[str, Any]) -> str:
"""Format a single message for indexing."""
text = message.get("text", "")
@@ -396,40 +251,6 @@ class SlackMCPReader:
return "\n".join(content_parts)
async def get_all_channels(self) -> list[str]:
"""Get list of all available channels."""
try:
channels_list_request = {
"jsonrpc": "2.0",
"id": 4,
"method": "tools/call",
"params": {"name": "channels_list", "arguments": {}},
}
channels_response = await self.send_mcp_request(channels_list_request)
if "result" in channels_response:
result = channels_response["result"]
if "content" in result and isinstance(result["content"], list):
content = result["content"][0] if result["content"] else {}
if "text" in content:
# Parse the channels from the response
channels = []
lines = content["text"].split("\n")
for line in lines:
if line.strip() and ("#" in line or "C" in line[:10]):
# Extract channel ID or name
parts = line.split()
for part in parts:
if part.startswith("C") and len(part) > 5:
channels.append(part)
elif part.startswith("#"):
channels.append(part[1:]) # Remove #
logger.info(f"Found {len(channels)} channels: {channels}")
return channels
return []
except Exception as e:
logger.warning(f"Failed to get channels list: {e}")
return []
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
"""
Read Slack data and return formatted text chunks.
@@ -466,33 +287,36 @@ class SlackMCPReader:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
else:
# Fetch from all available channels
logger.info("Fetching from all available channels...")
all_channels = await self.get_all_channels()
# Fetch from all available channels/conversations
# This is a simplified approach - real implementation would need to
# discover available channels first
try:
messages = await self.fetch_slack_messages(limit=1000)
if messages:
# Group messages by channel if concatenating
if self.concatenate_conversations:
channel_messages = {}
for message in messages:
channel = message.get(
"channel", message.get("channel_name", "general")
)
if channel not in channel_messages:
channel_messages[channel] = []
channel_messages[channel].append(message)
if not all_channels:
# Fallback to common channel names if we can't get the list
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
logger.info(f"Using fallback channels: {all_channels}")
for channel in all_channels:
try:
logger.info(f"Searching channel: {channel}")
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
if messages:
if self.concatenate_conversations:
text_content = self._create_concatenated_content(messages, channel)
# Create concatenated content for each channel
for channel, msgs in channel_messages.items():
text_content = self._create_concatenated_content(msgs, channel)
if text_content.strip():
all_texts.append(text_content)
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.error(f"Failed to fetch messages: {e}")
return all_texts

View File

@@ -78,20 +78,6 @@ class SlackMCPRAG(BaseRAGExample):
help="Test MCP server connection and list available tools without indexing",
)
parser.add_argument(
"--max-retries",
type=int,
default=5,
help="Maximum number of retries for failed operations (default: 5)",
)
parser.add_argument(
"--retry-delay",
type=float,
default=2.0,
help="Initial delay between retries in seconds (default: 2.0)",
)
async def test_mcp_connection(self, args) -> bool:
"""Test the MCP server connection and display available tools."""
print(f"Testing connection to MCP server: {args.mcp_server}")
@@ -102,14 +88,12 @@ class SlackMCPRAG(BaseRAGExample):
workspace_name=args.workspace_name,
concatenate_conversations=not args.no_concatenate_conversations,
max_messages_per_conversation=args.max_messages_per_channel,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
async with reader:
tools = await reader.list_available_tools()
print("Successfully connected to MCP server!")
print("\nSuccessfully connected to MCP server!")
print(f"Available tools ({len(tools)}):")
for i, tool in enumerate(tools, 1):
@@ -131,7 +115,7 @@ class SlackMCPRAG(BaseRAGExample):
return True
except Exception as e:
print(f"Failed to connect to MCP server: {e}")
print(f"\nFailed to connect to MCP server: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure the MCP server is installed and accessible")
print("2. Check if the server command is correct")
@@ -146,11 +130,8 @@ class SlackMCPRAG(BaseRAGExample):
if args.workspace_name:
print(f"Workspace: {args.workspace_name}")
# Filter out empty strings from channels
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
if channels:
print(f"Channels: {', '.join(channels)}")
if args.channels:
print(f"Channels: {', '.join(args.channels)}")
else:
print("Fetching from all available channels")
@@ -165,20 +146,18 @@ class SlackMCPRAG(BaseRAGExample):
workspace_name=args.workspace_name,
concatenate_conversations=concatenate,
max_messages_per_conversation=args.max_messages_per_channel,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
texts = await reader.read_slack_data(channels=channels)
texts = await reader.read_slack_data(channels=args.channels)
if not texts:
print("No messages found! This could mean:")
print("No messages found! This could mean:")
print("- The MCP server couldn't fetch messages")
print("- The specified channels don't exist or are empty")
print("- Authentication issues with the Slack workspace")
return []
print(f"Successfully loaded {len(texts)} text chunks from Slack")
print(f"Successfully loaded {len(texts)} text chunks from Slack")
# Show sample of what was loaded
if texts:
@@ -191,7 +170,7 @@ class SlackMCPRAG(BaseRAGExample):
return texts
except Exception as e:
print(f"Error loading Slack data: {e}")
print(f"Error loading Slack data: {e}")
print("\nThis might be due to:")
print("- MCP server connection issues")
print("- Authentication problems")
@@ -209,7 +188,7 @@ class SlackMCPRAG(BaseRAGExample):
if not success:
return
print(
"MCP server is working! You can now run without --test-connection to start indexing."
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
)
return

View File

@@ -1,143 +0,0 @@
# Update Benchmarks
This directory hosts two benchmark suites that exercise LEANNs HNSW “update +
search” pipeline under different assumptions:
1. **RNG recompute latency** measure how random-neighbour pruning and cache
settings influence incremental `add()` latency when embeddings are fetched
over the ZMQ embedding server.
2. **Update strategy comparison** compare a fully sequential update pipeline
against an offline approach that keeps the graph static and fuses results.
Both suites build a non-compact, `is_recompute=True` index so that new
embeddings are pulled from the embedding server. Benchmark outputs are written
under `.leann/bench/` by default and appended to CSV files for later plotting.
## Benchmarks
### 1. HNSW RNG Recompute Benchmark
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
changes the forward / reverse RNG pruning flags and whether the embedding cache
is enabled:
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
| ---------------------------------- | ----------- | ----------- | ------------------- |
| `baseline` | Enabled | Enabled | Enabled |
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
For each scenario the script:
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
3. Appends the requested updates using the scenarios RNG flags.
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
timings before appending a row to the CSV output.
**Run:**
```bash
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
LEANN_LOG_LEVEL=INFO \
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--runs 1 \
--index-path .leann/bench/test.leann \
--initial-files data/PrideandPrejudice.txt \
--update-files data/huawei_pangu.md \
--max-initial 300 \
--max-updates 1 \
--add-timeout 120
```
**Output:**
- `benchmarks/update/bench_results.csv` per-scenario timing statistics
(including ms/passage) for each run.
- `.leann/bench/hnsw_server.log` detailed ZMQ/server logs (path controlled by
`LEANN_HNSW_LOG_PATH`).
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
### 2. Sequential vs. Offline Update Benchmark
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
same dataset:
- **Scenario A Sequential Update**
- Start an embedding server.
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
mutates the HNSW graph.
- After all inserts, run a search on the updated graph.
- Metrics recorded: update time (`add_total_s`), post-update search time
(`search_time_s`), combined total (`total_time_s`), and per-passage
latency.
- **Scenario B Offline Embedding + Concurrent Search**
- Stop Scenario As server and start a fresh embedding server.
- Spawn two threads: one generates embeddings for the new passages offline
(graph unchanged); the other computes the query embedding and searches the
existing graph.
- Merge offline similarities with the graph search results to emulate late
fusion, then report the merged topk preview.
- Metrics recorded: embedding time (`emb_time_s`), search time
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
**Run (both scenarios):**
```bash
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 \
--num-updates 1
```
You can pass `--only A` or `--only B` to run a single scenario. The script will
print timing summaries to stdout and append the results to CSV.
**Output:**
- `benchmarks/update/offline_vs_update.csv` per-scenario timing statistics for
Scenario A and B.
- Console output includes Scenario Bs merged topk preview for quick sanity
checks.
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
### 3. Visualisation
`plot_bench_results.py` combines the RNG benchmark and the update strategy
benchmark into a single two-panel plot.
**Run:**
```bash
uv run -m benchmarks.update.plot_bench_results \
--csv benchmarks/update/bench_results.csv \
--csv-right benchmarks/update/offline_vs_update.csv \
--out benchmarks/update/bench_latency_from_csv.png
```
**Options:**
- `--broken-y` Enable a broken Y-axis (default: true when appropriate).
- `--csv` RNG benchmark results CSV (left panel).
- `--csv-right` Update strategy results CSV (right panel).
- `--out` Output image path (PNG/PDF supported).
**Output:**
- `benchmarks/update/bench_latency_from_csv.png` visual comparison of the two
suites.
- `benchmarks/update/bench_latency_from_csv.pdf` PDF version, suitable for
slides/papers.
## Parameters & Environment
### Common CLI Flags
- `--max-initial` Number of initial passages used to seed the index.
- `--max-updates` / `--num-updates` Number of passages to treat as updates.
- `--index-path` Base path (without extension) where the LEANN index is stored.
- `--runs` Number of repetitions (RNG benchmark only).
### Environment Variables
- `LEANN_HNSW_LOG_PATH` File to receive embedding-server logs (optional).
- `LEANN_LOG_LEVEL` Logging verbosity (DEBUG/INFO/WARNING/ERROR).
- `CUDA_VISIBLE_DEVICES` Set to empty string if you want to force CPU
execution of the embedding model.
With these scripts you can easily replicate LEANNs update benchmarks, compare
multiple RNG strategies, and evaluate whether sequential updates or offline
fusion better match your latency/accuracy trade-offs.

View File

@@ -1,16 +0,0 @@
"""Benchmarks for LEANN update workflows."""
# Expose helper to locate repository root for other modules that need it.
from pathlib import Path
def find_repo_root() -> Path:
"""Return the project root containing pyproject.toml."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
return current.parents[1]
__all__ = ["find_repo_root"]

View File

@@ -1,804 +0,0 @@
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
embedding recomputation.
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
so that we build a non-compact ``is_recompute=True`` index, spin up the
standard HNSW embedding server, and measure how long incremental ``add`` takes
when RNG pruning is fully enabled vs. partially/fully disabled.
Example usage (run from the repo root; downloads the model on first run)::
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--index-path .leann/bench/leann-demo.leann \
--runs 1
You can tweak the input documents with ``--initial-files`` / ``--update-files``
if you want a larger or different workload, and change the embedding model via
``--model-name``.
"""
import argparse
import json
import logging
import os
import pickle
import re
import sys
import time
from pathlib import Path
from typing import Any
import msgpack
import numpy as np
import zmq
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
return [{"text": text, "metadata": {}} for text in paragraphs]
def benchmark_update_with_mode(
index_path: Path,
new_chunks: list[dict[str, Any]],
model_name: str,
embedding_mode: str,
distance_metric: str,
disable_forward_rng: bool,
disable_reverse_rng: bool,
server_port: int,
add_timeout: int,
ef_construction: int,
) -> tuple[float, float]:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
index_file = index_path.parent / f"{index_path.stem}.index"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in new_chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
model_name,
mode=embedding_mode,
is_build=False,
batch_size=16,
)
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
index.is_recompute = True
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
try:
storage_index.ntotal = index.ntotal
except AttributeError:
pass
try:
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
if ef_construction is not None:
index.hnsw.efConstruction = ef_construction
except AttributeError:
pass
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
logger.info(
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
disable_forward_rng,
disable_reverse_rng,
applied_forward,
applied_reverse,
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=server_port,
model_name=model_name,
embedding_mode=embedding_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
)
if not server_started:
raise RuntimeError("Failed to start embedding server.")
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
_warmup_embedding_server(actual_port)
total_start = time.time()
add_elapsed = 0.0
try:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("incremental add timed out")
if add_timeout > 0:
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(add_timeout)
add_start = time.time()
for i in range(embeddings.shape[0]):
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
add_elapsed = time.time() - add_start
if add_timeout > 0:
signal.alarm(0)
faiss.write_index(index, str(index_file))
finally:
server_manager.stop_server()
except TimeoutError:
raise
except Exception:
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_size)
with open(offset_file, "wb") as f:
pickle.dump(offset_map_backup, f)
raise
prune_hnsw_embeddings_inplace(str(index_file))
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
# Reset toggles so the index on disk returns to baseline behaviour.
try:
index.hnsw.set_disable_rng_during_add(False)
index.hnsw.set_disable_reverse_prune(False)
except AttributeError:
pass
faiss.write_index(index, str(index_file))
total_elapsed = time.time() - total_start
return total_elapsed, add_elapsed
def _total_zmq_nodes(log_path: Path) -> int:
if not log_path.exists():
return 0
with log_path.open("r", encoding="utf-8") as log_file:
text = log_file.read()
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
def _warmup_embedding_server(port: int) -> None:
"""Send a dummy REQ so the embedding server loads its model."""
ctx = zmq.Context()
try:
sock = ctx.socket(zmq.REQ)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.RCVTIMEO, 5000)
sock.setsockopt(zmq.SNDTIMEO, 5000)
sock.connect(f"tcp://127.0.0.1:{port}")
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
sock.send(payload)
try:
sock.recv()
except zmq.error.Again:
pass
finally:
sock.close()
ctx.term()
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/leann-demo.leann"),
help="Output index base path (without extension).",
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
help="Files used to build the initial index.",
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
help="Files appended during the benchmark.",
)
parser.add_argument(
"--runs", type=int, default=1, help="How many times to repeat each scenario."
)
parser.add_argument(
"--model-name",
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model used for build/update.",
)
parser.add_argument(
"--embedding-mode",
default="sentence-transformers",
help="Embedding mode passed to LeannBuilder/embedding server.",
)
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
help="Distance metric for HNSW backend.",
)
parser.add_argument(
"--ef-construction",
type=int,
default=200,
help="efConstruction setting for initial build.",
)
parser.add_argument(
"--server-port",
type=int,
default=5557,
help="Port for the real embedding server.",
)
parser.add_argument(
"--max-initial",
type=int,
default=300,
help="Optional cap on initial passages (after chunking).",
)
parser.add_argument(
"--max-updates",
type=int,
default=1,
help="Optional cap on update passages (after chunking).",
)
parser.add_argument(
"--add-timeout",
type=int,
default=900,
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
)
parser.add_argument(
"--plot-path",
type=Path,
default=Path("bench_latency.png"),
help="Where to save the latency bar plot.",
)
parser.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
)
parser.add_argument(
"--broken-y",
action="store_true",
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
)
parser.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
)
parser.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Where to append per-scenario results as CSV.",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
if not update_paragraphs:
raise ValueError("No update passages found; please provide --update-files with content.")
update_chunks = prepare_new_chunks(update_paragraphs)
ensure_index_dir(args.index_path)
scenarios = [
("baseline", False, False, True),
("no_cache_baseline", False, False, False),
("disable_forward_rng", True, False, True),
("disable_forward_and_reverse_rng", True, True, True),
]
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
log_path.parent.mkdir(parents=True, exist_ok=True)
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
# CSV setup
import csv
run_id = time.strftime("%Y%m%d-%H%M%S")
csv_fields = [
"run_id",
"scenario",
"cache_enabled",
"ef_construction",
"max_initial",
"max_updates",
"total_time_s",
"add_only_s",
"latency_ms_per_passage",
"zmq_nodes",
"stageA_time_s",
"stageBC_time_s",
"model_name",
"embedding_mode",
"distance_metric",
]
# Create CSV with header if missing
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
for run in range(args.runs):
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
print(f"\nScenario: {name}")
cleanup_index_files(args.index_path)
if log_path.exists():
try:
log_path.unlink()
except OSError:
pass
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
prev_size = log_path.stat().st_size if log_path.exists() else 0
try:
total_elapsed, add_elapsed = benchmark_update_with_mode(
args.index_path,
update_chunks,
args.model_name,
args.embedding_mode,
args.distance_metric,
disable_forward,
disable_reverse,
args.server_port,
args.add_timeout,
args.ef_construction,
)
except TimeoutError as exc:
print(f"Scenario {name} timed out: {exc}")
continue
curr_size = log_path.stat().st_size if log_path.exists() else 0
if curr_size < prev_size:
prev_size = 0
zmq_count = 0
if log_path.exists():
with log_path.open("r", encoding="utf-8") as log_file:
log_file.seek(prev_size)
new_entries = log_file.read()
zmq_count = sum(
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
)
stageA = sum(
float(x)
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
)
stageBC = sum(
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
)
else:
stageA = 0.0
stageBC = 0.0
per_chunk = add_elapsed / len(update_chunks)
print(
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
)
print(f"ZMQ node fetch total: {zmq_count}")
results_total[name].append(total_elapsed)
results_add[name].append(add_elapsed)
results_zmq[name].append(zmq_count)
results_ms_per_passage[name].append(per_chunk * 1e3)
results_stageA[name].append(stageA)
results_stageBC[name].append(stageBC)
# Append row to CSV
if args.csv_path:
row = {
"run_id": run_id,
"scenario": name,
"cache_enabled": 1 if cache_enabled else 0,
"ef_construction": args.ef_construction,
"max_initial": args.max_initial,
"max_updates": args.max_updates,
"total_time_s": round(total_elapsed, 6),
"add_only_s": round(add_elapsed, 6),
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
"zmq_nodes": int(zmq_count),
"stageA_time_s": round(stageA, 6),
"stageBC_time_s": round(stageBC, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row)
print("\n=== Summary ===")
for name in results_add:
add_values = results_add[name]
total_values = results_total[name]
zmq_values = results_zmq[name]
latency_values = results_ms_per_passage[name]
if not add_values:
print(f"{name}: no successful runs")
continue
avg_add = sum(add_values) / len(add_values)
avg_total = sum(total_values) / len(total_values)
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
runs = len(add_values)
print(
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
)
if args.plot_path:
try:
import matplotlib.pyplot as plt
labels = [name for name, *_ in scenarios]
values = [
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
if results_ms_per_passage[name]
else 0.0
for name in labels
]
def _auto_cap(vals: list[float]) -> float | None:
s = sorted(vals, reverse=True)
if len(s) < 2:
return None
if s[1] > 0 and s[0] >= 2.5 * s[1]:
return s[1] * 1.1
return None
def _fmt_ms(v: float) -> str:
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
if args.broken_y:
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.4, 5.0),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
)
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i,
v + lower_cap * 0.02,
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False)
ax_bottom.xaxis.tick_bottom()
d = 0.015
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
ax_top.plot((-d, +d), (-d, +d), **kwargs)
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_bottom.set_xticks(range(len(labels)))
ax_bottom.set_xticklabels(labels)
ax = ax_bottom
else:
cap = args.cap_y or _auto_cap(values)
plt.figure(figsize=(7.2, 4.2))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (v, show) in enumerate(zip(values, show_vals)):
b = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(b[0])
if v > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
ax.plot(
[0.02 - 0.02, 0.02 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
ax.plot(
[0.98 - 0.02, 0.98 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
if any(v > cap for v in values):
ax.legend(
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels)
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
plt.ylabel("Average add latency (ms per passage)")
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
plt.tight_layout()
plt.savefig(args.plot_path)
print(f"Saved latency bar plot to {args.plot_path}")
# ZMQ time split (Stage A vs B/C)
try:
plt.figure(figsize=(6, 4))
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
bc_vals = [
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
]
ind = range(len(labels))
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
plt.bar(
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
)
plt.xticks(list(ind), labels, rotation=10)
plt.ylabel("Server ZMQ time (s)")
plt.title(
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
)
plt.legend()
out2 = args.plot_path.with_name(
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
)
plt.tight_layout()
plt.savefig(out2)
print(f"Saved ZMQ time split plot to {out2}")
except Exception as e:
print("Failed to plot ZMQ split:", e)
except ImportError:
print("matplotlib not available; skipping plot generation")
# leave the last build on disk for inspection
if __name__ == "__main__":
main()

View File

@@ -1,5 +0,0 @@
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario cache_enabled ef_construction max_initial max_updates total_time_s add_only_s latency_ms_per_passage zmq_nodes stageA_time_s stageBC_time_s model_name embedding_mode distance_metric
2 20251024-133101 baseline 1 200 300 1 3.391856 1.120359 1120.359421 126 0.507821 0.601608 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-133101 no_cache_baseline 0 200 300 1 34.941514 32.91376 32913.760185 4033 0.506933 32.159928 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251024-133101 disable_forward_rng 1 200 300 1 2.746756 0.8202 820.200443 66 0.474354 0.338454 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251024-133101 disable_forward_and_reverse_rng 1 200 300 1 2.396566 0.521478 521.478415 1 0.508973 0.006938 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -1,704 +0,0 @@
"""
Compare two latency models for small incremental updates vs. search:
Scenario A (sequential update then search):
- Build initial HNSW (is_recompute=True)
- Start embedding server (ZMQ) for recompute
- Add N passages one-by-one (each triggers recompute over ZMQ)
- Then run a search query on the updated index
- Report total time = sum(add_i) + search_time, with breakdowns
Scenario B (offline embeds + concurrent search; no graph updates):
- Do NOT insert the N passages into the graph
- In parallel: (1) compute embeddings for the N passages; (2) compute query
embedding and run a search on the existing index
- After both finish, compute similarity between the query embedding and the N
new passage embeddings, merge with the index search results by score, and
report time = max(embed_time, search_time) (i.e., no blocking on updates)
This script reuses the model/data loading conventions of
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
comparison for the two execution strategies above.
Example (from the repository root):
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 --num-updates 5 --k 10
"""
import argparse
import csv
import json
import logging
import os
import pickle
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import psutil # type: ignore
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
if metric == "cosine":
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
norms[norms == 0] = 1
vecs = vecs / norms
return vecs
def _read_index_for_search(index_path: Path) -> Any:
index_file = index_path.parent / f"{index_path.stem}.index"
# Force-disable experimental disk cache when loading the index so that
# incremental benchmarks don't pick up stale top-degree bitmaps.
cfg = faiss.HNSWIndexConfig()
cfg.is_recompute = True
if hasattr(cfg, "disk_cache_ratio"):
cfg.disk_cache_ratio = 0.0
if hasattr(cfg, "external_storage_path"):
cfg.external_storage_path = None
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
index = faiss.read_index(str(index_file), io_flags, cfg)
# ensure recompute mode persists after reload
try:
index.is_recompute = True
except AttributeError:
pass
try:
actual_ntotal = index.hnsw.levels.size()
except AttributeError:
actual_ntotal = index.ntotal
if actual_ntotal != index.ntotal:
print(
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
flush=True,
)
index.ntotal = actual_ntotal
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
return index
def _append_passages_for_updates(
meta_path: Path,
start_id: int,
texts: list[str],
) -> list[str]:
"""Append update passages so the embedding server can serve recompute fetches."""
if not texts:
return []
index_dir = meta_path.parent
meta_name = meta_path.name
if not meta_name.endswith(".meta.json"):
raise ValueError(f"Unexpected meta filename: {meta_path}")
index_base = meta_name[: -len(".meta.json")]
passages_file = index_dir / f"{index_base}.passages.jsonl"
offsets_file = index_dir / f"{index_base}.passages.idx"
if not passages_file.exists() or not offsets_file.exists():
raise FileNotFoundError(
"Passage store missing; cannot register update passages for recompute mode."
)
with open(offsets_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
assigned_ids: list[str] = []
with open(passages_file, "a", encoding="utf-8") as f:
for i, text in enumerate(texts):
passage_id = str(start_id + i)
offset = f.tell()
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
f.write("\n")
offset_map[passage_id] = offset
assigned_ids.append(passage_id)
with open(offsets_file, "wb") as f:
pickle.dump(offset_map, f)
try:
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
except json.JSONDecodeError:
meta = {}
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
return assigned_ids
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
q = np.ascontiguousarray(q, dtype=np.float32)
distances = np.zeros((1, k), dtype=np.float32)
indices = np.zeros((1, k), dtype=np.int64)
index.search(
1,
faiss.swig_ptr(q),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(indices),
)
return distances[0], indices[0]
def _score_for_metric(dist: float, metric: str) -> float:
# Convert FAISS distance to a "higher is better" score
if metric in ("mips", "cosine"):
return float(dist)
# l2 distance (smaller better) -> negative distance as score
return -float(dist)
def _merge_results(
index_results: tuple[np.ndarray, np.ndarray],
offline_scores: list[tuple[int, float]],
k: int,
metric: str,
) -> list[tuple[str, float]]:
distances, indices = index_results
merged: list[tuple[str, float]] = []
for distance, idx in zip(distances.tolist(), indices.tolist()):
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
for j, s in offline_scores:
merged.append((f"offline:{j}", s))
merged.sort(key=lambda x: x[1], reverse=True)
return merged[:k]
@dataclass
class ScenarioResult:
name: str
update_total_s: float
search_s: float
overall_s: float
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/offline-vs-update.leann"),
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
)
parser.add_argument("--max-initial", type=int, default=300)
parser.add_argument("--num-updates", type=int, default=5)
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
parser.add_argument(
"--query",
type=str,
default="neural network",
help="Query text used for the search benchmark.",
)
parser.add_argument("--server-port", type=int, default=5557)
parser.add_argument("--add-timeout", type=int, default=600)
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
parser.add_argument("--embedding-mode", default="sentence-transformers")
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
)
parser.add_argument("--ef-construction", type=int, default=200)
parser.add_argument(
"--only",
choices=["A", "B", "both"],
default="both",
help="Run only Scenario A, Scenario B, or both",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Where to append results (CSV).",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
# Load data
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, None)
if not update_paragraphs:
raise ValueError("No update passages loaded from --update-files")
update_paragraphs = update_paragraphs[: args.num_updates]
if len(update_paragraphs) < args.num_updates:
raise ValueError(
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
)
ensure_index_dir(args.index_path)
cleanup_index_files(args.index_path)
# Build initial index
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
# Prepare index object and meta
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
index = _read_index_for_search(args.index_path)
# CSV setup
run_id = time.strftime("%Y%m%d-%H%M%S")
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
csv_fields = [
"run_id",
"scenario",
"max_initial",
"num_updates",
"k",
"total_time_s",
"add_total_s",
"search_time_s",
"emb_time_s",
"makespan_s",
"model_name",
"embedding_mode",
"distance_metric",
]
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
# Debug: list existing HNSW server PIDs before starting
try:
existing = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if existing:
print("[debug] Found existing hnsw_embedding_server processes before run:")
for p in existing:
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
except Exception as _e:
pass
add_total = 0.0
search_after_add = 0.0
total_seq = 0.0
port_a = None
if args.only in ("A", "both"):
# Scenario A: sequential update then search
start_id = index.ntotal
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
if assigned_ids:
logger.debug(
"Registered %d update passages starting at id %s",
len(assigned_ids),
assigned_ids[0],
)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
ok, port = server_manager.start_server(
port=args.server_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok:
raise RuntimeError("Failed to start embedding server")
try:
# Set ZMQ port for recompute mode
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(port)
# Start A overall timer BEFORE computing update embeddings
t0 = time.time()
# Compute embeddings for updates (counted into A's overall)
t_emb0 = time.time()
upd_embs = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time_updates = time.time() - t_emb0
upd_embs = np.asarray(upd_embs, dtype=np.float32)
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
# Perform sequential adds
for i in range(upd_embs.shape[0]):
t_add0 = time.time()
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
add_total += time.time() - t_add0
# Don't persist index after adds to avoid contaminating Scenario B
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
# faiss.write_index(index, str(index_file))
# Search after updates
q_emb = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_emb = np.asarray(q_emb, dtype=np.float32)
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
# Warm up search with a dummy query first
print("[DEBUG] Warming up search...")
_ = _search(index, q_emb, 1)
t_s0 = time.time()
D_upd, I_upd = _search(index, q_emb, args.k)
search_after_add = time.time() - t_s0
total_seq = time.time() - t0
finally:
server_manager.stop_server()
port_a = port
print("\n=== Scenario A: update->search (sequential) ===")
# emb_time_updates is defined only when A runs
try:
_emb_a = emb_time_updates
except NameError:
_emb_a = 0.0
print(
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
)
# CSV row for A
if args.csv_path:
row_a = {
"run_id": run_id,
"scenario": "A",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": round(total_seq, 6),
"add_total_s": round(add_total, 6),
"search_time_s": round(search_after_add, 6),
"emb_time_s": round(_emb_a, 6),
"makespan_s": 0.0,
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_a)
# Verify server cleanup
try:
# short sleep to allow signal handling to finish
time.sleep(0.5)
leftovers = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if leftovers:
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
for p in leftovers:
print(
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
)
else:
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
except Exception:
pass
# Scenario B: offline embeds + concurrent search (no graph updates)
if args.only in ("B", "both"):
# ensure a server is available for recompute search
server_manager_b = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
requested_port = args.server_port if port_a is None else port_a
ok_b, port_b = server_manager_b.start_server(
port=requested_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok_b:
raise RuntimeError("Failed to start embedding server for Scenario B")
# Wait for server to fully initialize
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
time.sleep(2)
try:
# Read the index first
index_no_update = _read_index_for_search(args.index_path) # unchanged index
# Then configure ZMQ port on the correct index object
if hasattr(index_no_update.hnsw, "set_zmq_port"):
index_no_update.hnsw.set_zmq_port(port_b)
elif hasattr(index_no_update, "set_zmq_port"):
index_no_update.set_zmq_port(port_b)
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
logger.info("Warming up embedding model for Scenario B...")
_ = compute_embeddings(
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
)
# Prepare worker A: compute embeddings for the same N passages
emb_time = 0.0
updates_embs_offline: np.ndarray | None = None
def _worker_emb():
nonlocal emb_time, updates_embs_offline
t = time.time()
updates_embs_offline = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time = time.time() - t
# Pre-compute query embedding and warm up search outside of timed section.
q_vec = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_vec = np.asarray(q_vec, dtype=np.float32)
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
print("[DEBUG B] Warming up search...")
_ = _search(index_no_update, q_vec, 1)
# Worker B: timed search on the warmed index
search_time = 0.0
offline_elapsed = 0.0
index_results: tuple[np.ndarray, np.ndarray] | None = None
def _worker_search():
nonlocal search_time, index_results
t = time.time()
distances, indices = _search(index_no_update, q_vec, args.k)
search_time = time.time() - t
index_results = (distances, indices)
# Run two workers concurrently
t0 = time.time()
th1 = threading.Thread(target=_worker_emb)
th2 = threading.Thread(target=_worker_search)
th1.start()
th2.start()
th1.join()
th2.join()
offline_elapsed = time.time() - t0
# For mixing: compute query vs. offline update similarities (pure client-side)
offline_scores: list[tuple[int, float]] = []
if updates_embs_offline is not None:
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
for j in range(upd2.shape[0]):
if args.distance_metric in ("mips", "cosine"):
s = float(np.dot(q_vec[0], upd2[j]))
else:
diff = q_vec[0] - upd2[j]
s = -float(np.dot(diff, diff))
offline_scores.append((j, s))
merged_topk = (
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
if index_results
else []
)
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
print(
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
)
if merged_topk:
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
print(f"Merged top-5 preview: {preview}")
# CSV row for B
if args.csv_path:
row_b = {
"run_id": run_id,
"scenario": "B",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": 0.0,
"add_total_s": 0.0,
"search_time_s": round(search_time, 6),
"emb_time_s": round(emb_time, 6),
"makespan_s": round(offline_elapsed, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_b)
finally:
server_manager_b.stop_server()
# Summary
print("\n=== Summary ===")
msg_a = (
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
if args.only in ("A", "both")
else "A: skipped"
)
msg_b = (
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
if args.only in ("B", "both")
else "B: skipped"
)
print(msg_a + "\n" + msg_b)
if __name__ == "__main__":
main()

View File

@@ -1,5 +0,0 @@
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario max_initial num_updates k total_time_s add_total_s search_time_s emb_time_s makespan_s model_name embedding_mode distance_metric
2 20251024-141607 A 300 1 10 3.273957 3.050168 0.097825 0.017339 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-141607 B 300 1 10 0.0 0.0 0.111892 0.007869 0.112635 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251025-160652 A 300 5 10 5.061945 4.805962 0.123271 0.015008 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251025-160652 B 300 5 10 0.0 0.0 0.101809 0.008817 0.102447 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -1,645 +0,0 @@
#!/usr/bin/env python3
"""
Plot latency bars from the benchmark CSV produced by
benchmarks/update/bench_hnsw_rng_recompute.py.
If you also provide an offline_vs_update.csv via --csv-right
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
output a side-by-side figure:
- Left: ms/passage bars (four RNG scenarios).
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
Usage:
uv run python benchmarks/update/plot_bench_results.py \
--csv benchmarks/update/bench_results.csv \
--out benchmarks/update/bench_latency_from_csv.png
The script selects the latest run_id in the CSV and plots four bars for
the default scenarios:
- baseline
- no_cache_baseline
- disable_forward_rng
- disable_forward_and_reverse_rng
If multiple rows exist per scenario for that run_id, the script averages
their latency_ms_per_passage values.
"""
import argparse
import csv
from collections import defaultdict
from pathlib import Path
DEFAULT_SCENARIOS = [
"no_cache_baseline",
"baseline",
"disable_forward_rng",
"disable_forward_and_reverse_rng",
]
SCENARIO_LABELS = {
"baseline": "+ Cache",
"no_cache_baseline": "Naive \n Recompute",
"disable_forward_rng": "+ w/o \n Fwd RNG",
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
}
# Paper-style colors and hatches for scenarios
SCENARIO_STYLES = {
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
}
def load_latest_run(csv_path: Path):
rows = []
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
if not rows:
raise SystemExit("CSV is empty: no rows to plot")
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
run_ids = [r.get("run_id", "") for r in rows]
latest = max(run_ids)
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
if not latest_rows:
# Fallback: take last 4 rows
latest_rows = rows[-4:]
latest = latest_rows[-1].get("run_id", "unknown")
return latest, latest_rows
def aggregate_latency(rows):
acc = defaultdict(list)
for r in rows:
sc = r.get("scenario", "")
try:
val = float(r.get("latency_ms_per_passage", "nan"))
except ValueError:
continue
acc[sc].append(val)
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
return avg
def _auto_cap(values: list[float]) -> float | None:
if not values:
return None
sorted_vals = sorted(values, reverse=True)
if len(sorted_vals) < 2:
return None
max_v, second = sorted_vals[0], sorted_vals[1]
if second <= 0:
return None
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
if max_v >= 2.5 * second:
return second * 1.1
return None
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
# Draw small diagonal ticks near left/right to signal cap
x0, x1 = rel_x0, rel_x1
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
def _fmt_ms(v: float) -> str:
if v >= 1000:
return f"{v / 1000:.1f}k"
return f"{v:.1f}"
def main():
# Set LaTeX style for paper figures (matching paper_fig.py)
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument(
"--csv",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Path to results CSV (defaults to bench_results.csv)",
)
ap.add_argument(
"--out",
type=Path,
default=Path("add_ablation.pdf"),
help="Output image path",
)
ap.add_argument(
"--csv-right",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
)
ap.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
)
ap.add_argument(
"--no-auto-cap",
action="store_true",
help="Disable auto-cap heuristic when --cap-y is not provided.",
)
ap.add_argument(
"--broken-y",
action="store_true",
default=True,
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
)
ap.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
)
ap.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
)
args = ap.parse_args()
latest_run, latest_rows = load_latest_run(args.csv)
avg = aggregate_latency(latest_rows)
try:
import matplotlib.pyplot as plt
except Exception as e:
raise SystemExit(f"matplotlib not available: {e}")
scenarios = DEFAULT_SCENARIOS
values = [avg.get(name, 0.0) for name in scenarios]
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
# If right CSV is provided, build side-by-side figure
if args.csv_right is not None:
try:
right_rows_all = []
with args.csv_right.open("r", encoding="utf-8") as f:
rreader = csv.DictReader(f)
right_rows_all = list(rreader)
if right_rows_all:
r_latest = max(r.get("run_id", "") for r in right_rows_all)
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
else:
r_latest = None
right_rows = []
except Exception:
r_latest = None
right_rows = []
a_total = 0.0
b_makespan = 0.0
for r in right_rows:
sc = (r.get("scenario", "") or "").strip().upper()
if sc == "A":
try:
a_total = float(r.get("total_time_s", 0.0))
except Exception:
pass
elif sc == "B":
try:
b_makespan = float(r.get("makespan_s", 0.0))
except Exception:
pass
import matplotlib.pyplot as plt
from matplotlib import gridspec
# Left subplot (reuse current style, with optional cap)
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
x = list(range(len(labels)))
if args.broken_y:
# Use broken axis for left subplot
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
gs = gridspec.GridSpec(
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
)
ax_left_top = fig.add_subplot(gs[0, 0])
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
ax_right = fig.add_subplot(gs[:, 1])
# Determine break points
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = (
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
) # Increased to show more range
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.5, lower_cap * 1.02)
)
ymax = (
max(values) * 1.90 if values else 1.0
) # Increase headroom to 1.90 for text label and tick range
# Draw bars on both axes
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Set limits
ax_left_bottom.set_ylim(0, lower_cap)
ax_left_top.set_ylim(upper_start, ymax)
# Annotate values (convert ms to s)
values_s = [v / 1000.0 for v in values]
lower_cap_s = lower_cap / 1000.0
upper_start_s = upper_start / 1000.0
ymax_s = ymax / 1000.0
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
ax_left_bottom.clear()
ax_left_top.clear()
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
# Draw in bottom axis for all bars
ax_left_bottom.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
# Only draw in top axis if the bar is tall enough to reach the upper range
if v > upper_start_s:
ax_left_top.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
for i, v in enumerate(values_s):
if v <= lower_cap_s:
ax_left_bottom.text(
i,
v + lower_cap_s * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
else:
ax_left_top.text(
i,
v + (ymax_s - upper_start_s) * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
# Hide spines between axes
ax_left_top.spines["bottom"].set_visible(False)
ax_left_bottom.spines["top"].set_visible(False)
ax_left_top.tick_params(
labeltop=False, labelbottom=False, bottom=False
) # Hide tick marks
ax_left_bottom.xaxis.tick_bottom()
ax_left_bottom.tick_params(top=False) # Hide top tick marks
# Draw break marks (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_left_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_left_bottom.transAxes})
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.set_xticks(x)
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_left_bottom.tick_params(axis="y", labelsize=10)
ax_left_top.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match bar width with right subplot
ax_left_bottom.set_xlim(-0.6, 3.6)
ax_left_top.set_xlim(-0.6, 3.6)
ax_left = ax_left_bottom # for compatibility
else:
# Regular side-by-side layout
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
for i, (val, show) in enumerate(zip(values, show_vals)):
if val > cap:
bars[i].set_hatch("//")
ax_left.text(
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
)
else:
ax_left.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(val),
ha="center",
va="bottom",
fontsize=9,
)
ax_left.set_ylim(0, cap * 1.10)
_add_break_marker(ax_left, y=0.98)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
else:
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
for i, v in enumerate(values):
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
ax_left.set_ylabel("Latency (ms per passage)")
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
ax_left.set_title(
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
)
# Right subplot (A vs B, seconds) - paper style
r_labels = ["Sequential", "Delayed \n Add+Search"]
r_values = [a_total or 0.0, b_makespan or 0.0]
r_styles = [
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
{"edgecolor": "#edc948", "hatch": "/////"},
]
# 2 bars, centered with proper spacing
xr = [0, 1]
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (v, style) in enumerate(zip(r_values, r_styles)):
ax_right.bar(
xr[i],
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
for i, v in enumerate(r_values):
max_v = max(r_values) if r_values else 1.0
offset = max(0.0002, 0.02 * max_v)
ax_right.text(
xr[i],
v + offset,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
ax_right.set_xticks(xr)
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_right.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match left subplot's bar width visually
# Accounting for width_ratios=[1.5, 1]:
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
# Right: 2 bars, need same visual width
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
# range_right = 4.2 / 1.5 = 2.8
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
ax_right.set_xlim(-0.9, 1.9)
# Set y-axis limit with headroom for text labels
if r_values:
max_v = max(r_values)
ax_right.set_ylim(0, max_v * 1.15)
# Format y-axis to avoid scientific notation
ax_right.ticklabel_format(style="plain", axis="y")
plt.tight_layout()
# Add aligned ylabels using fig.text (after tight_layout)
# Get the vertical center of the entire figure
fig_center_y = 0.5
# Left ylabel - closer to left plot
left_x = 0.05
fig.text(
left_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
# Right ylabel - closer to right plot
right_bbox = ax_right.get_position()
right_x = right_bbox.x0 - 0.07
fig.text(
right_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
return
# Broken-Y mode
if args.broken_y:
import matplotlib.pyplot as plt
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.5, 6.75),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
)
# Determine default breaks from second-highest
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Limits
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
# Annotate values
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
# Hide spines between axes and draw diagonal break marks
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
ax_bottom.xaxis.tick_bottom()
# Diagonal lines at the break (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
ax_bottom.set_xticks(x)
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
ax = ax_bottom # for labeling below
else:
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
plt.figure(figsize=(5.4, 3.15))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
bar = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(bar[0])
# Hatch and annotate when capped
if val > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
_add_break_marker(ax, y=0.98)
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
v > cap for v in values
) else None
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(
idx,
val + 1.0,
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
# Try to extract some context for title
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
if args.broken_y:
fig.text(
0.02,
0.5,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
fig.suptitle(
"Add Operation Latency",
fontsize=11,
y=0.98,
fontweight="bold",
)
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
else:
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
plt.tight_layout()
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
if __name__ == "__main__":
main()

View File

@@ -1,395 +0,0 @@
# Slack Integration Setup Guide
This guide provides step-by-step instructions for setting up Slack integration with LEANN.
## Overview
LEANN's Slack integration uses MCP (Model Context Protocol) servers to fetch and index your Slack messages for RAG (Retrieval-Augmented Generation). This allows you to search through your Slack conversations using natural language queries.
## Prerequisites
1. **Slack Workspace Access**: You need admin or owner permissions in your Slack workspace to create apps and configure OAuth tokens.
2. **Slack MCP Server**: Install a Slack MCP server (e.g., `slack-mcp-server` via npm)
3. **LEANN**: Ensure you have LEANN installed and working
## Step 1: Create a Slack App
### 1.1 Go to Slack API Dashboard
1. Visit [https://api.slack.com/apps](https://api.slack.com/apps)
2. Click **"Create New App"**
3. Choose **"From scratch"**
4. Enter your app name (e.g., "LEANN Slack Integration")
5. Select your workspace
6. Click **"Create App"**
### 1.2 Configure App Permissions
#### Token Scopes
1. In your app dashboard, go to **"OAuth & Permissions"** in the left sidebar
2. Scroll down to **"Scopes"** section
3. Under **"Bot Token Scopes & OAuth Scope"**, click **"Add an OAuth Scope"**
4. Add the following scopes:
- `channels:read` - Read public channel information
- `channels:history` - Read messages in public channels
- `groups:read` - Read private channel information
- `groups:history` - Read messages in private channels
- `im:read` - Read direct message information
- `im:history` - Read direct messages
- `mpim:read` - Read group direct message information
- `mpim:history` - Read group direct messages
- `users:read` - Read user information
- `team:read` - Read workspace information
#### App-Level Tokens (Optional)
Some MCP servers may require app-level tokens:
1. Go to **"Basic Information"** in the left sidebar
2. Scroll down to **"App-Level Tokens"**
3. Click **"Generate Token and Scopes"**
4. Enter a name (e.g., "LEANN Integration")
5. Add the `connections:write` scope
6. Click **"Generate"**
7. Copy the token (starts with `xapp-`)
### 1.3 Install App to Workspace
1. Go to **"OAuth & Permissions"** in the left sidebar
2. Click **"Install to Workspace"**
3. Review the permissions and click **"Allow"**
4. Copy the **"Bot User OAuth Token"** (starts with `xoxb-`)
5. Copy the **"User OAuth Token"** (starts with `xoxp-`)
## Step 2: Install Slack MCP Server
### Option A: Using npm (Recommended)
```bash
# Install globally
npm install -g slack-mcp-server
# Or install locally
npm install slack-mcp-server
```
### Option B: Using npx (No installation required)
```bash
# Use directly without installation
npx slack-mcp-server
```
## Step 3: Install and Configure Ollama (for Real LLM Responses)
### 3.1 Install Ollama
```bash
# Install Ollama using Homebrew (macOS)
brew install ollama
# Or download from https://ollama.ai/
```
### 3.2 Start Ollama Service
```bash
# Start Ollama as a service
brew services start ollama
# Or start manually
ollama serve
```
### 3.3 Pull a Model
```bash
# Pull a lightweight model for testing
ollama pull llama3.2:1b
# Verify the model is available
ollama list
```
## Step 4: Configure Environment Variables
Create a `.env` file or set environment variables:
```bash
# Required: User OAuth Token
SLACK_OAUTH_TOKEN=xoxp-your-user-oauth-token-here
# Optional: App-Level Token (if your MCP server requires it)
SLACK_APP_TOKEN=xapp-your-app-token-here
# Optional: Workspace-specific settings
SLACK_WORKSPACE_ID=T1234567890 # Your workspace ID (optional)
```
## Step 5: Test the Setup
### 5.1 Test MCP Server Connection
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--test-connection \
--workspace-name "Your Workspace Name"
```
This will test the connection and list available tools without indexing any data.
### 5.2 Index a Specific Channel
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Your Workspace Name" \
--channels general \
--query "What did we discuss about the project?"
```
### 5.3 Real RAG Query Examples
This section demonstrates successful Slack RAG integration queries against the Sky Lab Computing workspace's "random" channel. The system successfully retrieves actual conversation messages and performs semantic search with high relevance scores, including finding specific research paper announcements and technical discussions.
### Example 1: Advisor Models Query
**Query:** "train black-box models to adopt to your personal data"
This query demonstrates the system's ability to find specific research announcements about training black-box models for personal data adaptation.
![Advisor Models Query - Command Setup](videos/slack_integration_1.1.png)
![Advisor Models Query - Search Results](videos/slack_integration_1.2.png)
![Advisor Models Query - LLM Response](videos/slack_integration_1.3.png)
### Example 2: Barbarians at the Gate Query
**Query:** "AI-driven research systems ADRS"
This query demonstrates the system's ability to find specific research announcements about AI-driven research systems and algorithm discovery.
![Barbarians Query - Command Setup](videos/slack_integration_2.1.png)
![Barbarians Query - Search Results](videos/slack_integration_2.2.png)
![Barbarians Query - LLM Response](videos/slack_integration_2.3.png)
### Prerequisites
- Bot is installed in the Sky Lab Computing workspace and invited to the target channel (run `/invite @YourBotName` in the channel if needed)
- Bot token available and exported in the same terminal session
### Commands
1) Set the workspace token for this shell
```bash
export SLACK_MCP_XOXP_TOKEN="xoxp-***-redacted-***"
```
2) Run queries against the "random" channel by channel ID (C0GN5BX0F)
**Advisor Models Query:**
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Sky Lab Computing" \
--channels C0GN5BX0F \
--max-messages-per-channel 100000 \
--query "train black-box models to adopt to your personal data" \
--llm ollama \
--llm-model "llama3.2:1b" \
--llm-host "http://localhost:11434" \
--no-concatenate-conversations
```
**Barbarians at the Gate Query:**
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Sky Lab Computing" \
--channels C0GN5BX0F \
--max-messages-per-channel 100000 \
--query "AI-driven research systems ADRS" \
--llm ollama \
--llm-model "llama3.2:1b" \
--llm-host "http://localhost:11434" \
--no-concatenate-conversations
```
These examples demonstrate the system's ability to find and retrieve specific research announcements and technical discussions from the conversation history, showcasing the power of semantic search in Slack data.
3) Optional: Ask a broader question
```bash
python test_channel_by_id_or_name.py \
--channel-id C0GN5BX0F \
--workspace-name "Sky Lab Computing" \
--query "What is LEANN about?"
```
Notes:
- If you see `not_in_channel`, invite the bot to the channel and re-run.
- If you see `channel_not_found`, confirm the channel ID and workspace.
- Deep search via server-side “search” tools may require additional Slack scopes; the example above performs client-side filtering over retrieved history.
## Common Issues and Solutions
### Issue 1: "users cache is not ready yet" Error
**Problem**: You see this warning:
```
WARNING - Failed to fetch messages from channel random: Failed to fetch messages: {'code': -32603, 'message': 'users cache is not ready yet, sync process is still running... please wait'}
```
**Solution**: This is a common timing issue. The LEANN integration now includes automatic retry logic:
1. **Wait and Retry**: The system will automatically retry with exponential backoff (2s, 4s, 8s, etc.)
2. **Increase Retry Parameters**: If needed, you can customize retry behavior:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--max-retries 10 \
--retry-delay 3.0 \
--channels general \
--query "Your query here"
```
3. **Keep MCP Server Running**: Start the MCP server separately and keep it running:
```bash
# Terminal 1: Start MCP server
slack-mcp-server
# Terminal 2: Run LEANN (it will connect to the running server)
python -m apps.slack_rag --mcp-server "slack-mcp-server" --channels general --query "test"
```
### Issue 2: "No message fetching tool found"
**Problem**: The MCP server doesn't have the expected tools.
**Solution**:
1. Check if your MCP server is properly installed and configured
2. Verify your Slack tokens are correct
3. Try a different MCP server implementation
4. Check the MCP server documentation for required configuration
### Issue 3: Permission Denied Errors
**Problem**: You get permission errors when trying to access channels.
**Solutions**:
1. **Check Bot Permissions**: Ensure your bot has been added to the channels you want to access
2. **Verify Token Scopes**: Make sure you have all required scopes configured
3. **Channel Access**: For private channels, the bot needs to be explicitly invited
4. **Workspace Permissions**: Ensure your Slack app has the necessary workspace permissions
### Issue 4: Empty Results
**Problem**: No messages are returned even though the channel has messages.
**Solutions**:
1. **Check Channel Names**: Ensure channel names are correct (without the # symbol)
2. **Verify Bot Access**: Make sure the bot can access the channels
3. **Check Date Ranges**: Some MCP servers have limitations on message history
4. **Increase Message Limits**: Try increasing the message limit:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--channels general \
--max-messages-per-channel 1000 \
--query "test"
```
## Advanced Configuration
### Custom MCP Server Commands
If you need to pass additional parameters to your MCP server:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server --token-file /path/to/tokens.json" \
--workspace-name "Your Workspace" \
--channels general \
--query "Your query"
```
### Multiple Workspaces
To work with multiple Slack workspaces, you can:
1. Create separate apps for each workspace
2. Use different environment variables
3. Run separate instances with different configurations
### Performance Optimization
For better performance with large workspaces:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Your Workspace" \
--max-messages-per-channel 500 \
--no-concatenate-conversations \
--query "Your query"
```
---
## Troubleshooting Checklist
- [ ] Slack app created with proper permissions
- [ ] Bot token (xoxb-) copied correctly
- [ ] App-level token (xapp-) created if needed
- [ ] MCP server installed and accessible
- [ ] Ollama installed and running (`brew services start ollama`)
- [ ] Ollama model pulled (`ollama pull llama3.2:1b`)
- [ ] Environment variables set correctly
- [ ] Bot invited to relevant channels
- [ ] Channel names specified without # symbol
- [ ] Sufficient retry attempts configured
- [ ] Network connectivity to Slack APIs
## Getting Help
If you continue to have issues:
1. **Check Logs**: Look for detailed error messages in the console output
2. **Test MCP Server**: Use `--test-connection` to verify the MCP server is working
3. **Verify Tokens**: Double-check that your Slack tokens are valid and have the right scopes
4. **Check Ollama**: Ensure Ollama is running (`ollama serve`) and the model is available (`ollama list`)
5. **Community Support**: Reach out to the LEANN community for help
## Example Commands
### Basic Usage
```bash
# Test connection
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
# Index specific channels
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "My Company" \
--channels general random \
--query "What did we decide about the project timeline?"
```
### Advanced Usage
```bash
# With custom retry settings
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "My Company" \
--channels general \
--max-retries 10 \
--retry-delay 5.0 \
--max-messages-per-channel 2000 \
--query "Show me all decisions made in the last month"
```

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 445 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 508 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 437 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 474 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 501 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 454 KiB

View File

@@ -29,25 +29,12 @@ if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
endif()
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
# Use system ZeroMQ instead of building from source
find_package(PkgConfig REQUIRED)
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
endif()
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
# This creates PkgConfig::ZMQ target automatically with correct properties
if(TARGET PkgConfig::ZMQ)
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
else()
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
endif()
pkg_check_modules(ZMQ REQUIRED libzmq)
# Add cppzmq headers
include_directories(SYSTEM third_party/cppzmq)
include_directories(third_party/cppzmq)
# Configure msgpack-c - disable boost dependency
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)

View File

@@ -215,8 +215,6 @@ class HNSWSearcher(BaseSearcher):
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
if hasattr(self._index, "set_zmq_port"):
self._index.set_zmq_port(zmq_port)
if query.dtype != np.float32:
query = query.astype(np.float32)

View File

@@ -820,10 +820,10 @@ class LeannBuilder:
actual_port,
requested_zmq_port,
)
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
try:
index.hnsw.zmq_port = actual_port
except AttributeError:
pass
if needs_recompute:
for i in range(embeddings.shape[0]):
@@ -1236,17 +1236,6 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge."
)
print("The context provided to the LLM is:")
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
print("-" * 150)
for r in results:
chunk_relevance = f"{r.score:.3f}"
chunk_id = r.id
chunk_content = r.text[:60]
chunk_source = r.metadata.get("source", "")[:80]
print(
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
)
ask_time = time.time()
ans = self.llm.ask(prompt, **llm_kwargs)
ask_time = time.time() - ask_time

View File

@@ -834,11 +834,6 @@ class OpenAIChat(LLMInterface):
try:
response = self.client.chat.completions.create(**params)
print(
f"Total tokens = {response.usage.total_tokens}, prompt tokens = {response.usage.prompt_tokens}, completion tokens = {response.usage.completion_tokens}"
)
if response.choices[0].finish_reason == "length":
print("The query is exceeding the maximum allowed number of tokens")
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")

View File

@@ -5,128 +5,12 @@ Packaged within leann-core so installed wheels can import it reliably.
import logging
from pathlib import Path
from typing import Any, Optional
from typing import Optional
from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__)
# Flag to ensure AST token warning only shown once per session
_ast_token_warning_shown = False
def estimate_token_count(text: str) -> int:
"""
Estimate token count for a text string.
Uses conservative estimation: ~4 characters per token for natural text,
~1.2 tokens per character for code (worse tokenization).
Args:
text: Input text to estimate tokens for
Returns:
Estimated token count
"""
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
return len(encoder.encode(text))
except ImportError:
# Fallback: Conservative character-based estimation
# Assume worst case for code: 1.2 tokens per character
return int(len(text) * 1.2)
def calculate_safe_chunk_size(
model_token_limit: int,
overlap_tokens: int,
chunking_mode: str = "traditional",
safety_factor: float = 0.9,
) -> int:
"""
Calculate safe chunk size accounting for overlap and safety margin.
Args:
model_token_limit: Maximum tokens supported by embedding model
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
chunking_mode: "traditional" (tokens) or "ast" (characters)
safety_factor: Safety margin (0.9 = 10% safety margin)
Returns:
Safe chunk size: tokens for traditional, characters for AST
"""
safe_limit = int(model_token_limit * safety_factor)
if chunking_mode == "traditional":
# Traditional chunking uses tokens
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
return max(1, safe_limit - overlap_tokens)
else: # AST chunking
# AST uses characters, need to convert
# Conservative estimate: 1.2 tokens per char for code
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
safe_chars = int(safe_limit / 1.2)
return max(1, safe_chars - overlap_chars)
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
"""
Validate that chunks don't exceed token limits and truncate if necessary.
Args:
chunks: List of text chunks to validate
max_tokens: Maximum tokens allowed per chunk
Returns:
Tuple of (validated_chunks, num_truncated)
"""
validated_chunks = []
num_truncated = 0
for i, chunk in enumerate(chunks):
estimated_tokens = estimate_token_count(chunk)
if estimated_tokens > max_tokens:
# Truncate chunk to fit token limit
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
tokens = encoder.encode(chunk)
if len(tokens) > max_tokens:
truncated_tokens = tokens[:max_tokens]
truncated_chunk = encoder.decode(truncated_tokens)
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
)
else:
validated_chunks.append(chunk)
except ImportError:
# Fallback: Conservative character truncation
char_limit = int(max_tokens / 1.2) # Conservative for code
if len(chunk) > char_limit:
truncated_chunk = chunk[:char_limit]
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
f"(conservative estimate for {max_tokens} tokens)"
)
else:
validated_chunks.append(chunk)
else:
validated_chunks.append(chunk)
if num_truncated > 0:
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
return validated_chunks, num_truncated
# Code file extensions supported by astchunk
CODE_EXTENSIONS = {
".py": "python",
@@ -177,45 +61,27 @@ def create_ast_chunks(
max_chunk_size: int = 512,
chunk_overlap: int = 64,
metadata_template: str = "default",
) -> list[dict[str, Any]]:
) -> list[str]:
"""Create AST-aware chunks from code documents using astchunk.
Falls back to traditional chunking if astchunk is unavailable.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
try:
from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
all_chunks = []
for doc in documents:
language = doc.metadata.get("language")
if not language:
logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
continue
try:
# Warn once if AST chunk size + overlap might exceed common token limits
# Note: Actual truncation happens at embedding time with dynamic model limits
global _ast_token_warning_shown
estimated_max_tokens = int(
(max_chunk_size + chunk_overlap) * 1.2
) # Conservative estimate
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
logger.warning(
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
)
_ast_token_warning_shown = True
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
@@ -239,40 +105,17 @@ def create_ast_chunks(
chunks = chunk_builder.chunkify(code_content)
for chunk in chunks:
chunk_text = None
astchunk_metadata = {}
if hasattr(chunk, "text"):
chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str):
chunk_text = chunk
elif isinstance(chunk, dict):
# Handle astchunk format: {"content": "...", "metadata": {...}}
if "content" in chunk:
chunk_text = chunk["content"]
astchunk_metadata = chunk.get("metadata", {})
elif "text" in chunk:
chunk_text = chunk["text"]
else:
chunk_text = str(chunk) # Last resort
else:
chunk_text = str(chunk)
if chunk_text and chunk_text.strip():
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Merge document metadata + astchunk metadata
combined_metadata = {**doc_metadata, **astchunk_metadata}
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
all_chunks.append(chunk_text.strip())
logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
@@ -280,19 +123,15 @@ def create_ast_chunks(
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
return all_chunks
def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
) -> list[str]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
@@ -308,40 +147,19 @@ def create_traditional_chunks(
paragraph_separator="\n\n",
)
result = []
all_texts = []
for doc in documents:
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
try:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
for node in nodes:
result.append({"text": node.get_content(), "metadata": doc_metadata})
all_texts.extend(node.get_content() for node in nodes)
except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}")
content = doc.get_content()
if content and content.strip():
result.append({"text": content.strip(), "metadata": doc_metadata})
all_texts.append(content.strip())
return result
def _traditional_chunks_as_dicts(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]:
"""Helper: Traditional chunking that returns dict format for consistency.
This is now just an alias for create_traditional_chunks for backwards compatibility.
"""
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
return all_texts
def create_text_chunks(
@@ -353,12 +171,8 @@ def create_text_chunks(
ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
) -> list[dict[str, Any]]:
"""Create text chunks from documents with optional AST support for code files.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
) -> list[str]:
"""Create text chunks from documents with optional AST support for code files."""
if not documents:
logger.warning("No documents provided for chunking")
return []
@@ -393,17 +207,14 @@ def create_text_chunks(
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
all_chunks.extend(
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
)
else:
raise
if text_docs:
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
else:
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")
# Note: Token truncation is now handled at embedding time with dynamic model limits
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
return all_chunks

View File

@@ -1,6 +1,5 @@
import argparse
import asyncio
import time
from pathlib import Path
from typing import Any, Optional, Union
@@ -107,7 +106,7 @@ Examples:
help="Documents directories and/or files (default: current directory)",
)
build_parser.add_argument(
"--backend-name",
"--backend",
type=str,
default="hnsw",
choices=["hnsw", "diskann"],
@@ -181,25 +180,25 @@ Examples:
"--doc-chunk-size",
type=int,
default=256,
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
help="Document chunk size in tokens/characters (default: 256)",
)
build_parser.add_argument(
"--doc-chunk-overlap",
type=int,
default=128,
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
help="Document chunk overlap (default: 128)",
)
build_parser.add_argument(
"--code-chunk-size",
type=int,
default=512,
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
help="Code chunk size in tokens/lines (default: 512)",
)
build_parser.add_argument(
"--code-chunk-overlap",
type=int,
default=50,
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
help="Code chunk overlap (default: 50)",
)
build_parser.add_argument(
"--use-ast-chunking",
@@ -209,14 +208,14 @@ Examples:
build_parser.add_argument(
"--ast-chunk-size",
type=int,
default=300,
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
default=768,
help="AST chunk size in characters (default: 768)",
)
build_parser.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
default=96,
help="AST chunk overlap in characters (default: 96)",
)
build_parser.add_argument(
"--ast-fallback-traditional",
@@ -255,11 +254,6 @@ Examples:
action="store_true",
help="Non-interactive mode: automatically select index without prompting",
)
search_parser.add_argument(
"--show-metadata",
action="store_true",
help="Display file paths and metadata in search results",
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
@@ -1192,7 +1186,6 @@ Examples:
for doc in other_docs:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
doc.metadata["source"] = file_path
filtered_docs.append(doc)
documents.extend(filtered_docs)
@@ -1268,7 +1261,7 @@ Examples:
from .chunking_utils import create_text_chunks
# Use enhanced chunking with AST support
chunk_texts = create_text_chunks(
all_texts = create_text_chunks(
documents,
chunk_size=self.node_parser.chunk_size,
chunk_overlap=self.node_parser.chunk_overlap,
@@ -1279,9 +1272,6 @@ Examples:
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
)
# create_text_chunks now returns list[dict] with metadata preserved
all_texts.extend(chunk_texts)
except ImportError as e:
print(
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
@@ -1293,27 +1283,14 @@ Examples:
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
# Check if this is a code file based on source path
source_path = doc.metadata.get("source", "")
file_path = doc.metadata.get("file_path", "")
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
# Extract metadata to preserve with chunks
chunk_metadata = {
"file_path": file_path or source_path,
"file_name": doc.metadata.get("file_name", ""),
}
# Add optional metadata if available
if "creation_date" in doc.metadata:
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Use appropriate parser based on file type
parser = self.code_parser if is_code_file else self.node_parser
nodes = parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
all_texts.append(node.get_content())
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
return all_texts
@@ -1388,7 +1365,7 @@ Examples:
index_dir.mkdir(parents=True, exist_ok=True)
print(f"Building index '{index_name}' with {args.backend_name} backend...")
print(f"Building index '{index_name}' with {args.backend} backend...")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
@@ -1400,7 +1377,7 @@ Examples:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend_name,
backend_name=args.backend,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
@@ -1411,8 +1388,8 @@ Examples:
num_threads=args.num_threads,
)
for chunk in all_texts:
builder.add_text(chunk["text"], metadata=chunk["metadata"])
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"Index built at {index_path}")
@@ -1533,25 +1510,7 @@ Examples:
print(f"Search results for '{query}' (top {len(results)}):")
for i, result in enumerate(results, 1):
print(f"{i}. Score: {result.score:.3f}")
# Display metadata if flag is set
if args.show_metadata and result.metadata:
file_path = result.metadata.get("file_path", "")
if file_path:
print(f" 📄 File: {file_path}")
file_name = result.metadata.get("file_name", "")
if file_name and file_name != file_path:
print(f" 📝 Name: {file_name}")
# Show timestamps if available
if "creation_date" in result.metadata:
print(f" 🕐 Created: {result.metadata['creation_date']}")
if "last_modified_date" in result.metadata:
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
print(f" {result.text[:200]}...")
print(f" Source: {result.metadata.get('source', '')}")
print()
async def ask_questions(self, args):
@@ -1583,7 +1542,6 @@ Examples:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
query_start_time = time.time()
response = chat.ask(
prompt,
top_k=args.top_k,
@@ -1594,9 +1552,7 @@ Examples:
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
query_completion_time = time.time() - query_start_time
print(f"LEANN: {response}")
print(f"The query took {query_completion_time:.3f} seconds to finish")
initial_query = (args.query or "").strip()

View File

@@ -10,7 +10,6 @@ import time
from typing import Any, Optional
import numpy as np
import tiktoken
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
@@ -21,170 +20,6 @@ LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Token limit registry for embedding models
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
# Ollama models use dynamic discovery via /api/show
EMBEDDING_MODEL_LIMITS = {
# Nomic models (common across servers)
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
"nomic-embed-text-v1.5": 2048,
"nomic-embed-text-v2": 512,
# Other embedding models
"mxbai-embed-large": 512,
"all-minilm": 512,
"bge-m3": 8192,
"snowflake-arctic-embed": 512,
# OpenAI models
"text-embedding-3-small": 8192,
"text-embedding-3-large": 8192,
"text-embedding-ada-002": 8192,
}
def get_model_token_limit(
model_name: str,
base_url: Optional[str] = None,
default: int = 2048,
) -> int:
"""
Get token limit for a given embedding model.
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
Args:
model_name: Name of the embedding model
base_url: Base URL of the embedding server (for dynamic discovery)
default: Default token limit if model not found
Returns:
Token limit for the model in tokens
"""
# Try Ollama dynamic discovery if base_url provided
if base_url:
# Detect Ollama servers by port or "ollama" in URL
if "11434" in base_url or "ollama" in base_url.lower():
limit = _query_ollama_context_limit(model_name, base_url)
if limit:
return limit
# Fallback to known model registry with version handling (from PR #154)
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
base_model_name = model_name.split(":")[0]
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[model_name]
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[base_model_name]
# Check partial matches for common patterns
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
return limit
# Default fallback
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
return default
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
"""
Truncate texts to fit within token limit using tiktoken.
Args:
texts: List of text strings to truncate
token_limit: Maximum number of tokens allowed
Returns:
List of truncated texts (same length as input)
"""
if not texts:
return []
# Use tiktoken with cl100k_base encoding
enc = tiktoken.get_encoding("cl100k_base")
truncated_texts = []
truncation_count = 0
total_tokens_removed = 0
max_original_length = 0
for i, text in enumerate(texts):
tokens = enc.encode(text)
original_length = len(tokens)
if original_length <= token_limit:
# Text is within limit, keep as is
truncated_texts.append(text)
else:
# Truncate to token_limit
truncated_tokens = tokens[:token_limit]
truncated_text = enc.decode(truncated_tokens)
truncated_texts.append(truncated_text)
# Track truncation statistics
truncation_count += 1
tokens_removed = original_length - token_limit
total_tokens_removed += tokens_removed
max_original_length = max(max_original_length, original_length)
# Log individual truncation at WARNING level (first few only)
if truncation_count <= 3:
logger.warning(
f"Text {i + 1} truncated: {original_length}{token_limit} tokens "
f"({tokens_removed} tokens removed)"
)
elif truncation_count == 4:
logger.warning("Further truncation warnings suppressed...")
# Log summary at INFO level
if truncation_count > 0:
logger.warning(
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
)
else:
logger.debug(
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
)
return truncated_texts
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query Ollama /api/show for model context limit.
Args:
model_name: Name of the Ollama model
base_url: Base URL of the Ollama server
Returns:
Context limit in tokens if found, None otherwise
"""
try:
import requests
response = requests.post(
f"{base_url}/api/show",
json={"name": model_name},
timeout=5,
)
if response.status_code == 200:
data = response.json()
if "model_info" in data:
# Look for *.context_length in model_info
for key, value in data["model_info"].items():
if "context_length" in key and isinstance(value, int):
logger.info(f"Detected {model_name} context limit: {value} tokens")
return value
except Exception as e:
logger.debug(f"Failed to query Ollama context limit: {e}")
return None
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
@@ -739,10 +574,9 @@ def compute_embeddings_ollama(
host: Optional[str] = None,
) -> np.ndarray:
"""
Compute embeddings using Ollama API with true batch processing.
Compute embeddings using Ollama API with simplified batch processing.
Uses the /api/embed endpoint which supports batch inputs.
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Args:
texts: List of texts to compute embeddings for
@@ -847,11 +681,11 @@ def compute_embeddings_ollama(
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it with /api/embed
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
f"{resolved_host}/api/embed",
json={"model": model_name, "input": "test"},
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
timeout=10,
)
if test_response.status_code != 200:
@@ -883,71 +717,56 @@ def compute_embeddings_ollama(
# If torch is not available, use conservative batch size
batch_size = 32
logger.info(f"Using batch size: {batch_size} for true batch processing")
# Get model token limit and apply truncation before batching
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
logger.info(f"Model '{model_name}' token limit: {token_limit}")
# Apply truncation to all texts before batch processing
# Function logs truncation details internally
texts = truncate_to_token_limit(texts, token_limit)
logger.info(f"Using batch size: {batch_size}")
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts using /api/embed endpoint."""
max_retries = 3
retry_count = 0
"""Get embeddings for a batch of texts."""
all_embeddings = []
failed_indices = []
# Texts are already truncated to token limit by the outer function
while retry_count < max_retries:
try:
# Use /api/embed endpoint with "input" parameter for batch processing
response = requests.post(
f"{resolved_host}/api/embed",
json={"model": model_name, "input": batch_texts},
timeout=60, # Increased timeout for batch processing
)
response.raise_for_status()
for i, text in enumerate(batch_texts):
max_retries = 3
retry_count = 0
result = response.json()
batch_embeddings = result.get("embeddings")
if batch_embeddings is None:
raise ValueError("No embeddings returned from API")
if not isinstance(batch_embeddings, list):
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
if len(batch_embeddings) != len(batch_texts):
raise ValueError(
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
try:
response = requests.post(
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)
response.raise_for_status()
return batch_embeddings, []
result = response.json()
embedding = result.get("embedding")
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for batch after {max_retries} retries")
return None, list(range(len(batch_texts)))
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
# Enhanced error detection for token limit violations
error_msg = str(e).lower()
if "token" in error_msg and (
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
):
logger.error(
f"Token limit exceeded for batch. Error: {e}. "
f"Consider reducing chunk sizes or check token truncation."
)
else:
logger.error(f"Failed to get embeddings for batch: {e}")
return None, list(range(len(batch_texts)))
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
return None, list(range(len(batch_texts)))
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
# Process texts in batches
all_embeddings = []
@@ -965,7 +784,7 @@ def compute_embeddings_ollama(
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
else:
batch_iterator = range(num_batches)
@@ -976,14 +795,10 @@ def compute_embeddings_ollama(
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
if batch_embeddings is not None:
all_embeddings.extend(batch_embeddings)
else:
# Entire batch failed, add None placeholders
all_embeddings.extend([None] * len(batch_texts))
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
all_embeddings.extend(batch_embeddings)
# Handle failed embeddings
if all_failed_indices:

View File

@@ -60,11 +60,6 @@ def handle_request(request):
"maximum": 128,
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
},
"show_metadata": {
"type": "boolean",
"default": False,
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
},
},
"required": ["index_name", "query"],
},
@@ -109,8 +104,6 @@ def handle_request(request):
f"--complexity={args.get('complexity', 32)}",
"--non-interactive",
]
if args.get("show_metadata", False):
cmd.append("--show-metadata")
result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_list":

View File

@@ -57,8 +57,6 @@ dependencies = [
"tree-sitter-c-sharp>=0.20.0",
"tree-sitter-typescript>=0.20.0",
"torchvision>=0.23.0",
"einops",
"seaborn",
]
[project.optional-dependencies]

View File

@@ -8,7 +8,7 @@ import subprocess
import sys
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch
from unittest.mock import patch
import pytest
@@ -116,10 +116,8 @@ class TestChunkingFunctions:
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0
# Traditional chunks now return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks)
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(len(chunk.strip()) > 0 for chunk in chunks)
def test_create_traditional_chunks_empty_docs(self):
"""Test traditional chunking with empty documents."""
@@ -160,22 +158,11 @@ class Calculator:
# Should have multiple chunks due to different functions/classes
assert len(chunks) > 0
# R3: Expect dict format with "text" and "metadata" keys
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
"Each chunk text should be non-empty"
)
# Check metadata is present
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
"Each chunk should have file_path metadata"
)
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(len(chunk.strip()) > 0 for chunk in chunks)
# Check that code structure is somewhat preserved
combined_content = " ".join([c["text"] for c in chunks])
combined_content = " ".join(chunks)
assert "def hello_world" in combined_content
assert "class Calculator" in combined_content
@@ -207,11 +194,7 @@ class Calculator:
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0
# R3: Traditional chunking should also return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
assert all(isinstance(chunk, str) for chunk in chunks)
def test_create_text_chunks_ast_mode(self):
"""Test text chunking in AST mode."""
@@ -230,11 +213,7 @@ class Calculator:
)
assert len(chunks) > 0
# R3: AST mode should also return dict format
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
assert all(isinstance(chunk, str) for chunk in chunks)
def test_create_text_chunks_custom_extensions(self):
"""Test text chunking with custom code file extensions."""
@@ -374,552 +353,6 @@ class MathUtils:
pytest.skip("Test timed out - likely due to model download in CI")
class TestASTContentExtraction:
"""Test AST content extraction bug fix.
These tests verify that astchunk's dict format with 'content' key is handled correctly,
and that the extraction logic doesn't fall through to stringifying entire dicts.
"""
def test_extract_content_from_astchunk_dict(self):
"""Test that astchunk dict format with 'content' key is handled correctly.
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
This causes fallthrough to str(chunk), stringifying the entire dict.
This test will FAIL until the bug is fixed because:
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
- Fixed code should extract just the content value
"""
# Mock the ASTChunkBuilder class
mock_builder = Mock()
# Astchunk returns this format
astchunk_format_chunk = {
"content": "def hello():\n print('world')",
"metadata": {
"filepath": "test.py",
"line_count": 2,
"start_line_no": 0,
"end_line_no": 1,
"node_count": 1,
},
}
mock_builder.chunkify.return_value = [astchunk_format_chunk]
# Create mock document
doc = MockDocument(
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
)
# Mock the astchunk module and its ASTChunkBuilder class
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
# Patch sys.modules to inject our mock before the import
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should return dict format with proper metadata
assert len(chunks) > 0, "Should return at least one chunk"
# R3: Each chunk should be a dict
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "metadata" in chunk, "Chunk should have 'metadata' key"
chunk_text = chunk["text"]
# CRITICAL: Should NOT contain stringified dict markers in the text field
# These assertions will FAIL with current buggy code
assert "'content':" not in chunk_text, (
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
)
assert "'metadata':" not in chunk_text, (
"Chunk text contains stringified metadata - extraction failed! "
f"Got: {chunk_text[:100]}..."
)
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
"Chunk text appears to be a stringified dict"
)
# Should contain actual content
assert "def hello()" in chunk_text, "Should extract actual code content"
assert "print('world')" in chunk_text, "Should extract complete code content"
# R3: Should preserve astchunk metadata
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
"Should preserve file path metadata"
)
def test_extract_text_key_fallback(self):
"""Test that 'text' key still works for backward compatibility.
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
This test should PASS even with current code.
"""
mock_builder = Mock()
# Some chunks might use "text" key
text_key_chunk = {"text": "def legacy_function():\n return True"}
mock_builder.chunkify.return_value = [text_key_chunk]
# Create mock document
doc = MockDocument(
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract text correctly as dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
# Should NOT be stringified
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
# Should contain actual content
assert "def legacy_function()" in chunk_text
assert "return True" in chunk_text
def test_handles_string_chunks(self):
"""Test that plain string chunks still work.
Some chunkers might return plain strings - verify these are preserved.
This test should PASS with current code.
"""
mock_builder = Mock()
# Plain string chunk
plain_string_chunk = "def simple_function():\n pass"
mock_builder.chunkify.return_value = [plain_string_chunk]
# Create mock document
doc = MockDocument(
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should wrap string in dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
assert chunk_text == plain_string_chunk.strip(), (
"Should preserve plain string chunk content"
)
assert "def simple_function()" in chunk_text
assert "pass" in chunk_text
def test_multiple_chunks_with_mixed_formats(self):
"""Test handling of multiple chunks with different formats.
Real-world scenario: astchunk might return a mix of formats.
This test will FAIL if any chunk with 'content' key gets stringified.
"""
mock_builder = Mock()
# Mix of formats
mixed_chunks = [
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
"def second():\n return 2", # Plain string
{"text": "def third():\n return 3"}, # Old format
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
]
mock_builder.chunkify.return_value = mixed_chunks
# Create mock document
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract all chunks correctly as dicts
assert len(chunks) == 4, "Should extract all 4 chunks"
# Check each chunk
for i, chunk in enumerate(chunks):
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
assert "text" in chunk, f"Chunk {i} should have 'text' key"
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
chunk_text = chunk["text"]
# None should be stringified dicts
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
assert "'metadata':" not in chunk_text, (
f"Chunk {i} text is stringified (has 'metadata':)"
)
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
# Verify actual content is present
combined = "\n".join([c["text"] for c in chunks])
assert "def first()" in combined
assert "def second()" in combined
assert "def third()" in combined
assert "class MyClass:" in combined
def test_empty_content_value_handling(self):
"""Test handling of chunks with empty content values.
Edge case: chunk has 'content' key but value is empty.
Should skip these chunks, not stringify them.
"""
mock_builder = Mock()
chunks_with_empty = [
{"content": "", "metadata": {"line_count": 0}}, # Empty content
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
]
mock_builder.chunkify.return_value = chunks_with_empty
doc = MockDocument(
"def valid():\n return True", "/test/empty.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# R3: Should only have the valid chunk (empty ones filtered out)
assert len(chunks) == 1, "Should filter out empty content chunks"
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "def valid()" in chunk["text"]
# Should not have stringified the empty dict
assert "'content': ''" not in chunk["text"]
class TestASTMetadataPreservation:
"""Test metadata preservation in AST chunk dictionaries.
R3: These tests define the contract for metadata preservation when returning
chunk dictionaries instead of plain strings. Each chunk dict should have:
- "text": str - the actual chunk content
- "metadata": dict - all metadata from document AND astchunk
These tests will FAIL until G3 implementation changes return type to list[dict].
"""
def test_ast_chunks_preserve_file_metadata(self):
"""Test that document metadata is preserved in chunk metadata.
This test verifies that all document-level metadata (file_path, file_name,
creation_date, last_modified_date) is included in each chunk's metadata dict.
This will FAIL because current code returns list[str], not list[dict].
"""
# Create mock document with rich metadata
python_code = '''
def calculate_sum(numbers):
"""Calculate sum of numbers."""
return sum(numbers)
class DataProcessor:
"""Process data records."""
def process(self, data):
return [x * 2 for x in data]
'''
doc = MockDocument(
python_code,
file_path="/project/src/utils.py",
metadata={
"language": "python",
"file_path": "/project/src/utils.py",
"file_name": "utils.py",
"creation_date": "2024-01-15T10:30:00",
"last_modified_date": "2024-10-31T15:45:00",
},
)
# Mock astchunk to return chunks with metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def calculate_sum(numbers):\n return sum(numbers)",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 2,
"start_line_no": 1,
"end_line_no": 2,
"node_count": 1,
},
},
{
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 3,
"start_line_no": 5,
"end_line_no": 7,
"node_count": 2,
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These assertions will FAIL with current list[str] return type
assert len(chunks) == 2, "Should return 2 chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL: current code returns strings
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
# Document metadata preservation - WILL FAIL
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
assert metadata["file_path"] == "/project/src/utils.py", (
f"Chunk {i} file_path incorrect"
)
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
f"Chunk {i} creation_date incorrect"
)
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
f"Chunk {i} last_modified_date incorrect"
)
# Verify metadata is consistent across chunks from same document
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
"All chunks from same document should have same file_path"
)
# Verify text content is present and not stringified
assert "def calculate_sum" in chunks[0]["text"]
assert "class DataProcessor" in chunks[1]["text"]
def test_ast_chunks_include_astchunk_metadata(self):
"""Test that astchunk-specific metadata is merged into chunk metadata.
This test verifies that astchunk's metadata (line_count, start_line_no,
end_line_no, node_count) is merged with document metadata.
This will FAIL because current code returns list[str], not list[dict].
"""
python_code = '''
def function_one():
"""First function."""
x = 1
y = 2
return x + y
def function_two():
"""Second function."""
return 42
'''
doc = MockDocument(
python_code,
file_path="/test/code.py",
metadata={
"language": "python",
"file_path": "/test/code.py",
"file_name": "code.py",
},
)
# Mock astchunk with detailed metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
"metadata": {
"filepath": "/test/code.py",
"line_count": 4,
"start_line_no": 1,
"end_line_no": 4,
"node_count": 5, # function, assignments, return
},
},
{
"content": "def function_two():\n return 42",
"metadata": {
"filepath": "/test/code.py",
"line_count": 2,
"start_line_no": 7,
"end_line_no": 8,
"node_count": 2, # function, return
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These will FAIL with current list[str] return
assert len(chunks) == 2
# First chunk - function_one
chunk1 = chunks[0]
assert isinstance(chunk1, dict), "Chunk should be dict"
assert "metadata" in chunk1
metadata1 = chunk1["metadata"]
# Check astchunk metadata is present
assert "line_count" in metadata1, "Should include astchunk line_count"
assert metadata1["line_count"] == 4, "line_count should be 4"
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
assert "node_count" in metadata1, "Should include astchunk node_count"
assert metadata1["node_count"] == 5, "node_count should be 5"
# Second chunk - function_two
chunk2 = chunks[1]
metadata2 = chunk2["metadata"]
assert metadata2["line_count"] == 2, "line_count should be 2"
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
assert metadata2["node_count"] == 2, "node_count should be 2"
# Verify document metadata is ALSO present (merged, not replaced)
assert metadata1["file_path"] == "/test/code.py"
assert metadata1["file_name"] == "code.py"
assert metadata2["file_path"] == "/test/code.py"
assert metadata2["file_name"] == "code.py"
# Verify text content is correct
assert "def function_one" in chunk1["text"]
assert "def function_two" in chunk2["text"]
def test_traditional_chunks_as_dicts_helper(self):
"""Test the helper function that wraps traditional chunks as dicts.
This test verifies that when create_traditional_chunks is called,
its plain string chunks are wrapped into dict format with metadata.
This will FAIL because the helper function _traditional_chunks_as_dicts()
doesn't exist yet, and create_traditional_chunks returns list[str].
"""
# Create documents with various metadata
docs = [
MockDocument(
"This is the first paragraph of text. It contains multiple sentences. "
"This should be split into chunks based on size.",
file_path="/docs/readme.txt",
metadata={
"file_path": "/docs/readme.txt",
"file_name": "readme.txt",
"creation_date": "2024-01-01",
},
),
MockDocument(
"Second document with different metadata. It also has content that needs chunking.",
file_path="/docs/guide.md",
metadata={
"file_path": "/docs/guide.md",
"file_name": "guide.md",
"last_modified_date": "2024-10-31",
},
),
]
# Call create_traditional_chunks (which should now return list[dict])
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
# CRITICAL: Will FAIL - current code returns list[str]
assert len(chunks) > 0, "Should return chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
# Text should be non-empty
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
# Metadata should include document info
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
# Verify metadata tracking works correctly
# At least one chunk should be from readme.txt
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
# At least one chunk should be from guide.md
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
# Verify creation_date is preserved for readme chunks
for chunk in readme_chunks:
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
"readme.txt chunks should preserve creation_date"
)
# Verify last_modified_date is preserved for guide chunks
for chunk in guide_chunks:
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
"guide.md chunks should preserve last_modified_date"
)
# Verify text content is present
all_text = " ".join([c["text"] for c in chunks])
assert "first paragraph" in all_text
assert "Second document" in all_text
class TestErrorHandling:
"""Test error handling and edge cases."""

View File

@@ -1,268 +0,0 @@
"""Unit tests for token-aware truncation functionality.
This test suite defines the contract for token truncation functions that prevent
500 errors from Ollama when text exceeds model token limits. These tests verify:
1. Model token limit retrieval (known and unknown models)
2. Text truncation behavior for single and multiple texts
3. Token counting and truncation accuracy using tiktoken
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
import pytest
import tiktoken
from leann.embedding_compute import (
EMBEDDING_MODEL_LIMITS,
get_model_token_limit,
truncate_to_token_limit,
)
class TestModelTokenLimits:
"""Tests for retrieving model-specific token limits."""
def test_get_model_token_limit_known_model(self):
"""Verify correct token limit is returned for known models.
Known models should return their specific token limits from
EMBEDDING_MODEL_LIMITS dictionary.
"""
# Test nomic-embed-text (2048 tokens)
limit = get_model_token_limit("nomic-embed-text")
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
# Test nomic-embed-text-v1.5 (2048 tokens)
limit = get_model_token_limit("nomic-embed-text-v1.5")
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
# Test nomic-embed-text-v2 (512 tokens)
limit = get_model_token_limit("nomic-embed-text-v2")
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
# Test OpenAI models (8192 tokens)
limit = get_model_token_limit("text-embedding-3-small")
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
def test_get_model_token_limit_unknown_model(self):
"""Verify default token limit is returned for unknown models.
Unknown models should return the default limit (2048) to allow
operation with reasonable safety margin.
"""
# Test with completely unknown model
limit = get_model_token_limit("unknown-model-xyz")
assert limit == 2048, "Unknown models should return default 2048"
# Test with empty string
limit = get_model_token_limit("")
assert limit == 2048, "Empty model name should return default 2048"
def test_get_model_token_limit_custom_default(self):
"""Verify custom default can be specified for unknown models.
Allow callers to specify their own default token limit when
model is not in the known models dictionary.
"""
limit = get_model_token_limit("unknown-model", default=4096)
assert limit == 4096, "Should return custom default for unknown models"
# Known model should ignore custom default
limit = get_model_token_limit("nomic-embed-text", default=4096)
assert limit == 2048, "Known model should ignore custom default"
def test_embedding_model_limits_dictionary_exists(self):
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
The dictionary should be importable and contain at least the
known nomic models with correct token limits.
"""
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
"Should contain nomic-embed-text-v1.5"
)
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
# OpenAI models
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
class TestTokenTruncation:
"""Tests for truncating texts to token limits."""
@pytest.fixture
def tokenizer(self):
"""Provide tiktoken tokenizer for token counting verification."""
return tiktoken.get_encoding("cl100k_base")
def test_truncate_single_text_under_limit(self, tokenizer):
"""Verify text under token limit remains unchanged.
When text is already within the token limit, it should be
returned unchanged with no truncation.
"""
text = "This is a short text that is well under the token limit."
token_count = len(tokenizer.encode(text))
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
# Truncate with generous limit
result = truncate_to_token_limit([text], token_limit=512)
assert len(result) == 1, "Should return same number of texts"
assert result[0] == text, "Text under limit should be unchanged"
def test_truncate_single_text_over_limit(self, tokenizer):
"""Verify text over token limit is truncated correctly.
When text exceeds the token limit, it should be truncated to
fit within the limit while maintaining valid token boundaries.
"""
# Create a text that definitely exceeds limit
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 50, (
f"Test setup: text should be long (has {original_token_count} tokens)"
)
# Truncate to 50 tokens
result = truncate_to_token_limit([text], token_limit=50)
assert len(result) == 1, "Should return same number of texts"
assert result[0] != text, "Text over limit should be truncated"
assert len(result[0]) < len(text), "Truncated text should be shorter"
# Verify truncated text is within token limit
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 50, (
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
)
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
"""Verify multiple texts with mixed lengths are handled correctly.
When processing multiple texts:
- Texts under limit should remain unchanged
- Texts over limit should be truncated independently
- Output list should maintain same order and length
"""
texts = [
"Short text.", # Under limit
"word " * 200, # Over limit
"Another short one.", # Under limit
"token " * 150, # Over limit
]
# Verify test setup
for i, text in enumerate(texts):
token_count = len(tokenizer.encode(text))
if i in [1, 3]:
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
else:
assert token_count < 50, (
f"Text {i} should be under limit (has {token_count} tokens)"
)
# Truncate with 50 token limit
result = truncate_to_token_limit(texts, token_limit=50)
assert len(result) == len(texts), "Should return same number of texts"
# Verify each text individually
for i, (original, truncated) in enumerate(zip(texts, result)):
token_count = len(tokenizer.encode(truncated))
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
# Short texts should be unchanged
if i in [0, 2]:
assert truncated == original, f"Short text {i} should be unchanged"
# Long texts should be truncated
else:
assert len(truncated) < len(original), f"Long text {i} should be truncated"
def test_truncate_empty_list(self):
"""Verify empty input list returns empty output list.
Edge case: empty list should return empty list without errors.
"""
result = truncate_to_token_limit([], token_limit=512)
assert result == [], "Empty input should return empty output"
def test_truncate_preserves_order(self, tokenizer):
"""Verify truncation preserves original text order.
Output list should maintain the same order as input list,
regardless of which texts were truncated.
"""
texts = [
"First text " * 50, # Will be truncated
"Second text.", # Won't be truncated
"Third text " * 50, # Will be truncated
]
result = truncate_to_token_limit(texts, token_limit=20)
assert len(result) == 3, "Should preserve list length"
# Check that order is maintained by looking for distinctive words
assert "First" in result[0], "First text should remain in first position"
assert "Second" in result[1], "Second text should remain in second position"
assert "Third" in result[2], "Third text should remain in third position"
def test_truncate_extremely_long_text(self, tokenizer):
"""Verify extremely long texts are truncated efficiently.
Test with text that far exceeds token limit to ensure
truncation handles extreme cases without performance issues.
"""
# Create very long text (simulate real-world scenario)
text = "token " * 5000 # ~5000+ tokens
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 1000, "Test setup: text should be very long"
# Truncate to small limit
result = truncate_to_token_limit([text], token_limit=100)
assert len(result) == 1
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 100, (
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
)
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
def test_truncate_exact_token_limit(self, tokenizer):
"""Verify text at exactly token limit is handled correctly.
Edge case: text with exactly the token limit should either
remain unchanged or be safely truncated by 1 token.
"""
# Create text with approximately 50 tokens
# We'll adjust to get exactly 50
target_tokens = 50
text = "word " * 50
tokens = tokenizer.encode(text)
# Adjust to get exactly target_tokens
if len(tokens) > target_tokens:
tokens = tokens[:target_tokens]
text = tokenizer.decode(tokens)
elif len(tokens) < target_tokens:
# Add more words
while len(tokenizer.encode(text)) < target_tokens:
text += "word "
tokens = tokenizer.encode(text)[:target_tokens]
text = tokenizer.decode(tokens)
# Verify we have exactly target_tokens
assert len(tokenizer.encode(text)) == target_tokens, (
"Test setup: should have exactly 50 tokens"
)
result = truncate_to_token_limit([text], token_limit=target_tokens)
assert len(result) == 1
result_tokens = len(tokenizer.encode(result[0]))
assert result_tokens <= target_tokens, (
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
)

7553
uv.lock generated
View File

File diff suppressed because it is too large Load Diff