Compare commits
32 Commits
v0.3.5
...
add-gh-pat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae29ae9b88 | ||
|
|
e5977e4c4f | ||
|
|
cbd6c8ab34 | ||
|
|
8a2ea37871 | ||
|
|
7ddb4772c0 | ||
|
|
a1c21adbce | ||
|
|
d1b3c93a5a | ||
|
|
a6ee95b18a | ||
|
|
17cbd07b25 | ||
|
|
3629ccf8f7 | ||
|
|
0175bc9c20 | ||
|
|
af47dfdde7 | ||
|
|
f13bd02fbd | ||
|
|
a0bbf831db | ||
|
|
86287d8832 | ||
|
|
76cc798e3e | ||
|
|
d599566fd7 | ||
|
|
00770aebbb | ||
|
|
e268392d5b | ||
|
|
eb909ccec5 | ||
|
|
13beb98164 | ||
|
|
969f514564 | ||
|
|
1ef9cba7de | ||
|
|
a63550944b | ||
|
|
97493a2896 | ||
|
|
f7d2dc6e7c | ||
|
|
ea86b283cb | ||
|
|
e7519bceaa | ||
|
|
abf0b2c676 | ||
|
|
930b79cc98 | ||
|
|
9b7353f336 | ||
|
|
9dd0e0b26f |
8
.github/workflows/build-reusable.yml
vendored
8
.github/workflows/build-reusable.yml
vendored
@@ -16,8 +16,10 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
fetch-depth: 1
|
||||
token: ${{ secrets.GH_PAT != '' && secrets.GH_PAT || secrets.GITHUB_TOKEN }}
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
@@ -91,8 +93,10 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
fetch-depth: 1
|
||||
token: ${{ secrets.GH_PAT != '' && secrets.GH_PAT || secrets.GITHUB_TOKEN }}
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
5
.github/workflows/link-check.yml
vendored
5
.github/workflows/link-check.yml
vendored
@@ -12,8 +12,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
token: ${{ secrets.GH_PAT != '' && secrets.GH_PAT || secrets.GITHUB_TOKEN }}
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
5
.github/workflows/release-manual.yml
vendored
5
.github/workflows/release-manual.yml
vendored
@@ -19,6 +19,9 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
token: ${{ secrets.GH_PAT != '' && secrets.GH_PAT || secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Validate version
|
||||
run: |
|
||||
@@ -73,6 +76,8 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
token: ${{ secrets.GH_PAT != '' && secrets.GH_PAT || secrets.GITHUB_TOKEN }}
|
||||
ref: 'main'
|
||||
|
||||
- name: Download all artifacts
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -91,7 +91,8 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
|
||||
*.npy
|
||||
*.db
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
|
||||
75
README.md
75
README.md
@@ -16,15 +16,27 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
|
||||
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
|
||||
</a>
|
||||
<p>
|
||||
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
|
||||
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
|
||||
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
@@ -189,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
|
||||
|
||||
#### LLM Backend
|
||||
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
|
||||
|
||||
|
||||
<details>
|
||||
@@ -257,6 +269,7 @@ Below is a list of base URLs for common providers to get you started.
|
||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||
| **Anthropic** | `https://api.anthropic.com/v1` |
|
||||
|
||||
|
||||
|
||||
@@ -316,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
@@ -379,6 +392,54 @@ python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authenticat
|
||||
|
||||
</details>
|
||||
|
||||
### 🎨 ColQwen: Multimodal PDF Retrieval with Vision-Language Models
|
||||
|
||||
Search through PDFs using both text and visual understanding with ColQwen2/ColPali models. Perfect for research papers, technical documents, and any PDFs with complex layouts, figures, or diagrams.
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: ColQwen Setup & Usage</strong></summary>
|
||||
|
||||
#### Prerequisites
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
#### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 # or colpali
|
||||
```
|
||||
|
||||
#### Search
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
#### Models
|
||||
- **ColQwen2** (`colqwen2`): Latest vision-language model with improved performance
|
||||
- **ColPali** (`colpali`): Proven multimodal retriever
|
||||
|
||||
For detailed usage, see the [ColQwen Guide](docs/COLQWEN_GUIDE.md).
|
||||
|
||||
</details>
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||
@@ -1045,10 +1106,10 @@ Options:
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
**List Command:**
|
||||
|
||||
@@ -6,7 +6,7 @@ Provides common parameters and functionality for all RAG examples.
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
@@ -257,8 +257,8 @@ class BaseRAGExample(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
|
||||
"""Load data from the source. Returns list of text chunks (strings or dicts with 'text' key)."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
@@ -282,8 +282,8 @@ class BaseRAGExample(ABC):
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
async def build_index(self, args, texts: list[Union[str, dict[str, Any]]]) -> str:
|
||||
"""Build LEANN index from texts (accepts strings or dicts with 'text' key)."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
@@ -314,8 +314,14 @@ class BaseRAGExample(ABC):
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
for item in batch:
|
||||
# Handle both dict format (from create_text_chunks) and plain strings
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text", "")
|
||||
metadata = item.get("metadata")
|
||||
builder.add_text(text, metadata)
|
||||
else:
|
||||
builder.add_text(item)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
|
||||
364
apps/colqwen_rag.py
Normal file
364
apps/colqwen_rag.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
|
||||
|
||||
Usage:
|
||||
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
|
||||
python -m apps.colqwen_rag search my_index "How does attention work?"
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
# Add LEANN packages to path
|
||||
_repo_root = Path(__file__).resolve().parents[1]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
import torch # noqa: E402
|
||||
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
|
||||
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
|
||||
from pdf2image import convert_from_path # noqa: E402
|
||||
from PIL import Image # noqa: E402
|
||||
from torch.utils.data import DataLoader # noqa: E402
|
||||
from tqdm import tqdm # noqa: E402
|
||||
|
||||
# Import the existing multi-vector implementation
|
||||
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
|
||||
class ColQwenRAG:
|
||||
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
|
||||
|
||||
def __init__(self, model_type: str = "colpali"):
|
||||
"""
|
||||
Initialize ColQwen RAG system.
|
||||
|
||||
Args:
|
||||
model_type: "colqwen2" or "colpali"
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.device = self._get_device()
|
||||
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
|
||||
if self.device.type == "mps":
|
||||
self.dtype = torch.float32
|
||||
elif self.device.type == "cuda":
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
|
||||
|
||||
# Load model and processor with MPS-optimized settings
|
||||
try:
|
||||
if model_type == "colqwen2":
|
||||
self.model_name = "vidore/colqwen2-v1.0"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||
else: # colpali
|
||||
self.model_name = "vidore/colpali-v1.2"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
|
||||
except Exception as e:
|
||||
if "memory" in str(e).lower() or "offload" in str(e).lower():
|
||||
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_type == "colqwen2":
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_device(self):
|
||||
"""Auto-select best available device."""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
|
||||
"""
|
||||
Build multimodal index from PDF files.
|
||||
|
||||
Args:
|
||||
pdf_paths: List of PDF file paths
|
||||
index_name: Name for the index
|
||||
pages_dir: Directory to save page images (optional)
|
||||
"""
|
||||
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
|
||||
|
||||
# Convert PDFs to images
|
||||
all_images = []
|
||||
all_metadata = []
|
||||
|
||||
if pages_dir:
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
|
||||
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
|
||||
try:
|
||||
images = convert_from_path(pdf_path, dpi=150)
|
||||
pdf_name = Path(pdf_path).stem
|
||||
|
||||
for i, image in enumerate(images):
|
||||
# Save image if pages_dir specified
|
||||
if pages_dir:
|
||||
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
|
||||
image.save(image_path)
|
||||
|
||||
all_images.append(image)
|
||||
all_metadata.append(
|
||||
{
|
||||
"pdf_path": pdf_path,
|
||||
"pdf_name": pdf_name,
|
||||
"page_number": i + 1,
|
||||
"image_path": str(image_path) if pages_dir else None,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {pdf_path}: {e}")
|
||||
continue
|
||||
|
||||
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
|
||||
print(f"All metadata: {all_metadata}")
|
||||
|
||||
# Generate embeddings
|
||||
print("🧠 Generating embeddings...")
|
||||
embeddings = self._embed_images(all_images)
|
||||
|
||||
# Build LEANN index
|
||||
print("🔍 Building LEANN index...")
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=embeddings.shape[-1],
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Create collection and insert data
|
||||
leann_mv.create_collection()
|
||||
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
|
||||
data = {
|
||||
"doc_id": i,
|
||||
"filepath": metadata.get("image_path", ""),
|
||||
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
|
||||
}
|
||||
leann_mv.insert(data)
|
||||
|
||||
# Build the index
|
||||
leann_mv.create_index()
|
||||
print(f"✅ Index '{index_name}' built successfully!")
|
||||
|
||||
return leann_mv
|
||||
|
||||
def search(self, index_name: str, query: str, top_k: int = 5):
|
||||
"""
|
||||
Search the index with a text query.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to search
|
||||
query: Text query
|
||||
top_k: Number of results to return
|
||||
"""
|
||||
print(f"🔍 Searching '{index_name}' for: '{query}'")
|
||||
|
||||
# Load index
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=128, # Will be updated when loading
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = self._embed_query(query)
|
||||
|
||||
# Search (returns list of (score, doc_id) tuples)
|
||||
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
|
||||
|
||||
# Display results
|
||||
print(f"\n📋 Top {len(search_results)} results:")
|
||||
for i, (score, doc_id) in enumerate(search_results, 1):
|
||||
# Get metadata for this doc_id (we need to load the metadata)
|
||||
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
|
||||
|
||||
return search_results
|
||||
|
||||
def ask(self, index_name: str, interactive: bool = False):
|
||||
"""
|
||||
Interactive Q&A with the indexed documents.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to query
|
||||
interactive: Whether to run in interactive mode
|
||||
"""
|
||||
print(f"💬 ColQwen Chat with '{index_name}'")
|
||||
|
||||
if interactive:
|
||||
print("Type 'quit' to exit, 'help' for commands")
|
||||
while True:
|
||||
try:
|
||||
query = input("\n🤔 Your question: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("Commands: quit/exit/q (exit), help (this message)")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
self.search(index_name, query, top_k=3)
|
||||
|
||||
# TODO: Add answer generation with Qwen-VL
|
||||
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
else:
|
||||
query = input("🤔 Your question: ").strip()
|
||||
if query:
|
||||
self.search(index_name, query)
|
||||
|
||||
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
|
||||
"""Generate embeddings for a list of images."""
|
||||
dataset = ListDataset(images)
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
|
||||
|
||||
embeddings = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Embedding images"):
|
||||
batch_images = cast(list, batch)
|
||||
batch_inputs = self.processor.process_images(batch_images).to(self.device)
|
||||
batch_embeddings = self.model(**batch_inputs)
|
||||
embeddings.append(batch_embeddings.cpu())
|
||||
|
||||
return torch.cat(embeddings, dim=0)
|
||||
|
||||
def _embed_query(self, query: str) -> torch.Tensor:
|
||||
"""Generate embedding for a text query."""
|
||||
with torch.no_grad():
|
||||
query_inputs = self.processor.process_queries([query]).to(self.device)
|
||||
query_embedding = self.model(**query_inputs)
|
||||
return query_embedding.cpu()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
|
||||
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
|
||||
build_parser.add_argument("--index", required=True, help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
build_parser.add_argument("--pages-dir", help="Directory to save page images")
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search the index")
|
||||
search_parser.add_argument("index", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
|
||||
search_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
|
||||
ask_parser.add_argument("index", help="Index name")
|
||||
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||
ask_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Initialize ColQwen RAG
|
||||
if args.command == "build":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
|
||||
# Get PDF files
|
||||
pdf_dir = Path(args.pdfs)
|
||||
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
|
||||
pdf_paths = [str(pdf_dir)]
|
||||
elif pdf_dir.is_dir():
|
||||
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
|
||||
else:
|
||||
print(f"❌ Invalid PDF path: {args.pdfs}")
|
||||
return
|
||||
|
||||
if not pdf_paths:
|
||||
print(f"❌ No PDF files found in {args.pdfs}")
|
||||
return
|
||||
|
||||
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
|
||||
|
||||
elif args.command == "search":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.search(args.index, args.query, args.top_k)
|
||||
|
||||
elif args.command == "ask":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.ask(args.index, args.interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,6 +5,7 @@ Supports PDF, TXT, MD, and other document formats.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -51,7 +52,7 @@ class DocumentRAG(BaseRAGExample):
|
||||
help="Enable AST-aware chunking for code files in the data directory",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
|
||||
218
apps/image_rag.py
Normal file
218
apps/image_rag.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLIP Image RAG Application
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||
You can index a directory of images and search them using text queries.
|
||||
|
||||
Usage:
|
||||
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
|
||||
|
||||
class ImageRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for images using CLIP embeddings.
|
||||
|
||||
This class provides a complete RAG pipeline for image data, including
|
||||
CLIP embedding generation, indexing, and text-based image search.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Image RAG",
|
||||
description="RAG application for images using CLIP embeddings",
|
||||
default_index_name="image_index",
|
||||
)
|
||||
# Override default embedding model to use CLIP
|
||||
self.embedding_model_default = "clip-ViT-L-14"
|
||||
self.embedding_mode_default = "sentence-transformers"
|
||||
self._image_data: list[dict] = []
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add image-specific arguments."""
|
||||
image_group = parser.add_argument_group("Image Parameters")
|
||||
image_group.add_argument(
|
||||
"--image-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing images to index",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--image-extensions",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for CLIP embedding generation (default: 32)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||
self._image_data = self._load_images_and_embeddings(args)
|
||||
return [entry["text"] for entry in self._image_data]
|
||||
|
||||
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||
"""Helper to process images and produce embeddings/metadata."""
|
||||
image_dir = Path(args.image_dir)
|
||||
if not image_dir.exists():
|
||||
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||
|
||||
print(f"📸 Loading images from {image_dir}...")
|
||||
|
||||
# Find all image files
|
||||
image_files = []
|
||||
for ext in args.image_extensions:
|
||||
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||
|
||||
if not image_files:
|
||||
raise ValueError(
|
||||
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||
)
|
||||
|
||||
print(f"✅ Found {len(image_files)} images")
|
||||
|
||||
# Limit if max_items is set
|
||||
if args.max_items > 0:
|
||||
image_files = image_files[: args.max_items]
|
||||
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||
|
||||
# Load CLIP model
|
||||
print("🔍 Loading CLIP model...")
|
||||
model = SentenceTransformer(self.embedding_model_default)
|
||||
|
||||
# Process images and generate embeddings
|
||||
print("🖼️ Processing images and generating embeddings...")
|
||||
image_data = []
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
for image_path in tqdm(image_files, desc="Processing images"):
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
batch_images.append(image)
|
||||
batch_paths.append(image_path)
|
||||
|
||||
# Process in batches
|
||||
if len(batch_images) >= args.batch_size:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=args.batch_size,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||
continue
|
||||
|
||||
# Process remaining images
|
||||
if batch_images:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=len(batch_images),
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"✅ Processed {len(image_data)} images")
|
||||
return image_data
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build index using pre-computed CLIP embeddings."""
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if not self._image_data or len(self._image_data) != len(texts):
|
||||
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||
|
||||
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=self.embedding_model_default,
|
||||
embedding_mode=self.embedding_mode_default,
|
||||
is_recompute=False,
|
||||
distance_metric="cosine",
|
||||
graph_degree=args.graph_degree,
|
||||
build_complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
)
|
||||
|
||||
for text, data in zip(texts, self._image_data):
|
||||
builder.add_text(text=text, metadata=data["metadata"])
|
||||
|
||||
ids = [str(i) for i in range(len(self._image_data))]
|
||||
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||
pickle.dump((ids, embeddings), f)
|
||||
pkl_path = f.name
|
||||
|
||||
try:
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||
print(f"✅ Index built successfully at {index_path}")
|
||||
return index_path
|
||||
finally:
|
||||
Path(pkl_path).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the image RAG application."""
|
||||
import asyncio
|
||||
|
||||
app = ImageRAG()
|
||||
asyncio.run(app.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the current directory to path to import leann_multi_vector
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||
from PIL import Image
|
||||
|
||||
# Ensure repo paths are importable
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Set environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image."""
|
||||
# Create a simple RGB image (800x600)
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
return img
|
||||
|
||||
|
||||
def load_test_image_from_file():
|
||||
"""Try to load an image from the indexes directory if available."""
|
||||
# Try to find an existing image in the indexes directory
|
||||
indexes_dir = Path(__file__).parent / "indexes"
|
||||
|
||||
# Look for images in common locations
|
||||
possible_paths = [
|
||||
indexes_dir / "vidore_fastplaid" / "images",
|
||||
indexes_dir / "colvision_large.leann.images",
|
||||
indexes_dir / "colvision.leann.images",
|
||||
]
|
||||
|
||||
for img_dir in possible_paths:
|
||||
if img_dir.exists():
|
||||
# Find first image file
|
||||
for ext in [".png", ".jpg", ".jpeg"]:
|
||||
for img_file in img_dir.glob(f"*{ext}"):
|
||||
print(f"Loading test image from: {img_file}")
|
||||
return Image.open(img_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Testing ColQwen2 Forward Pass")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Load or create test image
|
||||
print("\n[Step 1] Loading test image...")
|
||||
test_image = load_test_image_from_file()
|
||||
if test_image is None:
|
||||
print("No existing image found, creating a simple test image...")
|
||||
test_image = create_test_image()
|
||||
else:
|
||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||
|
||||
# Convert to RGB if needed
|
||||
if test_image.mode != "RGB":
|
||||
test_image = test_image.convert("RGB")
|
||||
print(f"✓ Converted to RGB: {test_image.size}")
|
||||
|
||||
# Step 2: Load model
|
||||
print("\n[Step 2] Loading ColQwen2 model...")
|
||||
try:
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||
print(f"✓ Model loaded: {model_name}")
|
||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||
|
||||
# Print model info
|
||||
if hasattr(model, "device"):
|
||||
print(f"✓ Model device: {model.device}")
|
||||
if hasattr(model, "dtype"):
|
||||
print(f"✓ Model dtype: {model.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Step 3: Test forward pass
|
||||
print("\n[Step 3] Running forward pass...")
|
||||
try:
|
||||
# Use the _embed_images function which handles batching and forward pass
|
||||
images = [test_image]
|
||||
print(f"Processing {len(images)} image(s)...")
|
||||
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
print("✓ Forward pass completed!")
|
||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||
|
||||
if len(doc_vecs) > 0:
|
||||
emb = doc_vecs[0]
|
||||
print(f"✓ Embedding shape: {emb.shape}")
|
||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||
print("✓ Embedding stats:")
|
||||
print(f" - Min: {emb.min().item():.4f}")
|
||||
print(f" - Max: {emb.max().item():.4f}")
|
||||
print(f" - Mean: {emb.mean().item():.4f}")
|
||||
print(f" - Std: {emb.std().item():.4f}")
|
||||
|
||||
# Check for NaN or Inf
|
||||
if torch.isnan(emb).any():
|
||||
print("⚠ Warning: Embedding contains NaN values!")
|
||||
if torch.isinf(emb).any():
|
||||
print("⚠ Warning: Embedding contains Inf values!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during forward pass: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,11 @@
|
||||
import concurrent.futures
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
@@ -10,6 +13,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||
@@ -95,12 +100,63 @@ def _natural_sort_key(name: str) -> int:
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||
filenames = sorted(filenames, key=_natural_sort_key)
|
||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||
images = [Image.open(p) for p in filepaths]
|
||||
return filepaths, images
|
||||
def _load_images_from_dir(
|
||||
pages_dir: str, recursive: bool = False
|
||||
) -> tuple[list[str], list[Image.Image]]:
|
||||
"""
|
||||
Load images from a directory.
|
||||
|
||||
Args:
|
||||
pages_dir: Directory path containing images
|
||||
recursive: If True, recursively search subdirectories (default: False)
|
||||
|
||||
Returns:
|
||||
Tuple of (filepaths, images)
|
||||
"""
|
||||
|
||||
# Supported image extensions
|
||||
extensions = ("*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG", "*.webp", "*.WEBP")
|
||||
|
||||
if recursive:
|
||||
# Recursive search
|
||||
filepaths = []
|
||||
for ext in extensions:
|
||||
pattern = os.path.join(pages_dir, "**", ext)
|
||||
filepaths.extend(glob.glob(pattern, recursive=True))
|
||||
else:
|
||||
# Non-recursive search (only top-level directory)
|
||||
filepaths = []
|
||||
for ext in extensions:
|
||||
pattern = os.path.join(pages_dir, ext)
|
||||
filepaths.extend(glob.glob(pattern))
|
||||
|
||||
# Sort files naturally
|
||||
filepaths = sorted(filepaths, key=lambda x: _natural_sort_key(os.path.basename(x)))
|
||||
|
||||
# Load images with error handling
|
||||
images = []
|
||||
valid_filepaths = []
|
||||
failed_count = 0
|
||||
|
||||
for filepath in filepaths:
|
||||
try:
|
||||
img = Image.open(filepath)
|
||||
# Convert to RGB if necessary (handles RGBA, P, etc.)
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
images.append(img)
|
||||
valid_filepaths.append(filepath)
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"Warning: Failed to load image {filepath}: {e}")
|
||||
continue
|
||||
|
||||
if failed_count > 0:
|
||||
print(
|
||||
f"Warning: Failed to load {failed_count} image(s) out of {len(filepaths)} total files"
|
||||
)
|
||||
|
||||
return valid_filepaths, images
|
||||
|
||||
|
||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||
@@ -150,36 +206,99 @@ def _select_device_and_dtype():
|
||||
|
||||
|
||||
def _load_colvision(model_choice: str):
|
||||
import os
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||
from colpali_engine.models import (
|
||||
ColPali,
|
||||
ColQwen2,
|
||||
ColQwen2_5,
|
||||
ColQwen2_5_Processor,
|
||||
ColQwen2Processor,
|
||||
)
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
|
||||
# Set environment variables to ensure models are downloaded from HuggingFace
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
|
||||
# Log model loading info
|
||||
logger.info(f"Loading ColVision model: {model_choice}")
|
||||
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
|
||||
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
|
||||
|
||||
device_str, device, dtype = _select_device_and_dtype()
|
||||
|
||||
# Determine model name and type
|
||||
# IMPORTANT: Check colqwen2.5 BEFORE colqwen2 to avoid false matches
|
||||
model_choice_lower = model_choice.lower()
|
||||
if model_choice == "colqwen2":
|
||||
model_name = "vidore/colqwen2-v1.0"
|
||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||
else "eager"
|
||||
)
|
||||
model_type = "colqwen2"
|
||||
elif model_choice == "colqwen2.5" or model_choice == "colqwen25":
|
||||
model_name = "vidore/colqwen2.5-v0.2"
|
||||
model_type = "colqwen2.5"
|
||||
elif model_choice == "colpali":
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model_type = "colpali"
|
||||
elif (
|
||||
"colqwen2.5" in model_choice_lower
|
||||
or "colqwen25" in model_choice_lower
|
||||
or "colqwen2_5" in model_choice_lower
|
||||
):
|
||||
# Handle HuggingFace model names like "vidore/colqwen2.5-v0.2"
|
||||
model_name = model_choice
|
||||
model_type = "colqwen2.5"
|
||||
elif "colqwen2" in model_choice_lower and "colqwen2-v1.0" in model_choice_lower:
|
||||
# Handle HuggingFace model names like "vidore/colqwen2-v1.0" (but not colqwen2.5)
|
||||
model_name = model_choice
|
||||
model_type = "colqwen2"
|
||||
elif "colpali" in model_choice_lower:
|
||||
# Handle HuggingFace model names like "vidore/colpali-v1.2"
|
||||
model_name = model_choice
|
||||
model_type = "colpali"
|
||||
else:
|
||||
# Default to colpali for backward compatibility
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model_type = "colpali"
|
||||
|
||||
# Load model based on type
|
||||
attn_implementation = (
|
||||
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
|
||||
)
|
||||
|
||||
# Load model from HuggingFace Hub (not Google Drive)
|
||||
# Use local_files_only=False to ensure download from HF if not cached
|
||||
if model_type == "colqwen2.5":
|
||||
model = ColQwen2_5.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
|
||||
elif model_type == "colqwen2":
|
||||
model = ColQwen2.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
else:
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
|
||||
else: # colpali
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
processor = cast(
|
||||
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
|
||||
)
|
||||
|
||||
return model_name, model, processor, device_str, device, dtype
|
||||
|
||||
@@ -194,7 +313,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[Image.Image](images),
|
||||
batch_size=1,
|
||||
batch_size=32,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
@@ -218,32 +337,47 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||
|
||||
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
# Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings:
|
||||
# 1. MTEB receives batch["text"] which already includes instruction/prompt (from _combine_queries_with_instruction_text)
|
||||
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10
|
||||
# 3. Calls processor.process_queries(batch) where batch is now a list of strings
|
||||
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
|
||||
#
|
||||
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
|
||||
# We need to match this exactly to reproduce MTEB results
|
||||
|
||||
all_embeds = []
|
||||
batch_size = 32 # Match MTEB's default batch_size
|
||||
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
|
||||
batch_queries = queries[i : i + batch_size]
|
||||
|
||||
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
|
||||
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
|
||||
batch = [
|
||||
processor.query_prefix + t + processor.query_augmentation_token * 10
|
||||
for t in batch_queries
|
||||
]
|
||||
inputs = processor.process_queries(batch)
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
|
||||
q_vecs: list[Any] = []
|
||||
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_query = model(**batch_query)
|
||||
outs = model(**inputs)
|
||||
else:
|
||||
embeddings_query = model(**batch_query)
|
||||
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
return q_vecs
|
||||
outs = model(**inputs)
|
||||
|
||||
# Match MTEB: convert to float32 on CPU
|
||||
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
|
||||
|
||||
return all_embeds
|
||||
|
||||
|
||||
def _build_index(
|
||||
@@ -283,6 +417,279 @@ def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
|
||||
return None
|
||||
|
||||
|
||||
def _build_fast_plaid_index(
|
||||
index_path: str,
|
||||
doc_vecs: list[Any],
|
||||
filepaths: list[str],
|
||||
images: list[Image.Image],
|
||||
) -> tuple[Any, float]:
|
||||
"""
|
||||
Build a Fast-Plaid index from document embeddings.
|
||||
|
||||
Args:
|
||||
index_path: Path to save the Fast-Plaid index
|
||||
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
|
||||
filepaths: List of filepath identifiers for each document
|
||||
images: List of PIL Images corresponding to each document
|
||||
|
||||
Returns:
|
||||
Tuple of (FastPlaid index object, build_time_in_seconds)
|
||||
"""
|
||||
import torch
|
||||
from fast_plaid import search as fast_plaid_search
|
||||
|
||||
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
|
||||
_t0 = time.perf_counter()
|
||||
|
||||
# Convert doc_vecs to list of tensors
|
||||
documents_embeddings = []
|
||||
for i, vec in enumerate(doc_vecs):
|
||||
if i % 1000 == 0:
|
||||
print(f" Converting embedding {i}/{len(doc_vecs)}...")
|
||||
if not isinstance(vec, torch.Tensor):
|
||||
vec = (
|
||||
torch.tensor(vec)
|
||||
if isinstance(vec, np.ndarray)
|
||||
else torch.from_numpy(np.array(vec))
|
||||
)
|
||||
# Ensure float32 for Fast-Plaid
|
||||
if vec.dtype != torch.float32:
|
||||
vec = vec.float()
|
||||
documents_embeddings.append(vec)
|
||||
|
||||
print(f" Converted {len(documents_embeddings)} embeddings")
|
||||
if len(documents_embeddings) > 0:
|
||||
print(f" First embedding shape: {documents_embeddings[0].shape}")
|
||||
print(f" First embedding dtype: {documents_embeddings[0].dtype}")
|
||||
|
||||
# Prepare metadata for Fast-Plaid
|
||||
print(f" Preparing metadata for {len(filepaths)} documents...")
|
||||
metadata_list = []
|
||||
for i, filepath in enumerate(filepaths):
|
||||
metadata_list.append(
|
||||
{
|
||||
"filepath": filepath,
|
||||
"index": i,
|
||||
}
|
||||
)
|
||||
|
||||
# Create Fast-Plaid index
|
||||
print(f" Creating FastPlaid object with index path: {index_path}")
|
||||
try:
|
||||
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
||||
print(" FastPlaid object created successfully")
|
||||
except Exception as e:
|
||||
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
|
||||
try:
|
||||
fast_plaid_index.create(
|
||||
documents_embeddings=documents_embeddings,
|
||||
metadata=metadata_list,
|
||||
)
|
||||
print(" Fast-Plaid index created successfully")
|
||||
except Exception as e:
|
||||
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
build_secs = time.perf_counter() - _t0
|
||||
|
||||
# Save images separately (Fast-Plaid doesn't store images)
|
||||
print(f" Saving {len(images)} images...")
|
||||
images_dir = Path(index_path) / "images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, img in enumerate(tqdm(images, desc="Saving images")):
|
||||
img_path = images_dir / f"doc_{i}.png"
|
||||
img.save(str(img_path))
|
||||
|
||||
return fast_plaid_index, build_secs
|
||||
|
||||
|
||||
def _fast_plaid_index_exists(index_path: str) -> bool:
|
||||
"""
|
||||
Check if Fast-Plaid index exists by checking for key files.
|
||||
This avoids creating the FastPlaid object which may trigger memory allocation.
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
|
||||
Returns:
|
||||
True if index appears to exist, False otherwise
|
||||
"""
|
||||
index_path_obj = Path(index_path)
|
||||
if not index_path_obj.exists() or not index_path_obj.is_dir():
|
||||
return False
|
||||
|
||||
# Fast-Plaid creates a SQLite database file for metadata
|
||||
# Check for metadata.db as the most reliable indicator
|
||||
metadata_db = index_path_obj / "metadata.db"
|
||||
if metadata_db.exists() and metadata_db.stat().st_size > 0:
|
||||
return True
|
||||
|
||||
# Also check if directory has any files (might be incomplete index)
|
||||
try:
|
||||
if any(index_path_obj.iterdir()):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]:
|
||||
"""
|
||||
Load Fast-Plaid index if it exists.
|
||||
First checks if index files exist, then creates the FastPlaid object.
|
||||
The actual index data loading happens lazily when search is called.
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
|
||||
Returns:
|
||||
FastPlaid index object if exists, None otherwise
|
||||
"""
|
||||
try:
|
||||
from fast_plaid import search as fast_plaid_search
|
||||
|
||||
# First check if index files exist without creating the object
|
||||
if not _fast_plaid_index_exists(index_path):
|
||||
return None
|
||||
|
||||
# Now try to create FastPlaid object
|
||||
# This may trigger some memory allocation, but the full index loading is deferred
|
||||
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
||||
return fast_plaid_index
|
||||
except ImportError:
|
||||
# fast-plaid not installed
|
||||
return None
|
||||
except Exception as e:
|
||||
# Any error (including memory errors from Rust backend) - return None
|
||||
# The error will be caught and index will be rebuilt
|
||||
print(f"Warning: Could not load Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _search_fast_plaid(
|
||||
fast_plaid_index: Any,
|
||||
query_vec: Any,
|
||||
top_k: int,
|
||||
) -> tuple[list[tuple[float, int]], float]:
|
||||
"""
|
||||
Search Fast-Plaid index with a query embedding.
|
||||
|
||||
Args:
|
||||
fast_plaid_index: FastPlaid index object
|
||||
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
|
||||
top_k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
Tuple of (results_list, search_time_in_seconds)
|
||||
results_list: List of (score, doc_id) tuples
|
||||
"""
|
||||
import torch
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
|
||||
# Ensure query is a torch tensor
|
||||
if not isinstance(query_vec, torch.Tensor):
|
||||
q_vec_tensor = (
|
||||
torch.tensor(query_vec)
|
||||
if isinstance(query_vec, np.ndarray)
|
||||
else torch.from_numpy(np.array(query_vec))
|
||||
)
|
||||
else:
|
||||
q_vec_tensor = query_vec
|
||||
|
||||
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
|
||||
if q_vec_tensor.dim() == 2:
|
||||
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
|
||||
|
||||
# Perform search
|
||||
scores = fast_plaid_index.search(
|
||||
queries_embeddings=q_vec_tensor,
|
||||
top_k=top_k,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
search_secs = time.perf_counter() - _t0
|
||||
|
||||
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
|
||||
results = []
|
||||
if scores and len(scores) > 0:
|
||||
query_results = scores[0]
|
||||
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
|
||||
results = [(float(score), int(doc_id)) for doc_id, score in query_results]
|
||||
|
||||
return results, search_secs
|
||||
|
||||
|
||||
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
|
||||
"""
|
||||
Retrieve image for a document from Fast-Plaid index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
doc_id: Document ID returned by Fast-Plaid search
|
||||
|
||||
Returns:
|
||||
PIL Image if found, None otherwise
|
||||
|
||||
Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid
|
||||
doc_id may differ from the file naming index.
|
||||
"""
|
||||
# First get metadata to find the actual index used for file naming
|
||||
metadata = _get_fast_plaid_metadata(index_path, doc_id)
|
||||
if metadata is None:
|
||||
# Fallback: try using doc_id directly
|
||||
file_index = doc_id
|
||||
else:
|
||||
# Use the 'index' field from metadata, which matches the file naming
|
||||
file_index = metadata.get("index", doc_id)
|
||||
|
||||
images_dir = Path(index_path) / "images"
|
||||
image_path = images_dir / f"doc_{file_index}.png"
|
||||
|
||||
if image_path.exists():
|
||||
return Image.open(image_path)
|
||||
|
||||
# If not found with index, try doc_id as fallback
|
||||
if file_index != doc_id:
|
||||
fallback_path = images_dir / f"doc_{doc_id}.png"
|
||||
if fallback_path.exists():
|
||||
return Image.open(fallback_path)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
|
||||
"""
|
||||
Retrieve metadata for a document from Fast-Plaid index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
Dictionary with metadata if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
from fast_plaid import filtering
|
||||
|
||||
metadata_list = filtering.get(index=index_path, subset=[doc_id])
|
||||
if metadata_list and len(metadata_list) > 0:
|
||||
return metadata_list[0]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _generate_similarity_map(
|
||||
model,
|
||||
processor,
|
||||
@@ -678,11 +1085,15 @@ class LeannMultiVector:
|
||||
return (float(score), doc_id)
|
||||
|
||||
scores: list[tuple[float, int]] = []
|
||||
# load and core time
|
||||
start_time = time.time()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
|
||||
for fut in concurrent.futures.as_completed(futures):
|
||||
scores.append(fut.result())
|
||||
|
||||
end_time = time.time()
|
||||
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
|
||||
print(f"Time taken in load and core time: {end_time - start_time} seconds")
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
@@ -710,7 +1121,6 @@ class LeannMultiVector:
|
||||
emb_path = self._embeddings_path()
|
||||
if not emb_path.exists():
|
||||
return self.search(data, topk)
|
||||
|
||||
all_embeddings = np.load(emb_path, mmap_mode="r")
|
||||
if all_embeddings.dtype != np.float32:
|
||||
all_embeddings = all_embeddings.astype(np.float32)
|
||||
@@ -718,23 +1128,29 @@ class LeannMultiVector:
|
||||
assert self._docid_to_indices is not None
|
||||
candidate_doc_ids = list(self._docid_to_indices.keys())
|
||||
|
||||
def _score_one(doc_id: int) -> tuple[float, int]:
|
||||
def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]:
|
||||
token_indices = self._docid_to_indices.get(doc_id, [])
|
||||
if not token_indices:
|
||||
return (0.0, doc_id)
|
||||
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
||||
doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32)
|
||||
sim = np.dot(data, doc_vecs.T)
|
||||
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
||||
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
||||
return (float(score), doc_id)
|
||||
|
||||
scores: list[tuple[float, int]] = []
|
||||
# load and core time
|
||||
start_time = time.time()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
|
||||
for fut in concurrent.futures.as_completed(futures):
|
||||
scores.append(fut.result())
|
||||
|
||||
end_time = time.time()
|
||||
# print number of candidate doc ids
|
||||
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
|
||||
print(f"Time taken in load and core time: {end_time - start_time} seconds")
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
del all_embeddings
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
def get_image(self, doc_id: int) -> Optional[Image.Image]:
|
||||
@@ -778,3 +1194,259 @@ class LeannMultiVector:
|
||||
"image_path": meta.get("image_path", ""),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
class ViDoReBenchmarkEvaluator:
|
||||
"""
|
||||
A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
|
||||
This class encapsulates common functionality for building indexes, searching, and evaluating.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
use_fast_plaid: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the evaluator.
|
||||
|
||||
Args:
|
||||
model_name: Model name ("colqwen2" or "colpali")
|
||||
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
|
||||
top_k: Top-k results to retrieve
|
||||
first_stage_k: First stage k for LEANN search
|
||||
k_values: List of k values for evaluation metrics
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.use_fast_plaid = use_fast_plaid
|
||||
self.top_k = top_k
|
||||
self.first_stage_k = first_stage_k
|
||||
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
|
||||
|
||||
# Load model once (can be reused across tasks)
|
||||
self._model = None
|
||||
self._processor = None
|
||||
self._model_name_actual = None
|
||||
|
||||
def _load_model_if_needed(self):
|
||||
"""Lazy load the model."""
|
||||
if self._model is None:
|
||||
print(f"\nLoading model: {self.model_name}")
|
||||
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(
|
||||
self.model_name
|
||||
)
|
||||
print(f"Model loaded: {self._model_name_actual}")
|
||||
|
||||
def build_index_from_corpus(
|
||||
self,
|
||||
corpus: dict[str, Image.Image],
|
||||
index_path: str,
|
||||
rebuild: bool = False,
|
||||
) -> tuple[Any, list[str]]:
|
||||
"""
|
||||
Build index from corpus images.
|
||||
|
||||
Args:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
index_path: Path to save/load the index
|
||||
rebuild: Whether to rebuild even if index exists
|
||||
|
||||
Returns:
|
||||
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
|
||||
"""
|
||||
self._load_model_if_needed()
|
||||
|
||||
# Ensure consistent ordering
|
||||
corpus_ids = sorted(corpus.keys())
|
||||
images = [corpus[cid] for cid in corpus_ids]
|
||||
|
||||
if self.use_fast_plaid:
|
||||
# Check if Fast-Plaid index exists
|
||||
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
|
||||
print(f"Fast-Plaid index already exists at {index_path}")
|
||||
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
|
||||
|
||||
print(f"Building Fast-Plaid index at {index_path}...")
|
||||
print("Embedding images...")
|
||||
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||
|
||||
fast_plaid_index, build_time = _build_fast_plaid_index(
|
||||
index_path, doc_vecs, corpus_ids, images
|
||||
)
|
||||
print(f"Fast-Plaid index built in {build_time:.2f}s")
|
||||
return fast_plaid_index, corpus_ids
|
||||
else:
|
||||
# Check if LEANN index exists
|
||||
if not rebuild:
|
||||
retriever = _load_retriever_if_index_exists(index_path)
|
||||
if retriever is not None:
|
||||
print(f"LEANN index already exists at {index_path}")
|
||||
return retriever, corpus_ids
|
||||
|
||||
print(f"Building LEANN index at {index_path}...")
|
||||
print("Embedding images...")
|
||||
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||
|
||||
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
|
||||
print("LEANN index built")
|
||||
return retriever, corpus_ids
|
||||
|
||||
def search_queries(
|
||||
self,
|
||||
queries: dict[str, str],
|
||||
corpus_ids: list[str],
|
||||
index_or_retriever: Any,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
task_prompt: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Search queries against the index.
|
||||
|
||||
Args:
|
||||
queries: dict mapping query_id to query text
|
||||
corpus_ids: list of corpus_ids in the same order as the index
|
||||
index_or_retriever: index or retriever object
|
||||
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
|
||||
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
|
||||
|
||||
Returns:
|
||||
results: dict mapping query_id to dict of {corpus_id: score}
|
||||
"""
|
||||
self._load_model_if_needed()
|
||||
|
||||
print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
|
||||
|
||||
query_ids = list(queries.keys())
|
||||
query_texts = [queries[qid] for qid in query_ids]
|
||||
|
||||
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata
|
||||
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
|
||||
# So we don't append task_prompt here to match MTEB behavior
|
||||
|
||||
# Embed queries
|
||||
print("Embedding queries...")
|
||||
query_vecs = _embed_queries(self._model, self._processor, query_texts)
|
||||
|
||||
results = {}
|
||||
|
||||
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
|
||||
if self.use_fast_plaid:
|
||||
# Fast-Plaid search
|
||||
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, self.top_k)
|
||||
query_results = {}
|
||||
for score, doc_id in search_results:
|
||||
if doc_id < len(corpus_ids):
|
||||
corpus_id = corpus_ids[doc_id]
|
||||
query_results[corpus_id] = float(score)
|
||||
else:
|
||||
# LEANN search
|
||||
import torch
|
||||
|
||||
query_np = (
|
||||
query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
|
||||
)
|
||||
search_results = index_or_retriever.search_exact(query_np, topk=self.top_k)
|
||||
query_results = {}
|
||||
for score, doc_id in search_results:
|
||||
if doc_id < len(corpus_ids):
|
||||
corpus_id = corpus_ids[doc_id]
|
||||
query_results[corpus_id] = float(score)
|
||||
|
||||
results[query_id] = query_results
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def evaluate_results(
|
||||
results: dict[str, dict[str, float]],
|
||||
qrels: dict[str, dict[str, int]],
|
||||
k_values: Optional[list[int]] = None,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Evaluate retrieval results using NDCG and other metrics.
|
||||
|
||||
Args:
|
||||
results: dict mapping query_id to dict of {corpus_id: score}
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
k_values: List of k values for evaluation metrics
|
||||
|
||||
Returns:
|
||||
Dictionary of metric scores
|
||||
"""
|
||||
try:
|
||||
from mteb._evaluators.retrieval_metrics import (
|
||||
calculate_retrieval_scores,
|
||||
make_score_dict,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pytrec_eval is required for evaluation. Install with: pip install pytrec-eval"
|
||||
)
|
||||
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Check if we have any queries to evaluate
|
||||
if len(results) == 0:
|
||||
print("Warning: No queries to evaluate. Returning zero scores.")
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
print(f"Evaluating results with k_values={k_values}...")
|
||||
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
|
||||
|
||||
# Filter to ensure qrels and results have the same query set
|
||||
# This matches MTEB behavior: only evaluate queries that exist in both
|
||||
# pytrec_eval only evaluates queries in qrels, so we need to ensure
|
||||
# results contains all queries in qrels, and filter out queries not in qrels
|
||||
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
|
||||
qrels_filtered = {
|
||||
qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered
|
||||
}
|
||||
|
||||
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
|
||||
|
||||
if len(results_filtered) != len(qrels_filtered):
|
||||
print(
|
||||
f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries"
|
||||
)
|
||||
missing_in_results = set(qrels.keys()) - set(results.keys())
|
||||
if missing_in_results:
|
||||
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
|
||||
print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
|
||||
|
||||
# Convert qrels to pytrec_eval format
|
||||
qrels_pytrec = {}
|
||||
for qid, rel_docs in qrels_filtered.items():
|
||||
qrels_pytrec[qid] = dict(rel_docs.items())
|
||||
|
||||
# Evaluate
|
||||
eval_result = calculate_retrieval_scores(
|
||||
results=results_filtered,
|
||||
qrels=qrels_pytrec,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Format scores
|
||||
scores = make_score_dict(
|
||||
ndcg=eval_result.ndcg,
|
||||
_map=eval_result.map,
|
||||
recall=eval_result.recall,
|
||||
precision=eval_result.precision,
|
||||
mrr=eval_result.mrr,
|
||||
naucs=eval_result.naucs,
|
||||
naucs_mrr=eval_result.naucs_mrr,
|
||||
cv_recall=eval_result.cv_recall,
|
||||
task_scores={},
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
## Jupyter-style notebook script
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import argparse
|
||||
import faulthandler
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Enable faulthandler to get stack trace on segfault
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
from leann_multi_vector import ( # utility functions/classes
|
||||
_ensure_repo_paths_importable,
|
||||
@@ -18,6 +25,11 @@ from leann_multi_vector import ( # utility functions/classes
|
||||
_build_index,
|
||||
_load_retriever_if_index_exists,
|
||||
_generate_similarity_map,
|
||||
_build_fast_plaid_index,
|
||||
_load_fast_plaid_index_if_exists,
|
||||
_search_fast_plaid,
|
||||
_get_fast_plaid_image,
|
||||
_get_fast_plaid_metadata,
|
||||
QwenVL,
|
||||
)
|
||||
|
||||
@@ -31,19 +43,52 @@ MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||
|
||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||
USE_HF_DATASET: bool = True
|
||||
# Single dataset name (used when DATASET_NAMES is None)
|
||||
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
DATASET_SPLIT: str = "train"
|
||||
# Multiple datasets to combine (if provided, DATASET_NAME is ignored)
|
||||
# Can be:
|
||||
# - List of strings: ["dataset1", "dataset2"]
|
||||
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
|
||||
# - Mixed: ["dataset1", ("dataset2", "config2")]
|
||||
#
|
||||
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
|
||||
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
|
||||
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
|
||||
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
|
||||
# - "pixparse/arxiv-papers" (if available, arXiv papers)
|
||||
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
|
||||
# - "huggingface/document-images" (if available)
|
||||
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
|
||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||
DATASET_NAMES = [
|
||||
"weaviate/arXiv-AI-papers-multi-vector",
|
||||
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
]
|
||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||
# Set to None to try loading all available splits automatically
|
||||
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
|
||||
# Image field name in the dataset (auto-detect if None)
|
||||
# Common names: "page_image", "image", "images", "img"
|
||||
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
|
||||
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||
# If set, images will be loaded directly from this folder
|
||||
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||
# Whether to recursively search subdirectories when loading from custom folder
|
||||
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||
|
||||
# Index + retrieval settings
|
||||
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||
INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# Fast-Plaid index settings (alternative to LEANN index)
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False
|
||||
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
@@ -54,38 +99,347 @@ ANSWER: bool = True
|
||||
MAX_NEW_TOKENS: int = 1024
|
||||
|
||||
|
||||
# %%
|
||||
# CLI overrides
|
||||
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
|
||||
parser.add_argument(
|
||||
"--search-method",
|
||||
type=str,
|
||||
choices=["ann", "exact", "exact-all"],
|
||||
default="ann",
|
||||
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=QUERY,
|
||||
help=f"Query string to search for. Default: '{QUERY}'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set to True to use fast-plaid instead of LEANN. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default="./indexes/colvision_fastplaid",
|
||||
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=TOPK,
|
||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Check if we can skip data loading (index already exists)
|
||||
retriever: Optional[Any] = None
|
||||
fast_plaid_index: Optional[Any] = None
|
||||
need_to_build_index = REBUILD_INDEX
|
||||
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid index handling
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is not None:
|
||||
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Fast-Plaid index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
except Exception as e:
|
||||
# If loading fails (e.g., memory error, corrupted index), rebuild
|
||||
print(f"Warning: Failed to load Fast-Plaid index: {e}")
|
||||
print("Will rebuild the index...")
|
||||
need_to_build_index = True
|
||||
fast_plaid_index = None
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
# Original LEANN index handling
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild index")
|
||||
need_to_build_index = True
|
||||
|
||||
# Step 2: Load data only if we need to build the index
|
||||
if need_to_build_index:
|
||||
print("Loading dataset...")
|
||||
if USE_HF_DATASET:
|
||||
from datasets import load_dataset
|
||||
# Check for custom folder path first (takes precedence)
|
||||
if CUSTOM_FOLDER_PATH:
|
||||
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||
if CUSTOM_FOLDER_RECURSIVE:
|
||||
print(" (recursive mode: searching subdirectories)")
|
||||
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||
print(f" Found {len(filepaths)} image files")
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||
)
|
||||
print(f" Successfully loaded {len(images)} images")
|
||||
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||
elif USE_HF_DATASET:
|
||||
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||
# Determine which datasets to load
|
||||
if DATASET_NAMES is not None:
|
||||
dataset_names_to_load = DATASET_NAMES
|
||||
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
|
||||
else:
|
||||
dataset_names_to_load = [DATASET_NAME]
|
||||
print(f"Loading single dataset: {DATASET_NAME}")
|
||||
|
||||
# Load and combine datasets
|
||||
all_datasets_to_concat = []
|
||||
|
||||
for dataset_entry in dataset_names_to_load:
|
||||
# Handle both string and tuple formats
|
||||
if isinstance(dataset_entry, tuple):
|
||||
dataset_name, config_name = dataset_entry
|
||||
else:
|
||||
dataset_name = dataset_entry
|
||||
config_name = None
|
||||
|
||||
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
|
||||
|
||||
# Load dataset to check available splits
|
||||
# If config_name is provided, use it; otherwise try without config
|
||||
try:
|
||||
if config_name:
|
||||
dataset_dict = load_dataset(dataset_name, config_name)
|
||||
else:
|
||||
dataset_dict = load_dataset(dataset_name)
|
||||
except ValueError as e:
|
||||
if "Config name is missing" in str(e):
|
||||
# Try to get available configs and suggest
|
||||
from datasets import get_dataset_config_names
|
||||
try:
|
||||
available_configs = get_dataset_config_names(dataset_name)
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Available configs: {available_configs}. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Determine which splits to load
|
||||
if DATASET_SPLITS is None:
|
||||
# Auto-detect: try to load all available splits
|
||||
available_splits = list(dataset_dict.keys())
|
||||
print(f" Auto-detected splits: {available_splits}")
|
||||
splits_to_load = available_splits
|
||||
else:
|
||||
splits_to_load = DATASET_SPLITS
|
||||
|
||||
# Load and concatenate multiple splits for this dataset
|
||||
datasets_to_concat = []
|
||||
for split in splits_to_load:
|
||||
if split not in dataset_dict:
|
||||
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
||||
continue
|
||||
split_dataset = dataset_dict[split]
|
||||
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
||||
datasets_to_concat.append(split_dataset)
|
||||
|
||||
if not datasets_to_concat:
|
||||
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
|
||||
continue
|
||||
|
||||
# Concatenate splits for this dataset
|
||||
if len(datasets_to_concat) > 1:
|
||||
combined_dataset = concatenate_datasets(datasets_to_concat)
|
||||
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
|
||||
else:
|
||||
combined_dataset = datasets_to_concat[0]
|
||||
|
||||
all_datasets_to_concat.append(combined_dataset)
|
||||
|
||||
if not all_datasets_to_concat:
|
||||
raise RuntimeError("No valid datasets or splits found.")
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(all_datasets_to_concat) > 1:
|
||||
dataset = concatenate_datasets(all_datasets_to_concat)
|
||||
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
|
||||
else:
|
||||
dataset = all_datasets_to_concat[0]
|
||||
|
||||
# Apply MAX_DOCS limit if specified
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
if N < len(dataset):
|
||||
print(f"Limiting to {N} pages (from {len(dataset)} total)")
|
||||
dataset = dataset.select(range(N))
|
||||
|
||||
# Auto-detect image field name if not specified
|
||||
if IMAGE_FIELD_NAME is None:
|
||||
# Check multiple samples to find the most common image field
|
||||
# (useful when datasets are merged and may have different field names)
|
||||
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
|
||||
field_counts = {}
|
||||
|
||||
# Check first few samples to find image fields
|
||||
num_samples_to_check = min(10, len(dataset))
|
||||
for sample_idx in range(num_samples_to_check):
|
||||
sample = dataset[sample_idx]
|
||||
for field in possible_image_fields:
|
||||
if field in sample and sample[field] is not None:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
field_counts[field] = field_counts.get(field, 0) + 1
|
||||
|
||||
# Choose the most common field, or first found if tied
|
||||
if field_counts:
|
||||
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
|
||||
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
|
||||
else:
|
||||
# Fallback: check first sample only
|
||||
sample = dataset[0]
|
||||
image_field = None
|
||||
for field in possible_image_fields:
|
||||
if field in sample:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
image_field = field
|
||||
break
|
||||
if image_field is None:
|
||||
raise RuntimeError(
|
||||
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
|
||||
f"Please specify IMAGE_FIELD_NAME manually."
|
||||
)
|
||||
print(f"Auto-detected image field: '{image_field}'")
|
||||
else:
|
||||
image_field = IMAGE_FIELD_NAME
|
||||
if image_field not in dataset[0]:
|
||||
raise RuntimeError(
|
||||
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
|
||||
)
|
||||
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(N), desc="Loading dataset", total=N):
|
||||
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
|
||||
p = dataset[i]
|
||||
# Compose a descriptive identifier for printing later
|
||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||
# Try to compose a descriptive identifier
|
||||
# Handle different dataset structures
|
||||
identifier_parts = []
|
||||
|
||||
# Helper function to safely get field value
|
||||
def safe_get(field_name, default=None):
|
||||
if field_name in p and p[field_name] is not None:
|
||||
return p[field_name]
|
||||
return default
|
||||
|
||||
# Try to get various identifier fields
|
||||
if safe_get("paper_arxiv_id"):
|
||||
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
|
||||
if safe_get("paper_title"):
|
||||
identifier_parts.append(f"title:{p['paper_title']}")
|
||||
if safe_get("page_number") is not None:
|
||||
try:
|
||||
identifier_parts.append(f"page:{int(p['page_number'])}")
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, use the raw value or skip
|
||||
if p['page_number']:
|
||||
identifier_parts.append(f"page:{p['page_number']}")
|
||||
if safe_get("page_id"):
|
||||
identifier_parts.append(f"id:{p['page_id']}")
|
||||
elif safe_get("questionId"):
|
||||
identifier_parts.append(f"qid:{p['questionId']}")
|
||||
elif safe_get("docId"):
|
||||
identifier_parts.append(f"docId:{p['docId']}")
|
||||
elif safe_get("id"):
|
||||
identifier_parts.append(f"id:{p['id']}")
|
||||
|
||||
# If no identifier parts found, create one from index
|
||||
if identifier_parts:
|
||||
identifier = "|".join(identifier_parts)
|
||||
else:
|
||||
# Create identifier from available fields or index
|
||||
fallback_parts = []
|
||||
# Try common fields that might exist
|
||||
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
|
||||
if safe_get(field):
|
||||
fallback_parts.append(f"{field}:{p[field]}")
|
||||
break
|
||||
if fallback_parts:
|
||||
identifier = "|".join(fallback_parts) + f"|idx:{i}"
|
||||
else:
|
||||
identifier = f"doc_{i}"
|
||||
|
||||
filepaths.append(identifier)
|
||||
images.append(p["page_image"]) # PIL Image
|
||||
|
||||
# Get image - try detected field first, then fallback to other common fields
|
||||
img = None
|
||||
if image_field in p and p[image_field] is not None:
|
||||
img = p[image_field]
|
||||
else:
|
||||
# Fallback: try other common image field names
|
||||
for fallback_field in ["image", "page_image", "images", "img"]:
|
||||
if fallback_field in p and p[fallback_field] is not None:
|
||||
img = p[fallback_field]
|
||||
break
|
||||
|
||||
if img is None:
|
||||
raise RuntimeError(
|
||||
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
|
||||
f"Expected field: {image_field}"
|
||||
)
|
||||
|
||||
# Ensure it's a PIL Image
|
||||
if not isinstance(img, Image.Image):
|
||||
if hasattr(img, 'convert'):
|
||||
img = img.convert('RGB')
|
||||
else:
|
||||
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
|
||||
images.append(img)
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
@@ -94,6 +448,19 @@ if need_to_build_index:
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
print(f"Loaded {len(images)} images")
|
||||
|
||||
# Memory check before loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
print("Skipping dataset loading (using existing index)")
|
||||
filepaths = [] # Not needed when using existing index
|
||||
@@ -102,46 +469,181 @@ else:
|
||||
|
||||
# %%
|
||||
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
print("Step 3: Loading model and processor...")
|
||||
print(f" Model: {MODEL}")
|
||||
try:
|
||||
import sys
|
||||
print(f" Python version: {sys.version}")
|
||||
print(f" Python executable: {sys.executable}")
|
||||
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
# Memory check after loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 4: Build index if needed
|
||||
if need_to_build_index and retriever is None:
|
||||
print("Building index...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
# Clear memory
|
||||
del images, filepaths, doc_vecs
|
||||
if need_to_build_index:
|
||||
print("Step 4: Building index...")
|
||||
print(f" Number of images: {len(images)}")
|
||||
print(f" Number of filepaths: {len(filepaths)}")
|
||||
|
||||
# Note: Images are now stored in the index, retriever will load them on-demand from disk
|
||||
try:
|
||||
print(" Embedding images...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
print(f" Embedded {len(doc_vecs)} documents")
|
||||
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
|
||||
except Exception as e:
|
||||
print(f"Error embedding images: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Build Fast-Plaid index
|
||||
print(" Building Fast-Plaid index...")
|
||||
try:
|
||||
fast_plaid_index, build_secs = _build_fast_plaid_index(
|
||||
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
|
||||
)
|
||||
from pathlib import Path
|
||||
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
|
||||
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
|
||||
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
|
||||
except Exception as e:
|
||||
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
else:
|
||||
# Build original LEANN index
|
||||
try:
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
except Exception as e:
|
||||
print(f"Error building LEANN index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
|
||||
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
|
||||
|
||||
|
||||
# %%
|
||||
# Step 5: Embed query and search
|
||||
_t0 = time.perf_counter()
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
|
||||
query_embed_secs = time.perf_counter() - _t0
|
||||
|
||||
print(f"[Search] Method: {SEARCH_METHOD}")
|
||||
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
|
||||
|
||||
# Run the selected search method and time it
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid search
|
||||
if fast_plaid_index is None:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is None:
|
||||
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
||||
|
||||
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
||||
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
||||
else:
|
||||
# Original LEANN search
|
||||
query_np = q_vec.float().numpy()
|
||||
|
||||
if SEARCH_METHOD == "ann":
|
||||
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact":
|
||||
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact-all":
|
||||
results = retriever.search_exact_all(query_np, topk=TOPK)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
|
||||
else:
|
||||
results = []
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
print("\n[DEBUG] Retrieval details:")
|
||||
top_images: list[Image.Image] = []
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
# Retrieve image from index instead of memory
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
image_hashes = {} # Track image hashes to detect duplicates
|
||||
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||
top_images.append(image)
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
# Retrieve image and metadata based on index type
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid: load image and get metadata
|
||||
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not find image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
||||
top_images.append(image)
|
||||
else:
|
||||
# Original LEANN: retrieve from retriever
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
top_images.append(image)
|
||||
|
||||
# Calculate image hash to detect duplicates
|
||||
import hashlib
|
||||
import io
|
||||
# Convert image to bytes for hashing
|
||||
img_bytes = io.BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
image_bytes = img_bytes.getvalue()
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
|
||||
|
||||
# Check if this image was already seen
|
||||
duplicate_info = ""
|
||||
if image_hash in image_hashes:
|
||||
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
|
||||
else:
|
||||
image_hashes[image_hash] = rank
|
||||
|
||||
# Print detailed information
|
||||
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
|
||||
if metadata:
|
||||
print(f" Metadata: {metadata}")
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
@@ -161,7 +663,6 @@ else:
|
||||
except Exception:
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||
|
||||
# %%
|
||||
# Step 6: Similarity maps for top-K results
|
||||
@@ -204,6 +705,9 @@ if results and SIMILARITY_MAP:
|
||||
# Step 7: Optional answer generation
|
||||
if results and ANSWER:
|
||||
qwen = QwenVL(device=device_str)
|
||||
_t0 = time.perf_counter()
|
||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||
gen_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Generation: {gen_secs:.3f}s")
|
||||
print("\nAnswer:")
|
||||
print(response)
|
||||
|
||||
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v1 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v1 tasks
|
||||
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# ViDoRe v1 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V1_TASKS = {
|
||||
"VidoreArxivQARetrieval": {
|
||||
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreDocVQARetrieval": {
|
||||
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreInfoVQARetrieval": {
|
||||
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTabfquadRetrieval": {
|
||||
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTatdqaRetrieval": {
|
||||
"dataset_path": "vidore/tatdqa_test_beir",
|
||||
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreShiftProjectRetrieval": {
|
||||
"dataset_path": "vidore/shiftproject_test_beir",
|
||||
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAAIRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
# Task name aliases (short names -> full names)
|
||||
TASK_ALIASES = {
|
||||
"arxivqa": "VidoreArxivQARetrieval",
|
||||
"docvqa": "VidoreDocVQARetrieval",
|
||||
"infovqa": "VidoreInfoVQARetrieval",
|
||||
"tabfquad": "VidoreTabfquadRetrieval",
|
||||
"tatdqa": "VidoreTatdqaRetrieval",
|
||||
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||
}
|
||||
|
||||
|
||||
def normalize_task_name(task_name: str) -> str:
|
||||
"""Normalize task name (handle aliases)."""
|
||||
task_name_lower = task_name.lower()
|
||||
if task_name in VIDORE_V1_TASKS:
|
||||
return task_name
|
||||
if task_name_lower in TASK_ALIASES:
|
||||
return TASK_ALIASES[task_name_lower]
|
||||
# Try partial match
|
||||
for alias, full_name in TASK_ALIASES.items():
|
||||
if alias in task_name_lower or task_name_lower in alias:
|
||||
return full_name
|
||||
return task_name
|
||||
|
||||
|
||||
def get_safe_model_name(model_name: str) -> str:
|
||||
"""Get a safe model name for use in file paths."""
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
# If it's a path, use basename or hash
|
||||
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||
# Use basename if it's reasonable, otherwise use hash
|
||||
basename = os.path.basename(model_name.rstrip("/"))
|
||||
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||
return basename
|
||||
# Use hash for very long or problematic paths
|
||||
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||
# For HuggingFace model names, replace / with _
|
||||
return model_name.replace("/", "_").replace(":", "_")
|
||||
|
||||
|
||||
def load_vidore_v1_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v1 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 1000,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v1 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Normalize task name (handle aliases)
|
||||
task_name = normalize_task_name(task_name)
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V1_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V1_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v1_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
)
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
# Use safe model name for index path (different models need different indexes)
|
||||
safe_model_name = get_safe_model_name(model_name)
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,20,100,1000",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v1_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [normalize_task_name(args.task)]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v2 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v2 tasks
|
||||
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Language name to dataset language field value mapping
|
||||
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
|
||||
LANGUAGE_MAPPING = {
|
||||
"english": "eng-Latn",
|
||||
"french": "fra-Latn",
|
||||
"spanish": "spa-Latn",
|
||||
"german": "deu-Latn",
|
||||
}
|
||||
|
||||
# ViDoRe v2 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V2_TASKS = {
|
||||
"Vidore2ESGReportsRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_v2",
|
||||
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2EconomicsReportsRetrieval": {
|
||||
"dataset_path": "vidore/economics_reports_v2",
|
||||
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2BioMedicalLecturesRetrieval": {
|
||||
"dataset_path": "vidore/biomedical_lectures_v2",
|
||||
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2ESGReportsHLRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_human_labeled_v2",
|
||||
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
|
||||
# Note: This dataset doesn't have language filtering - all queries are English
|
||||
"languages": None, # No language filtering needed
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_vidore_v2_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v2 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
# Check if dataset has language field before filtering
|
||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||
|
||||
if language and has_language_field:
|
||||
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||
# Check if filtering resulted in empty dataset
|
||||
if len(query_ds_filtered) == 0:
|
||||
print(
|
||||
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||
)
|
||||
# Try with original language value (dataset might use simple names like 'english')
|
||||
print(f"Trying with original language value '{language}'...")
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = load_dataset(
|
||||
dataset_path, "queries", split=split, revision=revision
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
print(f"Available language values in dataset: {sample_langs}")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v2 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V2_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V2_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
# Determine language
|
||||
if language is None:
|
||||
# Use first language if multiple available
|
||||
languages = task_config.get("languages")
|
||||
if languages is None:
|
||||
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||
language = None
|
||||
elif len(languages) == 1:
|
||||
language = languages[0]
|
||||
else:
|
||||
language = None
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v2_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(
|
||||
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||
)
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
choices=["colqwen2", "colpali"],
|
||||
help="Model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Language to evaluate (default: first available)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Top-k results to retrieve",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,100",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v2_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
language=args.language,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
|
||||
flexible message processing options.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
@@ -146,16 +147,16 @@ class SlackMCPReader:
|
||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = eval(match.group(1))
|
||||
except (ValueError, SyntaxError, NameError):
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
else:
|
||||
# Try alternative format
|
||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = eval(match.group(1))
|
||||
except (ValueError, SyntaxError, NameError):
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
if self._is_cache_sync_error(error_dict):
|
||||
|
||||
200
docs/COLQWEN_GUIDE.md
Normal file
200
docs/COLQWEN_GUIDE.md
Normal file
@@ -0,0 +1,200 @@
|
||||
# ColQwen Integration Guide
|
||||
|
||||
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
|
||||
|
||||
## Quick Start
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
### 2. Basic Usage
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 \
|
||||
--pages-dir ./page_images/ # Optional: save page images
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--pdfs`: Directory containing PDF files (or single PDF path)
|
||||
- `--index`: Name for the index (required)
|
||||
- `--model`: `colqwen2` (default) or `colpali`
|
||||
- `--pages-dir`: Directory to save page images (optional)
|
||||
|
||||
### Search Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--top-k`: Number of results to return (default: 5)
|
||||
- `--model`: Model used for search (should match build model)
|
||||
|
||||
### Interactive Q&A
|
||||
```bash
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
```
|
||||
|
||||
**Commands in interactive mode:**
|
||||
- Type your questions naturally
|
||||
- `help`: Show available commands
|
||||
- `quit`/`exit`/`q`: Exit interactive mode
|
||||
|
||||
## 🧪 Test & Reproduce Results
|
||||
|
||||
Run the reproduction test for issue #119:
|
||||
```bash
|
||||
python test_colqwen_reproduction.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. ✅ Check dependencies
|
||||
2. 📥 Download sample PDF (Attention Is All You Need paper)
|
||||
3. 🏗️ Build test index
|
||||
4. 🔍 Run sample queries
|
||||
5. 📊 Show how to generate similarity maps
|
||||
|
||||
## 🎨 Advanced: Similarity Maps
|
||||
|
||||
For visual similarity analysis, use the existing advanced script:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Edit the script to customize:
|
||||
- `QUERY`: Your question
|
||||
- `MODEL`: "colqwen2" or "colpali"
|
||||
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
|
||||
- `SIMILARITY_MAP`: Generate heatmaps
|
||||
- `ANSWER`: Enable Qwen-VL answer generation
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
### ColQwen2 vs ColPali
|
||||
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
|
||||
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
|
||||
|
||||
### Architecture
|
||||
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
|
||||
2. **Vision Encoding**: Process images with ColQwen2/ColPali
|
||||
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
|
||||
4. **Query Processing**: Encode text queries with same model
|
||||
5. **Similarity Search**: Find most relevant pages/regions
|
||||
6. **Visual Maps**: Generate attention heatmaps (optional)
|
||||
|
||||
### Device Support
|
||||
- **CUDA**: Best performance with GPU acceleration
|
||||
- **MPS**: Apple Silicon Mac support
|
||||
- **CPU**: Fallback for any system (slower)
|
||||
|
||||
Auto-detection: CUDA > MPS > CPU
|
||||
|
||||
## 📊 Performance Tips
|
||||
|
||||
### For Best Performance:
|
||||
```bash
|
||||
# Use ColQwen2 for latest features
|
||||
--model colqwen2
|
||||
|
||||
# Save page images for reuse
|
||||
--pages-dir ./cached_pages/
|
||||
|
||||
# Adjust batch size based on GPU memory
|
||||
# (automatically handled)
|
||||
```
|
||||
|
||||
### For Large Document Sets:
|
||||
- Process PDFs in batches
|
||||
- Use SSD storage for index files
|
||||
- Consider using CUDA if available
|
||||
|
||||
## 🔗 Related Resources
|
||||
|
||||
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
|
||||
- **Pylate**: https://github.com/lightonai/pylate
|
||||
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
|
||||
- **ColPali Paper**: Vision-Language Models for Document Retrieval
|
||||
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### PDF Conversion Issues (macOS)
|
||||
```bash
|
||||
# Install poppler
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
- Reduce batch size (automatically handled)
|
||||
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
|
||||
- Process fewer PDFs at once
|
||||
|
||||
### Model Download Issues
|
||||
- Ensure internet connection for first run
|
||||
- Models are cached after first download
|
||||
- Use HuggingFace mirrors if needed
|
||||
|
||||
### Import Errors
|
||||
```bash
|
||||
# Ensure all dependencies installed
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
|
||||
# Check PyTorch installation
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
```
|
||||
|
||||
## 💡 Examples
|
||||
|
||||
### Research Paper Analysis
|
||||
```bash
|
||||
# Index your research papers
|
||||
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
|
||||
|
||||
# Ask research questions
|
||||
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
|
||||
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
|
||||
```
|
||||
|
||||
### Document Q&A
|
||||
```bash
|
||||
# Index business documents
|
||||
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
|
||||
|
||||
# Interactive analysis
|
||||
python -m apps.colqwen_rag ask reports --interactive
|
||||
```
|
||||
|
||||
### Visual Analysis
|
||||
```bash
|
||||
# Generate similarity maps for specific queries
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
# Edit multi-vector-leann-similarity-map.py with your query
|
||||
python multi-vector-leann-similarity-map.py
|
||||
# Check ./figures/ for generated heatmaps
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**
|
||||
@@ -158,6 +158,95 @@ builder.build_index("./indexes/my-notes", chunks)
|
||||
|
||||
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||
|
||||
## Optional Embedding Features
|
||||
|
||||
### Task-Specific Prompt Templates
|
||||
|
||||
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
|
||||
|
||||
- **Indexing documents**: `"title: none | text: "`
|
||||
- **Search queries**: `"task: search result | query: "`
|
||||
|
||||
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
|
||||
|
||||
```bash
|
||||
# Build index with EmbeddingGemma (via LM Studio or Ollama)
|
||||
leann build my-docs \
|
||||
--docs ./documents \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-embeddinggemma-300m-qat \
|
||||
--embedding-api-base http://localhost:1234/v1 \
|
||||
--embedding-prompt-template "title: none | text: " \
|
||||
--force
|
||||
|
||||
# Search with query-specific prompt
|
||||
leann search my-docs \
|
||||
--query "What is quantum computing?" \
|
||||
--embedding-prompt-template "task: search result | query: "
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
|
||||
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
|
||||
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
|
||||
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
|
||||
|
||||
**Python API:**
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder(
|
||||
embedding_mode="openai",
|
||||
embedding_model="text-embedding-embeddinggemma-300m-qat",
|
||||
embedding_options={
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_key": "lm-studio",
|
||||
"prompt_template": "title: none | text: ",
|
||||
},
|
||||
)
|
||||
builder.build_index("./indexes/my-docs", chunks)
|
||||
```
|
||||
|
||||
**References:**
|
||||
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
|
||||
|
||||
### LM Studio Auto-Detection (Optional)
|
||||
|
||||
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
|
||||
|
||||
**Prerequisites:**
|
||||
```bash
|
||||
# Install Node.js (if not already installed)
|
||||
# Then install the LM Studio SDK globally
|
||||
npm install -g @lmstudio/sdk
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
|
||||
2. Queries model metadata via Node.js subprocess
|
||||
3. Automatically unloads model after query (respects your JIT auto-evict settings)
|
||||
4. Falls back to static registry if SDK unavailable
|
||||
|
||||
**No configuration needed** - it works automatically when SDK is installed:
|
||||
|
||||
```bash
|
||||
leann build my-docs \
|
||||
--docs ./documents \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-nomic-embed-text-v1.5 \
|
||||
--embedding-api-base http://localhost:1234/v1
|
||||
# Context length auto-detected if SDK available
|
||||
# Falls back to registry (2048) if not
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- ✅ Automatic token limit detection
|
||||
- ✅ Respects LM Studio JIT auto-evict settings
|
||||
- ✅ No manual registry maintenance
|
||||
- ✅ Graceful fallback if SDK unavailable
|
||||
|
||||
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
@@ -365,7 +454,7 @@ leann search my-index "your query" \
|
||||
|
||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
|
||||
|
||||
```bash
|
||||
# One-time: install and configure SkyPilot
|
||||
|
||||
48
docs/faq.md
48
docs/faq.md
@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
|
||||
## 2. When should I use prompt templates?
|
||||
|
||||
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
|
||||
|
||||
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
|
||||
|
||||
**Example usage with EmbeddingGemma:**
|
||||
```bash
|
||||
# Build with document prompt
|
||||
leann build my-docs --embedding-prompt-template "title: none | text: "
|
||||
|
||||
# Search with query prompt
|
||||
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
|
||||
```
|
||||
|
||||
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
|
||||
|
||||
## 3. Why is LM Studio loading multiple copies of my model?
|
||||
|
||||
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
|
||||
|
||||
**If you still see duplicates:**
|
||||
- Update to the latest LEANN version
|
||||
- Restart LM Studio to clear loaded models
|
||||
- Check that you have JIT auto-evict enabled in LM Studio settings
|
||||
|
||||
**How it works now:**
|
||||
1. LEANN loads model temporarily to get context length
|
||||
2. Immediately unloads after query
|
||||
3. LM Studio JIT loads model on-demand for actual embeddings
|
||||
4. Auto-evicts per your settings
|
||||
|
||||
## 4. Do I need Node.js and @lmstudio/sdk?
|
||||
|
||||
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
|
||||
|
||||
**Benefits if you install it:**
|
||||
- Automatic context length detection for LM Studio models
|
||||
- No manual registry maintenance
|
||||
- Always gets accurate token limits from the model itself
|
||||
|
||||
**To install (optional):**
|
||||
```bash
|
||||
npm install -g @lmstudio/sdk
|
||||
```
|
||||
|
||||
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.
|
||||
|
||||
@@ -916,6 +916,7 @@ class LeannSearcher:
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||
batch_size: int = 0,
|
||||
use_grep: bool = False,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
@@ -979,10 +980,24 @@ class LeannSearcher:
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Extract query template from stored embedding_options with fallback chain:
|
||||
# 1. Check provider_options override (highest priority)
|
||||
# 2. Check query_prompt_template (new format)
|
||||
# 3. Check prompt_template (old format for backward compat)
|
||||
# 4. None (no template)
|
||||
query_template = None
|
||||
if provider_options and "prompt_template" in provider_options:
|
||||
query_template = provider_options["prompt_template"]
|
||||
elif "query_prompt_template" in self.embedding_options:
|
||||
query_template = self.embedding_options["query_prompt_template"]
|
||||
elif "prompt_template" in self.embedding_options:
|
||||
query_template = self.embedding_options["prompt_template"]
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
query_template=query_template,
|
||||
)
|
||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
@@ -1236,15 +1251,15 @@ class LeannChat:
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
print("The context provided to the LLM is:")
|
||||
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
||||
print("-" * 150)
|
||||
logger.info("The context provided to the LLM is:")
|
||||
logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
||||
logger.info("-" * 150)
|
||||
for r in results:
|
||||
chunk_relevance = f"{r.score:.3f}"
|
||||
chunk_id = r.id
|
||||
chunk_content = r.text[:60]
|
||||
chunk_source = r.metadata.get("source", "")[:80]
|
||||
print(
|
||||
logger.info(
|
||||
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
||||
)
|
||||
ask_time = time.time()
|
||||
|
||||
@@ -12,7 +12,13 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
from .settings import (
|
||||
resolve_anthropic_api_key,
|
||||
resolve_anthropic_base_url,
|
||||
resolve_ollama_host,
|
||||
resolve_openai_api_key,
|
||||
resolve_openai_base_url,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -845,6 +851,81 @@ class OpenAIChat(LLMInterface):
|
||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||
|
||||
|
||||
class AnthropicChat(LLMInterface):
|
||||
"""LLM interface for Anthropic Claude models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-haiku-4-5",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = resolve_anthropic_base_url(base_url)
|
||||
self.api_key = resolve_anthropic_api_key(api_key)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
|
||||
model,
|
||||
self.base_url,
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
# Allow custom Anthropic-compatible endpoints via base_url
|
||||
self.client = anthropic.Anthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
logger.info(f"Sending request to Anthropic with model {self.model}")
|
||||
|
||||
try:
|
||||
# Anthropic API parameters
|
||||
params = {
|
||||
"model": self.model,
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if "temperature" in kwargs:
|
||||
params["temperature"] = kwargs["temperature"]
|
||||
if "top_p" in kwargs:
|
||||
params["top_p"] = kwargs["top_p"]
|
||||
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
# Extract text from response
|
||||
response_text = response.content[0].text
|
||||
|
||||
# Log token usage
|
||||
print(
|
||||
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
|
||||
f"input tokens = {response.usage.input_tokens}, "
|
||||
f"output tokens = {response.usage.output_tokens}"
|
||||
)
|
||||
|
||||
if response.stop_reason == "max_tokens":
|
||||
print("The query is exceeding the maximum allowed number of tokens")
|
||||
|
||||
return response_text.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error communicating with Anthropic: {e}")
|
||||
return f"Error: Could not get a response from Anthropic. Details: {e}"
|
||||
|
||||
|
||||
class SimulatedChat(LLMInterface):
|
||||
"""A simple simulated chat for testing and development."""
|
||||
|
||||
@@ -897,6 +978,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
)
|
||||
elif llm_type == "gemini":
|
||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "anthropic":
|
||||
return AnthropicChat(
|
||||
model=model or "claude-3-5-sonnet-20241022",
|
||||
api_key=llm_config.get("api_key"),
|
||||
base_url=llm_config.get("base_url"),
|
||||
)
|
||||
elif llm_type == "simulated":
|
||||
return SimulatedChat()
|
||||
else:
|
||||
|
||||
@@ -11,7 +11,12 @@ from tqdm import tqdm
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .interactive_utils import create_cli_session
|
||||
from .registry import register_project_directory
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
from .settings import (
|
||||
resolve_anthropic_base_url,
|
||||
resolve_ollama_host,
|
||||
resolve_openai_api_key,
|
||||
resolve_openai_base_url,
|
||||
)
|
||||
|
||||
|
||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||
@@ -144,6 +149,18 @@ Examples:
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--query-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template for queries (different from build template for task-specific models)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||
)
|
||||
@@ -260,6 +277,12 @@ Examples:
|
||||
action="store_true",
|
||||
help="Display file paths and metadata in search results",
|
||||
)
|
||||
search_parser.add_argument(
|
||||
"--embedding-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
@@ -273,7 +296,7 @@ Examples:
|
||||
"--llm",
|
||||
type=str,
|
||||
default="ollama",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
choices=["simulated", "ollama", "hf", "openai", "anthropic"],
|
||||
help="LLM provider (default: ollama)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
@@ -323,7 +346,7 @@ Examples:
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
help="API key for cloud LLM providers (OpenAI, Anthropic)",
|
||||
)
|
||||
|
||||
# List command
|
||||
@@ -1162,6 +1185,11 @@ Examples:
|
||||
print(f"Warning: Could not process {file_path}: {e}")
|
||||
|
||||
# Load other file types with default reader
|
||||
# Exclude PDFs from code_extensions if they were already processed separately
|
||||
other_file_extensions = code_extensions
|
||||
if should_process_pdfs and ".pdf" in code_extensions:
|
||||
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
|
||||
|
||||
try:
|
||||
# Create a custom file filter function using our PathSpec
|
||||
def file_filter(
|
||||
@@ -1177,15 +1205,19 @@ Examples:
|
||||
except (ValueError, OSError):
|
||||
return True # Include files that can't be processed
|
||||
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=code_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
exclude_hidden=not include_hidden,
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
# Only load other file types if there are extensions to process
|
||||
if other_file_extensions:
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=other_file_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
exclude_hidden=not include_hidden,
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
else:
|
||||
other_docs = []
|
||||
|
||||
# Filter documents after loading based on gitignore rules
|
||||
filtered_docs = []
|
||||
@@ -1398,6 +1430,14 @@ Examples:
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
if args.query_prompt_template:
|
||||
# New format: separate templates
|
||||
if args.embedding_prompt_template:
|
||||
embedding_options["build_prompt_template"] = args.embedding_prompt_template
|
||||
embedding_options["query_prompt_template"] = args.query_prompt_template
|
||||
elif args.embedding_prompt_template:
|
||||
# Old format: single template (backward compat)
|
||||
embedding_options["prompt_template"] = args.embedding_prompt_template
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
@@ -1519,6 +1559,11 @@ Examples:
|
||||
print("Invalid input. Aborting search.")
|
||||
return
|
||||
|
||||
# Build provider_options for runtime override
|
||||
provider_options = {}
|
||||
if args.embedding_prompt_template:
|
||||
provider_options["prompt_template"] = args.embedding_prompt_template
|
||||
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
results = searcher.search(
|
||||
query,
|
||||
@@ -1528,6 +1573,7 @@ Examples:
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
provider_options=provider_options if provider_options else None,
|
||||
)
|
||||
|
||||
print(f"Search results for '{query}' (top {len(results)}):")
|
||||
@@ -1575,6 +1621,12 @@ Examples:
|
||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||
if resolved_api_key:
|
||||
llm_config["api_key"] = resolved_api_key
|
||||
elif args.llm == "anthropic":
|
||||
# For Anthropic, pass base_url and API key if provided
|
||||
if args.api_base:
|
||||
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
|
||||
if args.api_key:
|
||||
llm_config["api_key"] = args.api_key
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ Consolidates all embedding computation logic using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure performance
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -40,6 +42,11 @@ EMBEDDING_MODEL_LIMITS = {
|
||||
"text-embedding-ada-002": 8192,
|
||||
}
|
||||
|
||||
# Runtime cache for dynamically discovered token limits
|
||||
# Key: (model_name, base_url), Value: token_limit
|
||||
# Prevents repeated SDK/API calls for the same model
|
||||
_token_limit_cache: dict[tuple[str, str], int] = {}
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
model_name: str,
|
||||
@@ -49,6 +56,7 @@ def get_model_token_limit(
|
||||
"""
|
||||
Get token limit for a given embedding model.
|
||||
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||
Caches discovered limits to prevent repeated API/SDK calls.
|
||||
|
||||
Args:
|
||||
model_name: Name of the embedding model
|
||||
@@ -58,12 +66,33 @@ def get_model_token_limit(
|
||||
Returns:
|
||||
Token limit for the model in tokens
|
||||
"""
|
||||
# Check cache first to avoid repeated SDK/API calls
|
||||
cache_key = (model_name, base_url or "")
|
||||
if cache_key in _token_limit_cache:
|
||||
cached_limit = _token_limit_cache[cache_key]
|
||||
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
|
||||
return cached_limit
|
||||
|
||||
# Try Ollama dynamic discovery if base_url provided
|
||||
if base_url:
|
||||
# Detect Ollama servers by port or "ollama" in URL
|
||||
if "11434" in base_url or "ollama" in base_url.lower():
|
||||
limit = _query_ollama_context_limit(model_name, base_url)
|
||||
if limit:
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Try LM Studio SDK discovery
|
||||
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
|
||||
# Convert HTTP to WebSocket URL
|
||||
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
|
||||
# Remove /v1 suffix if present
|
||||
if ws_url.endswith("/v1"):
|
||||
ws_url = ws_url[:-3]
|
||||
|
||||
limit = _query_lmstudio_context_limit(model_name, ws_url)
|
||||
if limit:
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Fallback to known model registry with version handling (from PR #154)
|
||||
@@ -72,19 +101,25 @@ def get_model_token_limit(
|
||||
|
||||
# Check exact match first
|
||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||
return EMBEDDING_MODEL_LIMITS[model_name]
|
||||
limit = EMBEDDING_MODEL_LIMITS[model_name]
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Check base name match
|
||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Check partial matches for common patterns
|
||||
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
||||
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
|
||||
if known_model in base_model_name or base_model_name in known_model:
|
||||
return limit
|
||||
_token_limit_cache[cache_key] = registry_limit
|
||||
return registry_limit
|
||||
|
||||
# Default fallback
|
||||
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||
_token_limit_cache[cache_key] = default
|
||||
return default
|
||||
|
||||
|
||||
@@ -185,6 +220,91 @@ def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]
|
||||
return None
|
||||
|
||||
|
||||
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||
"""
|
||||
Query LM Studio SDK for model context length via Node.js subprocess.
|
||||
|
||||
Args:
|
||||
model_name: Name of the LM Studio model
|
||||
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
|
||||
|
||||
Returns:
|
||||
Context limit in tokens if found, None otherwise
|
||||
"""
|
||||
# Inline JavaScript using @lmstudio/sdk
|
||||
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
|
||||
js_code = f"""
|
||||
const {{ LMStudioClient }} = require('@lmstudio/sdk');
|
||||
(async () => {{
|
||||
try {{
|
||||
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
|
||||
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
|
||||
const contextLength = await model.getContextLength();
|
||||
await model.unload(); // Unload immediately to respect JIT auto-evict settings
|
||||
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
|
||||
}} catch (error) {{
|
||||
console.error(JSON.stringify({{ error: error.message }}));
|
||||
process.exit(1);
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
|
||||
try:
|
||||
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
|
||||
env = os.environ.copy()
|
||||
|
||||
# Try to get npm global root (works with nvm, brew node, etc.)
|
||||
try:
|
||||
npm_root = subprocess.run(
|
||||
["npm", "root", "-g"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if npm_root.returncode == 0:
|
||||
global_modules = npm_root.stdout.strip()
|
||||
# Append to existing NODE_PATH if present
|
||||
existing_node_path = env.get("NODE_PATH", "")
|
||||
env["NODE_PATH"] = (
|
||||
f"{global_modules}:{existing_node_path}"
|
||||
if existing_node_path
|
||||
else global_modules
|
||||
)
|
||||
except Exception:
|
||||
# If npm not available, continue with existing NODE_PATH
|
||||
pass
|
||||
|
||||
result = subprocess.run(
|
||||
["node", "-e", js_code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.debug(f"LM Studio SDK error: {result.stderr}")
|
||||
return None
|
||||
|
||||
data = json.loads(result.stdout)
|
||||
context_length = data.get("contextLength")
|
||||
|
||||
if context_length and context_length > 0:
|
||||
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
|
||||
return context_length
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.debug("LM Studio SDK query timeout")
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("LM Studio SDK returned invalid JSON")
|
||||
except Exception as e:
|
||||
logger.debug(f"LM Studio SDK query failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: dict[str, Any] = {}
|
||||
|
||||
@@ -232,6 +352,7 @@ def compute_embeddings(
|
||||
model_name,
|
||||
base_url=provider_options.get("base_url"),
|
||||
api_key=provider_options.get("api_key"),
|
||||
provider_options=provider_options,
|
||||
)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
@@ -241,6 +362,7 @@ def compute_embeddings(
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
host=provider_options.get("host"),
|
||||
provider_options=provider_options,
|
||||
)
|
||||
elif mode == "gemini":
|
||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||
@@ -579,6 +701,7 @@ def compute_embeddings_openai(
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
@@ -597,26 +720,40 @@ def compute_embeddings_openai(
|
||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||
)
|
||||
|
||||
resolved_base_url = resolve_openai_base_url(base_url)
|
||||
resolved_api_key = resolve_openai_api_key(api_key)
|
||||
# Extract base_url and api_key from provider_options if not provided directly
|
||||
provider_options = provider_options or {}
|
||||
effective_base_url = base_url or provider_options.get("base_url")
|
||||
effective_api_key = api_key or provider_options.get("api_key")
|
||||
|
||||
resolved_base_url = resolve_openai_base_url(effective_base_url)
|
||||
resolved_api_key = resolve_openai_api_key(effective_api_key)
|
||||
|
||||
if not resolved_api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
# Cache OpenAI client
|
||||
cache_key = f"openai_client::{resolved_base_url}"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
# Create OpenAI client
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
print(f"len of texts: {len(texts)}")
|
||||
|
||||
# Apply prompt template if provided
|
||||
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||
"prompt_template"
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||
texts = [f"{prompt_template}{text}" for text in texts]
|
||||
|
||||
# Query token limit and apply truncation
|
||||
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
|
||||
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
|
||||
texts = truncate_to_token_limit(texts, token_limit)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||
all_embeddings = []
|
||||
@@ -647,7 +784,15 @@ def compute_embeddings_openai(
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Verify we got the expected number of embeddings
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
logger.warning(
|
||||
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
|
||||
)
|
||||
|
||||
# Only take the number of embeddings that match the batch size
|
||||
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {i} failed: {e}")
|
||||
raise
|
||||
@@ -737,6 +882,7 @@ def compute_embeddings_ollama(
|
||||
model_name: str,
|
||||
is_build: bool = False,
|
||||
host: Optional[str] = None,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API with true batch processing.
|
||||
@@ -749,6 +895,7 @@ def compute_embeddings_ollama(
|
||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||
provider_options: Optional provider-specific options (e.g., prompt_template)
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
@@ -885,6 +1032,17 @@ def compute_embeddings_ollama(
|
||||
|
||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||
|
||||
# Apply prompt template if provided
|
||||
provider_options = provider_options or {}
|
||||
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||
"prompt_template"
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||
texts = [f"{prompt_template}{text}" for text in texts]
|
||||
|
||||
# Get model token limit and apply truncation before batching
|
||||
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||
|
||||
@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
query_template: Optional prompt template to prepend to query
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array with shape (1, D)
|
||||
|
||||
@@ -33,6 +33,8 @@ def autodiscover_backends():
|
||||
discovered_backends = []
|
||||
for dist in importlib.metadata.distributions():
|
||||
dist_name = dist.metadata["name"]
|
||||
if dist_name is None:
|
||||
continue
|
||||
if dist_name.startswith("leann-backend-"):
|
||||
backend_module_name = dist_name.replace("-", "_")
|
||||
discovered_backends.append(backend_module_name)
|
||||
|
||||
@@ -71,6 +71,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
or "mips"
|
||||
)
|
||||
|
||||
# Filter out ALL prompt templates from provider_options during search
|
||||
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
|
||||
# The server should never apply templates during search to avoid double-templating
|
||||
search_provider_options = {
|
||||
k: v
|
||||
for k, v in self.embedding_options.items()
|
||||
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
|
||||
}
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
@@ -78,7 +87,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=distance_metric,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
provider_options=self.embedding_options,
|
||||
provider_options=search_provider_options,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||
@@ -90,6 +99,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
@@ -98,10 +108,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
query_template: Optional prompt template to prepend to query
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Apply query template BEFORE any computation path
|
||||
# This ensures template is applied consistently for both server and fallback paths
|
||||
if query_template:
|
||||
query = f"{query_template}{query}"
|
||||
|
||||
# Try to use embedding server if available and requested
|
||||
if use_server_if_available:
|
||||
try:
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any
|
||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||
|
||||
|
||||
def _clean_url(value: str) -> str:
|
||||
@@ -52,6 +53,23 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||
|
||||
|
||||
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
|
||||
"""Resolve the base URL for Anthropic-compatible services."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
|
||||
os.getenv("ANTHROPIC_BASE_URL"),
|
||||
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
|
||||
|
||||
|
||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for OpenAI-compatible services."""
|
||||
|
||||
@@ -61,6 +79,15 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for Anthropic services."""
|
||||
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
return os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
|
||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||
"""Serialize provider options for child processes."""
|
||||
|
||||
|
||||
@@ -53,6 +53,11 @@ leann build my-project --docs $(git ls-files)
|
||||
# Start Claude Code
|
||||
claude
|
||||
```
|
||||
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
|
||||
|
||||
```bash
|
||||
leann build my-project --docs $(git ls-files) --no-recompute
|
||||
```
|
||||
|
||||
## 🚀 Advanced Usage Examples to build the index
|
||||
|
||||
|
||||
@@ -69,7 +69,8 @@ diskann = [
|
||||
# Add a new optional dependency group for document processing
|
||||
documents = [
|
||||
"beautifulsoup4>=4.13.0", # For HTML parsing
|
||||
"python-docx>=0.8.11", # For Word documents
|
||||
"python-docx>=0.8.11", # For Word documents (creating/editing)
|
||||
"docx2txt>=0.9", # For Word documents (text extraction)
|
||||
"openpyxl>=3.1.0", # For Excel files
|
||||
"pandas>=2.2.0", # For data processing
|
||||
]
|
||||
@@ -164,6 +165,7 @@ python_functions = ["test_*"]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"openai: marks tests that require OpenAI API key",
|
||||
"integration: marks tests that require live services (Ollama, LM Studio, etc.)",
|
||||
]
|
||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||
addopts = [
|
||||
|
||||
@@ -36,6 +36,14 @@ Tests DiskANN graph partitioning functionality:
|
||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||
|
||||
### `test_prompt_template_e2e.py`
|
||||
Integration tests for prompt template feature with live embedding services:
|
||||
- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio)
|
||||
- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default)
|
||||
- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk)
|
||||
- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration`
|
||||
- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Install test dependencies:
|
||||
@@ -66,6 +74,12 @@ pytest tests/ -m "not openai"
|
||||
# Skip slow tests
|
||||
pytest tests/ -m "not slow"
|
||||
|
||||
# Skip integration tests (that require live services)
|
||||
pytest tests/ -m "not integration"
|
||||
|
||||
# Run only integration tests (requires LM Studio or Ollama running)
|
||||
pytest tests/test_prompt_template_e2e.py -v -s
|
||||
|
||||
# Run DiskANN partition tests (requires local machine, not CI)
|
||||
pytest tests/test_diskann_partition.py
|
||||
```
|
||||
@@ -101,6 +115,20 @@ The `pytest.ini` file configures:
|
||||
- Custom markers for slow and OpenAI tests
|
||||
- Verbose output with short tracebacks
|
||||
|
||||
### Integration Test Prerequisites
|
||||
|
||||
Integration tests (`test_prompt_template_e2e.py`) require live services:
|
||||
|
||||
**Required:**
|
||||
- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded
|
||||
|
||||
**Optional:**
|
||||
- Ollama running at `http://localhost:11434` for token limit detection tests
|
||||
- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests
|
||||
|
||||
Tests gracefully skip if services are unavailable.
|
||||
|
||||
### Known Issues
|
||||
|
||||
- OpenAI tests are automatically skipped if no API key is provided
|
||||
- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed)
|
||||
|
||||
533
tests/test_cli_prompt_template.py
Normal file
533
tests/test_cli_prompt_template.py
Normal file
@@ -0,0 +1,533 @@
|
||||
"""
|
||||
Tests for CLI argument integration of --embedding-prompt-template.
|
||||
|
||||
These tests verify that:
|
||||
1. The --embedding-prompt-template flag is properly registered on build and search commands
|
||||
2. The template value flows from CLI args to embedding_options dict
|
||||
3. The template is passed through to compute_embeddings() function
|
||||
4. Default behavior (no flag) is handled correctly
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from leann.cli import LeannCLI
|
||||
|
||||
|
||||
class TestCLIPromptTemplateArgument:
|
||||
"""Tests for --embedding-prompt-template on build and search commands."""
|
||||
|
||||
def test_commands_accept_prompt_template_argument(self):
|
||||
"""Verify that build and search parsers accept --embedding-prompt-template flag."""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
template_value = "search_query: "
|
||||
|
||||
# Test build command
|
||||
build_args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
"/tmp/test-docs",
|
||||
"--embedding-prompt-template",
|
||||
template_value,
|
||||
]
|
||||
)
|
||||
assert build_args.command == "build"
|
||||
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||
"build command should have embedding_prompt_template attribute"
|
||||
)
|
||||
assert build_args.embedding_prompt_template == template_value
|
||||
|
||||
# Test search command
|
||||
search_args = parser.parse_args(
|
||||
["search", "test-index", "my query", "--embedding-prompt-template", template_value]
|
||||
)
|
||||
assert search_args.command == "search"
|
||||
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||
"search command should have embedding_prompt_template attribute"
|
||||
)
|
||||
assert search_args.embedding_prompt_template == template_value
|
||||
|
||||
def test_commands_default_to_none(self):
|
||||
"""Verify default value is None when flag not provided (backward compatibility)."""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
# Test build command default
|
||||
build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"])
|
||||
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||
"build command should have embedding_prompt_template attribute"
|
||||
)
|
||||
assert build_args.embedding_prompt_template is None, (
|
||||
"Build default value should be None when flag not provided"
|
||||
)
|
||||
|
||||
# Test search command default
|
||||
search_args = parser.parse_args(["search", "test-index", "my query"])
|
||||
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||
"search command should have embedding_prompt_template attribute"
|
||||
)
|
||||
assert search_args.embedding_prompt_template is None, (
|
||||
"Search default value should be None when flag not provided"
|
||||
)
|
||||
|
||||
|
||||
class TestBuildCommandPromptTemplateArgumentExtras:
|
||||
"""Additional build-specific tests for prompt template argument."""
|
||||
|
||||
def test_build_command_prompt_template_with_multiword_value(self):
|
||||
"""
|
||||
Verify that template values with spaces are handled correctly.
|
||||
|
||||
Templates like "search_document: " or "Represent this sentence for searching: "
|
||||
should be accepted as a single string argument.
|
||||
"""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
template = "Represent this sentence for searching: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
"/tmp/test-docs",
|
||||
"--embedding-prompt-template",
|
||||
template,
|
||||
]
|
||||
)
|
||||
|
||||
assert args.embedding_prompt_template == template
|
||||
|
||||
|
||||
class TestPromptTemplateStoredInEmbeddingOptions:
|
||||
"""Tests for template storage in embedding_options dict."""
|
||||
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_prompt_template_stored_in_embedding_options_on_build(
|
||||
self, mock_builder_class, tmp_path
|
||||
):
|
||||
"""
|
||||
Verify that when --embedding-prompt-template is provided to build command,
|
||||
the value is stored in embedding_options dict passed to LeannBuilder.
|
||||
|
||||
This test will fail because the CLI doesn't currently process this argument
|
||||
and add it to embedding_options.
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
# Create CLI and run build command
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
template = "search_query: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--embedding-prompt-template",
|
||||
template,
|
||||
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||
]
|
||||
)
|
||||
|
||||
# Run the build command
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that LeannBuilder was called with embedding_options containing prompt_template
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
assert embedding_options is not None, (
|
||||
"embedding_options should not be None when template provided"
|
||||
)
|
||||
assert "prompt_template" in embedding_options, (
|
||||
"embedding_options should contain 'prompt_template' key"
|
||||
)
|
||||
assert embedding_options["prompt_template"] == template, (
|
||||
f"Template should be '{template}', got {embedding_options.get('prompt_template')}"
|
||||
)
|
||||
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path):
|
||||
"""
|
||||
Verify that when --embedding-prompt-template is NOT provided,
|
||||
embedding_options either doesn't have the key or it's None.
|
||||
|
||||
This ensures we don't pass empty/None values unnecessarily.
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||
]
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that if embedding_options is passed, it doesn't have prompt_template
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
if call_kwargs.get("embedding_options"):
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
# Either the key shouldn't exist, or it should be None
|
||||
assert (
|
||||
"prompt_template" not in embedding_options
|
||||
or embedding_options["prompt_template"] is None
|
||||
), "prompt_template should not be set when flag not provided"
|
||||
|
||||
# R1 Tests: Build-time separate template storage
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_build_stores_separate_templates(self, mock_builder_class, tmp_path):
|
||||
"""
|
||||
R1 Test 1: Verify that when both --embedding-prompt-template and
|
||||
--query-prompt-template are provided to build command, both values
|
||||
are stored separately in embedding_options dict as build_prompt_template
|
||||
and query_prompt_template.
|
||||
|
||||
This test will fail because:
|
||||
1. CLI doesn't accept --query-prompt-template flag yet
|
||||
2. CLI doesn't store templates as separate build_prompt_template and
|
||||
query_prompt_template keys
|
||||
|
||||
Expected behavior after implementation:
|
||||
- .meta.json contains: {"embedding_options": {
|
||||
"build_prompt_template": "doc: ",
|
||||
"query_prompt_template": "query: "
|
||||
}}
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
build_template = "doc: "
|
||||
query_template = "query: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--embedding-prompt-template",
|
||||
build_template,
|
||||
"--query-prompt-template",
|
||||
query_template,
|
||||
"--force",
|
||||
]
|
||||
)
|
||||
|
||||
# Run the build command
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that LeannBuilder was called with separate template keys
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
assert embedding_options is not None, (
|
||||
"embedding_options should not be None when templates provided"
|
||||
)
|
||||
|
||||
assert "build_prompt_template" in embedding_options, (
|
||||
"embedding_options should contain 'build_prompt_template' key"
|
||||
)
|
||||
assert embedding_options["build_prompt_template"] == build_template, (
|
||||
f"build_prompt_template should be '{build_template}'"
|
||||
)
|
||||
|
||||
assert "query_prompt_template" in embedding_options, (
|
||||
"embedding_options should contain 'query_prompt_template' key"
|
||||
)
|
||||
assert embedding_options["query_prompt_template"] == query_template, (
|
||||
f"query_prompt_template should be '{query_template}'"
|
||||
)
|
||||
|
||||
# Old key should NOT be present when using new separate template format
|
||||
assert "prompt_template" not in embedding_options, (
|
||||
"Old 'prompt_template' key should not be present with separate templates"
|
||||
)
|
||||
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path):
|
||||
"""
|
||||
R1 Test 2: Verify backward compatibility - when only
|
||||
--embedding-prompt-template is provided (old behavior), it should
|
||||
still be stored as 'prompt_template' in embedding_options.
|
||||
|
||||
This ensures existing workflows continue to work unchanged.
|
||||
|
||||
This test currently passes because it matches existing behavior, but it
|
||||
documents the requirement that this behavior must be preserved after
|
||||
implementing the separate template feature.
|
||||
|
||||
Expected behavior:
|
||||
- .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}}
|
||||
- No build_prompt_template or query_prompt_template keys
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
template = "prompt: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--embedding-prompt-template",
|
||||
template,
|
||||
"--force",
|
||||
]
|
||||
)
|
||||
|
||||
# Run the build command
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that LeannBuilder was called with old format
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
assert embedding_options is not None, (
|
||||
"embedding_options should not be None when template provided"
|
||||
)
|
||||
|
||||
assert "prompt_template" in embedding_options, (
|
||||
"embedding_options should contain old 'prompt_template' key for backward compat"
|
||||
)
|
||||
assert embedding_options["prompt_template"] == template, (
|
||||
f"prompt_template should be '{template}'"
|
||||
)
|
||||
|
||||
# New keys should NOT be present in backward compat mode
|
||||
assert "build_prompt_template" not in embedding_options, (
|
||||
"build_prompt_template should not be present with single template flag"
|
||||
)
|
||||
assert "query_prompt_template" not in embedding_options, (
|
||||
"query_prompt_template should not be present with single template flag"
|
||||
)
|
||||
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_build_no_templates(self, mock_builder_class, tmp_path):
|
||||
"""
|
||||
R1 Test 3: Verify that when no template flags are provided,
|
||||
embedding_options has no prompt template keys.
|
||||
|
||||
This ensures clean defaults and no unnecessary keys in .meta.json.
|
||||
|
||||
This test currently passes because it matches existing behavior, but it
|
||||
documents the requirement that this behavior must be preserved after
|
||||
implementing the separate template feature.
|
||||
|
||||
Expected behavior:
|
||||
- .meta.json has no prompt_template, build_prompt_template, or
|
||||
query_prompt_template keys (or embedding_options is empty/None)
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"])
|
||||
|
||||
# Run the build command
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that no template keys are present
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
if call_kwargs.get("embedding_options"):
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
|
||||
# None of the template keys should be present
|
||||
assert "prompt_template" not in embedding_options, (
|
||||
"prompt_template should not be present when no flags provided"
|
||||
)
|
||||
assert "build_prompt_template" not in embedding_options, (
|
||||
"build_prompt_template should not be present when no flags provided"
|
||||
)
|
||||
assert "query_prompt_template" not in embedding_options, (
|
||||
"query_prompt_template should not be present when no flags provided"
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplateFlowsToComputeEmbeddings:
|
||||
"""Tests for template flowing through to compute_embeddings function."""
|
||||
|
||||
@patch("leann.api.compute_embeddings")
|
||||
def test_prompt_template_flows_to_compute_embeddings_via_provider_options(
|
||||
self, mock_compute_embeddings, tmp_path
|
||||
):
|
||||
"""
|
||||
Verify that the prompt template flows from CLI args through LeannBuilder
|
||||
to compute_embeddings() function via provider_options parameter.
|
||||
|
||||
This is an integration test that verifies the complete flow:
|
||||
CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options)
|
||||
|
||||
This test will fail because:
|
||||
1. CLI doesn't capture the argument yet
|
||||
2. embedding_options doesn't include prompt_template
|
||||
3. LeannBuilder doesn't pass it through to compute_embeddings
|
||||
"""
|
||||
# Mock compute_embeddings to return dummy embeddings as numpy array
|
||||
import numpy as np
|
||||
|
||||
mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
|
||||
# Use real LeannBuilder (not mocked) to test the actual flow
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a simple document
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
template = "search_document: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--embedding-prompt-template",
|
||||
template,
|
||||
"--backend-name",
|
||||
"hnsw", # Use hnsw backend
|
||||
"--force", # Force rebuild to ensure index is created
|
||||
]
|
||||
)
|
||||
|
||||
# This should fail because the flow isn't implemented yet
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Verify compute_embeddings was called with provider_options containing prompt_template
|
||||
assert mock_compute_embeddings.called, "compute_embeddings should have been called"
|
||||
|
||||
# Check the call arguments
|
||||
call_kwargs = mock_compute_embeddings.call_args.kwargs
|
||||
assert "provider_options" in call_kwargs, (
|
||||
"compute_embeddings should receive provider_options parameter"
|
||||
)
|
||||
|
||||
provider_options = call_kwargs["provider_options"]
|
||||
assert provider_options is not None, "provider_options should not be None"
|
||||
assert "prompt_template" in provider_options, (
|
||||
"provider_options should contain prompt_template key"
|
||||
)
|
||||
assert provider_options["prompt_template"] == template, (
|
||||
f"Template should be '{template}', got {provider_options.get('prompt_template')}"
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplateArgumentHelp:
|
||||
"""Tests for argument help text and documentation."""
|
||||
|
||||
def test_build_command_prompt_template_has_help_text(self):
|
||||
"""
|
||||
Verify that --embedding-prompt-template has descriptive help text.
|
||||
|
||||
Good help text is crucial for CLI usability.
|
||||
"""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
# Get the build subparser
|
||||
# This is a bit tricky - we need to parse to get the help
|
||||
# We'll check that the help includes relevant keywords
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
f = io.StringIO()
|
||||
try:
|
||||
with redirect_stdout(f):
|
||||
parser.parse_args(["build", "--help"])
|
||||
except SystemExit:
|
||||
pass # --help causes sys.exit(0)
|
||||
|
||||
help_text = f.getvalue()
|
||||
assert "--embedding-prompt-template" in help_text, (
|
||||
"Help text should mention --embedding-prompt-template"
|
||||
)
|
||||
# Check for keywords that should be in the help
|
||||
help_lower = help_text.lower()
|
||||
assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), (
|
||||
"Help text should explain what the prompt template does"
|
||||
)
|
||||
|
||||
def test_search_command_prompt_template_has_help_text(self):
|
||||
"""
|
||||
Verify that search command also has help text for --embedding-prompt-template.
|
||||
"""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
f = io.StringIO()
|
||||
try:
|
||||
with redirect_stdout(f):
|
||||
parser.parse_args(["search", "--help"])
|
||||
except SystemExit:
|
||||
pass # --help causes sys.exit(0)
|
||||
|
||||
help_text = f.getvalue()
|
||||
assert "--embedding-prompt-template" in help_text, (
|
||||
"Search help text should mention --embedding-prompt-template"
|
||||
)
|
||||
281
tests/test_embedding_prompt_template.py
Normal file
281
tests/test_embedding_prompt_template.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Unit tests for prompt template prepending in OpenAI embeddings.
|
||||
|
||||
This test suite defines the contract for prompt template functionality that allows
|
||||
users to prepend a consistent prompt to all embedding inputs. These tests verify:
|
||||
|
||||
1. Template prepending to all input texts before embedding computation
|
||||
2. Graceful handling of None/missing provider_options
|
||||
3. Empty string template behavior (no-op)
|
||||
4. Logging of template application for observability
|
||||
5. Template application before token truncation
|
||||
|
||||
All tests are written in Red Phase - they should FAIL initially because the
|
||||
implementation does not exist yet.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from leann.embedding_compute import compute_embeddings_openai
|
||||
|
||||
|
||||
class TestPromptTemplatePrepending:
|
||||
"""Tests for prompt template prepending in compute_embeddings_openai."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client(self):
|
||||
"""Create mock OpenAI client that captures input texts."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the embeddings.create response
|
||||
mock_response = Mock()
|
||||
mock_response.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||
]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
return mock_client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_module(self, mock_openai_client, monkeypatch):
|
||||
"""Mock the openai module to return our mock client."""
|
||||
# Mock the API key environment variable
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking")
|
||||
|
||||
# openai is imported inside the function, so we need to patch it there
|
||||
with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai:
|
||||
yield mock_openai
|
||||
|
||||
def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client):
|
||||
"""Verify template is prepended to all input texts.
|
||||
|
||||
When provider_options contains "prompt_template", that template should
|
||||
be prepended to every text in the input list before sending to OpenAI API.
|
||||
|
||||
This is the core functionality: the template acts as a consistent prefix
|
||||
that provides context or instruction for the embedding model.
|
||||
"""
|
||||
texts = ["First document", "Second document"]
|
||||
template = "search_document: "
|
||||
provider_options = {"prompt_template": template}
|
||||
|
||||
# Call compute_embeddings_openai with provider_options
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
# Verify embeddings.create was called with templated texts
|
||||
mock_openai_client.embeddings.create.assert_called_once()
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
|
||||
# Extract the input texts sent to API
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
|
||||
# Verify template was prepended to all texts
|
||||
assert len(sent_texts) == 2, "Should send same number of texts"
|
||||
assert sent_texts[0] == "search_document: First document", (
|
||||
"Template should be prepended to first text"
|
||||
)
|
||||
assert sent_texts[1] == "search_document: Second document", (
|
||||
"Template should be prepended to second text"
|
||||
)
|
||||
|
||||
# Verify result is valid embeddings array
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == (2, 3), "Should return correct shape"
|
||||
|
||||
def test_template_not_applied_when_missing_or_empty(
|
||||
self, mock_openai_module, mock_openai_client
|
||||
):
|
||||
"""Verify template not applied when provider_options is None, missing key, or empty string.
|
||||
|
||||
This consolidated test covers three scenarios where templates should NOT be applied:
|
||||
1. provider_options is None (default behavior)
|
||||
2. provider_options exists but missing 'prompt_template' key
|
||||
3. prompt_template is explicitly set to empty string ""
|
||||
|
||||
In all cases, texts should be sent to the API unchanged.
|
||||
"""
|
||||
# Scenario 1: None provider_options
|
||||
texts = ["Original text one", "Original text two"]
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=None,
|
||||
)
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
assert sent_texts[0] == "Original text one", (
|
||||
"Text should be unchanged with None provider_options"
|
||||
)
|
||||
assert sent_texts[1] == "Original text two"
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == (2, 3)
|
||||
|
||||
# Reset mock for next scenario
|
||||
mock_openai_client.reset_mock()
|
||||
mock_response = Mock()
|
||||
mock_response.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||
]
|
||||
mock_openai_client.embeddings.create.return_value = mock_response
|
||||
|
||||
# Scenario 2: Missing 'prompt_template' key
|
||||
texts = ["Text without template", "Another text"]
|
||||
provider_options = {"base_url": "https://api.openai.com/v1"}
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key"
|
||||
assert sent_texts[1] == "Another text"
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
# Reset mock for next scenario
|
||||
mock_openai_client.reset_mock()
|
||||
mock_openai_client.embeddings.create.return_value = mock_response
|
||||
|
||||
# Scenario 3: Empty string template
|
||||
texts = ["Text one", "Text two"]
|
||||
provider_options = {"prompt_template": ""}
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
assert sent_texts[0] == "Text one", "Empty template should not modify text"
|
||||
assert sent_texts[1] == "Text two"
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client):
|
||||
"""Verify template is prepended in all batches when texts exceed batch size.
|
||||
|
||||
OpenAI API has batch size limits. When input texts are split into
|
||||
multiple batches, the template should be prepended to texts in every batch.
|
||||
|
||||
This ensures consistency across all API calls.
|
||||
"""
|
||||
# Create many texts that will be split into multiple batches
|
||||
texts = [f"Document {i}" for i in range(1000)]
|
||||
template = "passage: "
|
||||
provider_options = {"prompt_template": template}
|
||||
|
||||
# Mock multiple batch responses
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)]
|
||||
mock_openai_client.embeddings.create.return_value = mock_response
|
||||
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
# Verify embeddings.create was called multiple times (batching)
|
||||
assert mock_openai_client.embeddings.create.call_count >= 2, (
|
||||
"Should make multiple API calls for large text list"
|
||||
)
|
||||
|
||||
# Verify template was prepended in ALL batches
|
||||
for call in mock_openai_client.embeddings.create.call_args_list:
|
||||
sent_texts = call.kwargs["input"]
|
||||
for text in sent_texts:
|
||||
assert text.startswith(template), (
|
||||
f"All texts in all batches should start with template. Got: {text}"
|
||||
)
|
||||
|
||||
# Verify result shape
|
||||
assert result.shape[0] == 1000, "Should return embeddings for all texts"
|
||||
|
||||
def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client):
|
||||
"""Verify template with special characters is handled correctly.
|
||||
|
||||
Templates may contain special characters, Unicode, newlines, etc.
|
||||
These should all be prepended correctly without encoding issues.
|
||||
"""
|
||||
texts = ["Document content"]
|
||||
# Template with various special characters
|
||||
template = "🔍 Search query [EN]: "
|
||||
provider_options = {"prompt_template": template}
|
||||
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
# Verify special characters in template were preserved
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
|
||||
assert sent_texts[0] == "🔍 Search query [EN]: Document content", (
|
||||
"Special characters in template should be preserved"
|
||||
)
|
||||
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_prompt_template_integration_with_existing_validation(
|
||||
self, mock_openai_module, mock_openai_client
|
||||
):
|
||||
"""Verify template works with existing input validation.
|
||||
|
||||
compute_embeddings_openai has validation for empty texts and whitespace.
|
||||
Template prepending should happen AFTER validation, so validation errors
|
||||
are thrown based on original texts, not templated texts.
|
||||
|
||||
This ensures users get clear error messages about their input.
|
||||
"""
|
||||
# Empty text should still raise ValueError even with template
|
||||
texts = [""]
|
||||
provider_options = {"prompt_template": "prefix: "}
|
||||
|
||||
with pytest.raises(ValueError, match="empty/invalid"):
|
||||
compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
def test_prompt_template_with_api_key_and_base_url(
|
||||
self, mock_openai_module, mock_openai_client
|
||||
):
|
||||
"""Verify template works alongside other provider_options.
|
||||
|
||||
provider_options may contain multiple settings: prompt_template,
|
||||
base_url, api_key. All should work together correctly.
|
||||
"""
|
||||
texts = ["Test document"]
|
||||
provider_options = {
|
||||
"prompt_template": "embed: ",
|
||||
"base_url": "https://custom.api.com/v1",
|
||||
"api_key": "test-key-123",
|
||||
}
|
||||
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
# Verify template was applied
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
assert sent_texts[0] == "embed: Test document"
|
||||
|
||||
# Verify OpenAI client was created with correct base_url
|
||||
mock_openai_module.assert_called()
|
||||
client_init_kwargs = mock_openai_module.call_args.kwargs
|
||||
assert client_init_kwargs["base_url"] == "https://custom.api.com/v1"
|
||||
assert client_init_kwargs["api_key"] == "test-key-123"
|
||||
|
||||
assert isinstance(result, np.ndarray)
|
||||
315
tests/test_lmstudio_bridge.py
Normal file
315
tests/test_lmstudio_bridge.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Unit tests for LM Studio TypeScript SDK bridge functionality.
|
||||
|
||||
This test suite defines the contract for the LM Studio SDK bridge that queries
|
||||
model context length via Node.js subprocess. These tests verify:
|
||||
|
||||
1. Successful SDK query returns context length
|
||||
2. Graceful fallback when Node.js not installed (FileNotFoundError)
|
||||
3. Graceful fallback when SDK not installed (npm error)
|
||||
4. Timeout handling (subprocess.TimeoutExpired)
|
||||
5. Invalid JSON response handling
|
||||
|
||||
All tests are written in Red Phase - they should FAIL initially because the
|
||||
`_query_lmstudio_context_limit` function does not exist yet.
|
||||
|
||||
The function contract:
|
||||
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
|
||||
- Outputs: context_length (int) or None on error
|
||||
- Requirements:
|
||||
1. Call Node.js with inline JavaScript using @lmstudio/sdk
|
||||
2. 10-second timeout (accounts for Node.js startup)
|
||||
3. Graceful fallback on any error (returns None, doesn't raise)
|
||||
4. Parse JSON response with contextLength field
|
||||
5. Log errors at debug level (not warning/error)
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
# Try to import the function - if it doesn't exist, tests will fail as expected
|
||||
try:
|
||||
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||
except ImportError:
|
||||
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
|
||||
def _query_lmstudio_context_limit(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
|
||||
)
|
||||
|
||||
|
||||
class TestLMStudioBridge:
|
||||
"""Tests for LM Studio TypeScript SDK bridge integration."""
|
||||
|
||||
def test_query_lmstudio_success(self, monkeypatch):
|
||||
"""Verify successful SDK query returns context length.
|
||||
|
||||
When the Node.js subprocess successfully queries the LM Studio SDK,
|
||||
it should return a JSON response with contextLength field. The function
|
||||
should parse this and return the integer context length.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
# Verify timeout is set to 10 seconds
|
||||
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
|
||||
|
||||
# Verify capture_output and text=True are set
|
||||
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
|
||||
assert kwargs.get("text") is True, "Should decode output as text"
|
||||
|
||||
# Return successful JSON response
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
# Test with typical LM Studio model
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit == 8192, "Should return context length from SDK response"
|
||||
|
||||
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
|
||||
"""Verify graceful fallback when Node.js not installed.
|
||||
|
||||
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
|
||||
The function should catch this and return None (graceful fallback to registry).
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
raise FileNotFoundError("node: command not found")
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when Node.js not installed"
|
||||
|
||||
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
|
||||
"""Verify graceful fallback when @lmstudio/sdk not installed.
|
||||
|
||||
When the SDK npm package is not installed, Node.js will return non-zero
|
||||
exit code with error message in stderr. The function should detect this
|
||||
and return None.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 1
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = (
|
||||
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
|
||||
)
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when SDK not installed"
|
||||
|
||||
def test_query_lmstudio_timeout(self, monkeypatch):
|
||||
"""Verify graceful fallback when subprocess times out.
|
||||
|
||||
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
|
||||
not responding), subprocess.TimeoutExpired should be raised. The function
|
||||
should catch this and return None.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None on timeout"
|
||||
|
||||
def test_query_lmstudio_invalid_json(self, monkeypatch):
|
||||
"""Verify graceful fallback when response is invalid JSON.
|
||||
|
||||
When the subprocess returns malformed JSON (e.g., due to SDK error),
|
||||
json.loads will raise ValueError/JSONDecodeError. The function should
|
||||
catch this and return None.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "This is not valid JSON"
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when JSON parsing fails"
|
||||
|
||||
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
|
||||
"""Verify graceful fallback when JSON lacks contextLength field.
|
||||
|
||||
When the SDK returns valid JSON but without the expected contextLength
|
||||
field (e.g., error response), the function should return None.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="nonexistent-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when contextLength field missing"
|
||||
|
||||
def test_query_lmstudio_null_context_length(self, monkeypatch):
|
||||
"""Verify graceful fallback when contextLength is null.
|
||||
|
||||
When the SDK returns contextLength: null (model couldn't be loaded),
|
||||
the function should return None for registry fallback.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="test-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when contextLength is null"
|
||||
|
||||
def test_query_lmstudio_zero_context_length(self, monkeypatch):
|
||||
"""Verify graceful fallback when contextLength is zero.
|
||||
|
||||
When the SDK returns contextLength: 0 (invalid value), the function
|
||||
should return None to trigger registry fallback.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="test-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when contextLength is zero"
|
||||
|
||||
def test_query_lmstudio_with_custom_port(self, monkeypatch):
|
||||
"""Verify SDK query works with non-default WebSocket port.
|
||||
|
||||
LM Studio can run on custom ports. The function should pass the
|
||||
provided base_url to the Node.js subprocess.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
# Verify the base_url argument is passed correctly
|
||||
command = args[0] if args else kwargs.get("args", [])
|
||||
assert "ws://localhost:8080" in " ".join(command), (
|
||||
"Should pass custom port to subprocess"
|
||||
)
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:8080"
|
||||
)
|
||||
|
||||
assert limit == 4096, "Should work with custom WebSocket port"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"context_length,expected",
|
||||
[
|
||||
(512, 512), # Small context
|
||||
(2048, 2048), # Common context
|
||||
(8192, 8192), # Large context
|
||||
(32768, 32768), # Very large context
|
||||
],
|
||||
)
|
||||
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
|
||||
"""Verify SDK query handles various context length values.
|
||||
|
||||
Different models have different context lengths. The function should
|
||||
correctly parse and return any positive integer value.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="test-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit == expected, f"Should return {expected} for context length {context_length}"
|
||||
|
||||
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
|
||||
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
|
||||
|
||||
Following the graceful fallback pattern from Ollama implementation,
|
||||
errors should be logged at debug level to avoid alarming users when
|
||||
fallback to registry works fine.
|
||||
"""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
raise FileNotFoundError("node: command not found")
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
|
||||
|
||||
# Check that debug logging occurred (not warning/error)
|
||||
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
|
||||
assert len(debug_logs) > 0, "Should log error at DEBUG level"
|
||||
|
||||
# Verify no WARNING or ERROR logs
|
||||
warning_or_error_logs = [
|
||||
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
|
||||
]
|
||||
assert len(warning_or_error_logs) == 0, (
|
||||
"Should not log at WARNING/ERROR level for expected failures"
|
||||
)
|
||||
400
tests/test_prompt_template_e2e.py
Normal file
400
tests/test_prompt_template_e2e.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""End-to-end integration tests for prompt template and token limit features.
|
||||
|
||||
These tests verify real-world functionality with live services:
|
||||
- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support
|
||||
- Ollama with dynamic token limit detection
|
||||
- Hybrid token limit discovery mechanism
|
||||
|
||||
Run with: pytest tests/test_prompt_template_e2e.py -v -s
|
||||
Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration"
|
||||
|
||||
Prerequisites:
|
||||
1. LM Studio running with embedding model: http://localhost:1234
|
||||
2. [Optional] Ollama running: ollama serve
|
||||
3. [Optional] Ollama model: ollama pull nomic-embed-text
|
||||
4. [Optional] Node.js + @lmstudio/sdk for context length detection
|
||||
"""
|
||||
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
from leann.embedding_compute import (
|
||||
compute_embeddings_ollama,
|
||||
compute_embeddings_openai,
|
||||
get_model_token_limit,
|
||||
)
|
||||
|
||||
# Test markers for conditional execution
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool:
|
||||
"""Check if a service is available on the given host:port."""
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
result = sock.connect_ex((host, port))
|
||||
sock.close()
|
||||
return result == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def check_ollama_available() -> bool:
|
||||
"""Check if Ollama service is available."""
|
||||
if not check_service_available("localhost", 11434):
|
||||
return False
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def check_lmstudio_available() -> bool:
|
||||
"""Check if LM Studio service is available."""
|
||||
if not check_service_available("localhost", 1234):
|
||||
return False
|
||||
try:
|
||||
response = requests.get("http://localhost:1234/v1/models", timeout=2.0)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_lmstudio_first_model() -> str:
|
||||
"""Get the first available model from LM Studio."""
|
||||
try:
|
||||
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||
data = response.json()
|
||||
models = data.get("data", [])
|
||||
if models:
|
||||
return models[0]["id"]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class TestPromptTemplateOpenAI:
|
||||
"""End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio)."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234"
|
||||
)
|
||||
def test_lmstudio_embedding_with_prompt_template(self):
|
||||
"""Test prompt templates with LM Studio using OpenAI-compatible API."""
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
|
||||
texts = ["artificial intelligence", "machine learning"]
|
||||
prompt_template = "search_query: "
|
||||
|
||||
# Get embeddings with prompt template via provider_options
|
||||
provider_options = {"prompt_template": prompt_template}
|
||||
embeddings = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
base_url="http://localhost:1234/v1",
|
||||
api_key="lm-studio", # LM Studio doesn't require real key
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
assert embeddings is not None
|
||||
assert len(embeddings) == 2
|
||||
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||
assert all(len(emb) > 0 for emb in embeddings)
|
||||
|
||||
logger.info(
|
||||
f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||
def test_lmstudio_prompt_template_affects_embeddings(self):
|
||||
"""Verify that prompt templates actually change embedding values."""
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
|
||||
text = "machine learning"
|
||||
base_url = "http://localhost:1234/v1"
|
||||
api_key = "lm-studio"
|
||||
|
||||
# Get embeddings without template
|
||||
embeddings_no_template = compute_embeddings_openai(
|
||||
texts=[text],
|
||||
model_name=model_name,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
provider_options={},
|
||||
)
|
||||
|
||||
# Get embeddings with template
|
||||
embeddings_with_template = compute_embeddings_openai(
|
||||
texts=[text],
|
||||
model_name=model_name,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
provider_options={"prompt_template": "search_query: "},
|
||||
)
|
||||
|
||||
# Embeddings should be different when template is applied
|
||||
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||
|
||||
logger.info("✓ Prompt template changes embedding values as expected")
|
||||
|
||||
|
||||
class TestPromptTemplateOllama:
|
||||
"""End-to-end tests for prompt template with Ollama."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_ollama_available(), reason="Ollama service not available on localhost:11434"
|
||||
)
|
||||
def test_ollama_embedding_with_prompt_template(self):
|
||||
"""Test prompt templates with Ollama using any available embedding model."""
|
||||
# Get any available embedding model
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
models = response.json().get("models", [])
|
||||
|
||||
embedding_models = []
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
base_name = name.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||
embedding_models.append(name)
|
||||
|
||||
if not embedding_models:
|
||||
pytest.skip("No embedding models available in Ollama")
|
||||
|
||||
model_name = embedding_models[0]
|
||||
|
||||
texts = ["artificial intelligence", "machine learning"]
|
||||
prompt_template = "search_query: "
|
||||
|
||||
# Get embeddings with prompt template via provider_options
|
||||
provider_options = {"prompt_template": prompt_template}
|
||||
embeddings = compute_embeddings_ollama(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
is_build=False,
|
||||
host="http://localhost:11434",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
assert embeddings is not None
|
||||
assert len(embeddings) == 2
|
||||
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||
assert all(len(emb) > 0 for emb in embeddings)
|
||||
|
||||
logger.info(
|
||||
f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||
|
||||
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||
def test_ollama_prompt_template_affects_embeddings(self):
|
||||
"""Verify that prompt templates actually change embedding values with Ollama."""
|
||||
# Get any available embedding model
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
models = response.json().get("models", [])
|
||||
|
||||
embedding_models = []
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
base_name = name.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||
embedding_models.append(name)
|
||||
|
||||
if not embedding_models:
|
||||
pytest.skip("No embedding models available in Ollama")
|
||||
|
||||
model_name = embedding_models[0]
|
||||
text = "machine learning"
|
||||
host = "http://localhost:11434"
|
||||
|
||||
# Get embeddings without template
|
||||
embeddings_no_template = compute_embeddings_ollama(
|
||||
texts=[text], model_name=model_name, is_build=False, host=host, provider_options={}
|
||||
)
|
||||
|
||||
# Get embeddings with template
|
||||
embeddings_with_template = compute_embeddings_ollama(
|
||||
texts=[text],
|
||||
model_name=model_name,
|
||||
is_build=False,
|
||||
host=host,
|
||||
provider_options={"prompt_template": "search_query: "},
|
||||
)
|
||||
|
||||
# Embeddings should be different when template is applied
|
||||
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||
|
||||
logger.info("✓ Ollama prompt template changes embedding values as expected")
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||
|
||||
|
||||
class TestLMStudioSDK:
|
||||
"""End-to-end tests for LM Studio SDK integration."""
|
||||
|
||||
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||
def test_lmstudio_model_listing(self):
|
||||
"""Test that we can list models from LM Studio."""
|
||||
try:
|
||||
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
|
||||
models = data["data"]
|
||||
logger.info(f"✓ LM Studio models available: {len(models)}")
|
||||
|
||||
if models:
|
||||
logger.info(f" First model: {models[0].get('id', 'unknown')}")
|
||||
except Exception as e:
|
||||
pytest.skip(f"LM Studio API error: {e}")
|
||||
|
||||
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||
def test_lmstudio_sdk_context_length_detection(self):
|
||||
"""Test context length detection via LM Studio SDK bridge (requires Node.js + SDK)."""
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||
|
||||
# SDK requires WebSocket URL (ws://)
|
||||
context_length = _query_lmstudio_context_limit(
|
||||
model_name=model_name, base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
if context_length is None:
|
||||
logger.warning(
|
||||
"⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)"
|
||||
)
|
||||
pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable")
|
||||
else:
|
||||
assert context_length > 0
|
||||
logger.info(
|
||||
f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}"
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
pytest.skip("_query_lmstudio_context_limit not implemented yet")
|
||||
except Exception as e:
|
||||
logger.error(f"LM Studio SDK test error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class TestOllamaTokenLimit:
|
||||
"""End-to-end tests for Ollama token limit discovery."""
|
||||
|
||||
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||
def test_ollama_token_limit_detection(self):
|
||||
"""Test dynamic token limit detection from Ollama /api/show endpoint."""
|
||||
# Get any available embedding model
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
models = response.json().get("models", [])
|
||||
|
||||
embedding_models = []
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
base_name = name.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||
embedding_models.append(name)
|
||||
|
||||
if not embedding_models:
|
||||
pytest.skip("No embedding models available in Ollama")
|
||||
|
||||
test_model = embedding_models[0]
|
||||
|
||||
# Test token limit detection
|
||||
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||
|
||||
assert limit > 0
|
||||
logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}")
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not test Ollama token detection: {e}")
|
||||
|
||||
|
||||
class TestHybridTokenLimit:
|
||||
"""End-to-end tests for hybrid token limit discovery mechanism."""
|
||||
|
||||
def test_hybrid_discovery_registry_fallback(self):
|
||||
"""Test fallback to static registry for known OpenAI models."""
|
||||
# Use a known OpenAI model (should be in registry)
|
||||
limit = get_model_token_limit(
|
||||
model_name="text-embedding-3-small",
|
||||
base_url="http://fake-server:9999", # Fake URL to force registry lookup
|
||||
)
|
||||
|
||||
# text-embedding-3-small should have 8192 in registry
|
||||
assert limit == 8192
|
||||
logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens")
|
||||
|
||||
def test_hybrid_discovery_default_fallback(self):
|
||||
"""Test fallback to safe default for completely unknown models."""
|
||||
limit = get_model_token_limit(
|
||||
model_name="completely-unknown-model-xyz-12345",
|
||||
base_url="http://fake-server:9999",
|
||||
default=512,
|
||||
)
|
||||
|
||||
# Should get the specified default
|
||||
assert limit == 512
|
||||
logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens")
|
||||
|
||||
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||
def test_hybrid_discovery_ollama_dynamic_first(self):
|
||||
"""Test that Ollama models use dynamic discovery first."""
|
||||
# Get any available embedding model
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
models = response.json().get("models", [])
|
||||
|
||||
embedding_models = []
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
base_name = name.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||
embedding_models.append(name)
|
||||
|
||||
if not embedding_models:
|
||||
pytest.skip("No embedding models available in Ollama")
|
||||
|
||||
test_model = embedding_models[0]
|
||||
|
||||
# Should query Ollama /api/show dynamically
|
||||
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||
|
||||
assert limit > 0
|
||||
logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}")
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not test hybrid Ollama discovery: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 70)
|
||||
print("INTEGRATION TEST SUITE - Real Service Testing")
|
||||
print("=" * 70)
|
||||
print("\nThese tests require live services:")
|
||||
print(" • LM Studio: http://localhost:1234 (with embedding model loaded)")
|
||||
print(" • [Optional] Ollama: http://localhost:11434")
|
||||
print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests")
|
||||
print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s")
|
||||
print("=" * 70 + "\n")
|
||||
808
tests/test_prompt_template_persistence.py
Normal file
808
tests/test_prompt_template_persistence.py
Normal file
@@ -0,0 +1,808 @@
|
||||
"""
|
||||
Integration tests for prompt template metadata persistence and reuse.
|
||||
|
||||
These tests verify the complete lifecycle of prompt template persistence:
|
||||
1. Template is saved to .meta.json during index build
|
||||
2. Template is automatically loaded during search operations
|
||||
3. Template can be overridden with explicit flag during search
|
||||
4. Template is reused during chat/ask operations
|
||||
|
||||
These are integration tests that:
|
||||
- Use real file system with temporary directories
|
||||
- Run actual build and search operations
|
||||
- Inspect .meta.json file contents directly
|
||||
- Mock embedding servers to avoid external dependencies
|
||||
- Use small test codebases for fast execution
|
||||
|
||||
Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
class TestPromptTemplateMetadataPersistence:
|
||||
"""Tests for prompt template storage in .meta.json during build."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings(self):
|
||||
"""Mock compute_embeddings to return dummy embeddings."""
|
||||
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||
# Return dummy embeddings as numpy array
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
yield mock_compute
|
||||
|
||||
def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings):
|
||||
"""
|
||||
Verify that when build is run with embedding_options containing prompt_template,
|
||||
the template value is saved to .meta.json file.
|
||||
|
||||
This is the core persistence requirement - templates must be saved to allow
|
||||
reuse in subsequent search operations without re-specifying the flag.
|
||||
|
||||
Expected failure: .meta.json exists but doesn't contain embedding_options
|
||||
with prompt_template, or the value is not persisted correctly.
|
||||
"""
|
||||
# Setup test data
|
||||
index_path = temp_index_dir / "test_index.leann"
|
||||
template = "search_document: "
|
||||
|
||||
# Build index with prompt template in embedding_options
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={"prompt_template": template},
|
||||
)
|
||||
|
||||
# Add a simple document
|
||||
builder.add_text("This is a test document for indexing")
|
||||
|
||||
# Build the index
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Verify .meta.json was created and contains the template
|
||||
meta_path = temp_index_dir / "test_index.leann.meta.json"
|
||||
assert meta_path.exists(), ".meta.json file should be created during build"
|
||||
|
||||
# Read and parse metadata
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
|
||||
# Verify embedding_options exists in metadata
|
||||
assert "embedding_options" in meta_data, (
|
||||
"embedding_options should be saved to .meta.json when provided"
|
||||
)
|
||||
|
||||
# Verify prompt_template is in embedding_options
|
||||
embedding_options = meta_data["embedding_options"]
|
||||
assert "prompt_template" in embedding_options, (
|
||||
"prompt_template should be saved within embedding_options"
|
||||
)
|
||||
|
||||
# Verify the template value matches what we provided
|
||||
assert embedding_options["prompt_template"] == template, (
|
||||
f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'"
|
||||
)
|
||||
|
||||
def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings):
|
||||
"""
|
||||
Verify that when no prompt template is provided during build,
|
||||
.meta.json either doesn't have embedding_options or prompt_template key.
|
||||
|
||||
This ensures clean metadata without unnecessary keys when features aren't used.
|
||||
|
||||
Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template.
|
||||
"""
|
||||
index_path = temp_index_dir / "test_no_template.leann"
|
||||
|
||||
# Build index WITHOUT prompt template
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
# No embedding_options provided
|
||||
)
|
||||
|
||||
builder.add_text("Document without template")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Verify metadata
|
||||
meta_path = temp_index_dir / "test_no_template.leann.meta.json"
|
||||
assert meta_path.exists()
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
|
||||
# If embedding_options exists, it should not contain prompt_template
|
||||
if "embedding_options" in meta_data:
|
||||
embedding_options = meta_data["embedding_options"]
|
||||
assert "prompt_template" not in embedding_options, (
|
||||
"prompt_template should not be in metadata when not provided"
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplateAutoLoadOnSearch:
|
||||
"""Tests for automatic loading of prompt template during search operations.
|
||||
|
||||
NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search).
|
||||
This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad
|
||||
which uses simpler mocking and doesn't hang.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings(self):
|
||||
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
yield mock_compute
|
||||
|
||||
def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings):
|
||||
"""
|
||||
Verify that searching an index built WITHOUT a prompt template
|
||||
works correctly (backward compatibility).
|
||||
|
||||
The searcher should handle missing prompt_template gracefully.
|
||||
|
||||
Expected behavior: Search succeeds, no template is used.
|
||||
"""
|
||||
# Build index without template
|
||||
index_path = temp_index_dir / "no_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
)
|
||||
builder.add_text("Document without template")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mocks
|
||||
mock_embeddings.reset_mock()
|
||||
|
||||
# Create searcher and search
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
# Verify no template in embedding_options
|
||||
assert "prompt_template" not in searcher.embedding_options, (
|
||||
"Searcher should not have prompt_template when not in metadata"
|
||||
)
|
||||
|
||||
|
||||
class TestQueryPromptTemplateAutoLoad:
|
||||
"""Tests for automatic loading of separate query_prompt_template during search (R2).
|
||||
|
||||
These tests verify the new two-template system where:
|
||||
- build_prompt_template: Applied during index building
|
||||
- query_prompt_template: Applied during search operations
|
||||
|
||||
Expected to FAIL in Red Phase (R2) because query template extraction
|
||||
and application is not yet implemented in LeannSearcher.search().
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_compute_embeddings(self):
|
||||
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||
with patch("leann.embedding_compute.compute_embeddings") as mock_compute:
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
yield mock_compute
|
||||
|
||||
def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings):
|
||||
"""
|
||||
Verify that search() automatically loads and applies query_prompt_template from .meta.json.
|
||||
|
||||
Given: Index built with separate build_prompt_template and query_prompt_template
|
||||
When: LeannSearcher.search("my query") is called
|
||||
Then: Query embedding is computed with "query: my query" (query template applied)
|
||||
|
||||
This is the core R2 requirement - query templates must be auto-loaded and applied
|
||||
during search without user intervention.
|
||||
|
||||
Expected failure: compute_embeddings called with raw "my query" instead of
|
||||
"query: my query" because query template extraction is not implemented.
|
||||
"""
|
||||
# Setup: Build index with separate templates in new format
|
||||
index_path = temp_index_dir / "query_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={
|
||||
"build_prompt_template": "doc: ",
|
||||
"query_prompt_template": "query: ",
|
||||
},
|
||||
)
|
||||
builder.add_text("Test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mock to ignore build calls
|
||||
mock_compute_embeddings.reset_mock()
|
||||
|
||||
# Act: Search with query
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
# Mock the backend search to avoid actual search
|
||||
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||
mock_backend_search.return_value = {
|
||||
"labels": [["test_id_0"]], # IDs (nested list for batch support)
|
||||
"distances": [[0.9]], # Distances (nested list for batch support)
|
||||
}
|
||||
|
||||
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||
|
||||
# Assert: compute_embeddings was called with query template applied
|
||||
assert mock_compute_embeddings.called, "compute_embeddings should be called during search"
|
||||
|
||||
# Get the actual text passed to compute_embeddings
|
||||
call_args = mock_compute_embeddings.call_args
|
||||
texts_arg = call_args[0][0] # First positional arg (list of texts)
|
||||
|
||||
assert len(texts_arg) == 1, "Should compute embedding for one query"
|
||||
assert texts_arg[0] == "query: my query", (
|
||||
f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'"
|
||||
)
|
||||
|
||||
def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings):
|
||||
"""
|
||||
Verify backward compatibility with old single prompt_template format.
|
||||
|
||||
Given: Index with old format (single prompt_template, no query_prompt_template)
|
||||
When: LeannSearcher.search("my query") is called
|
||||
Then: Query embedding computed with "doc: my query" (old template applied)
|
||||
|
||||
This ensures indexes built with the old single-template system continue
|
||||
to work correctly with the new search implementation.
|
||||
|
||||
Expected failure: Old template not recognized/applied because backward
|
||||
compatibility logic is not implemented.
|
||||
"""
|
||||
# Setup: Build index with old single-template format
|
||||
index_path = temp_index_dir / "old_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={"prompt_template": "doc: "}, # Old format
|
||||
)
|
||||
builder.add_text("Test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mock
|
||||
mock_compute_embeddings.reset_mock()
|
||||
|
||||
# Act: Search
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||
|
||||
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||
|
||||
# Assert: Old template was applied
|
||||
call_args = mock_compute_embeddings.call_args
|
||||
texts_arg = call_args[0][0]
|
||||
|
||||
assert texts_arg[0] == "doc: my query", (
|
||||
f"Old prompt_template should be applied for backward compatibility: "
|
||||
f"expected 'doc: my query', got '{texts_arg[0]}'"
|
||||
)
|
||||
|
||||
def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings):
|
||||
"""
|
||||
Verify backward compatibility when no template is present in .meta.json.
|
||||
|
||||
Given: Index with no template in .meta.json (very old indexes)
|
||||
When: LeannSearcher.search("my query") is called
|
||||
Then: Query embedding computed with "my query" (no template, raw query)
|
||||
|
||||
This ensures the most basic backward compatibility - indexes without
|
||||
any template support continue to work as before.
|
||||
|
||||
Expected failure: May fail if default template is incorrectly applied,
|
||||
or if missing template causes error.
|
||||
"""
|
||||
# Setup: Build index without any template
|
||||
index_path = temp_index_dir / "no_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
# No embedding_options at all
|
||||
)
|
||||
builder.add_text("Test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mock
|
||||
mock_compute_embeddings.reset_mock()
|
||||
|
||||
# Act: Search
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||
|
||||
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||
|
||||
# Assert: No template applied (raw query)
|
||||
call_args = mock_compute_embeddings.call_args
|
||||
texts_arg = call_args[0][0]
|
||||
|
||||
assert texts_arg[0] == "my query", (
|
||||
f"No template should be applied when missing from metadata: "
|
||||
f"expected 'my query', got '{texts_arg[0]}'"
|
||||
)
|
||||
|
||||
def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings):
|
||||
"""
|
||||
Verify that explicit provider_options can override stored query template.
|
||||
|
||||
Given: Index with query_prompt_template: "query: "
|
||||
When: search() called with provider_options={"prompt_template": "override: "}
|
||||
Then: Query embedding computed with "override: test" (override takes precedence)
|
||||
|
||||
This enables users to experiment with different query templates without
|
||||
rebuilding the index, or to handle special query types differently.
|
||||
|
||||
Expected failure: provider_options parameter is accepted via **kwargs but
|
||||
not used. Query embedding computed with raw "test" instead of "override: test"
|
||||
because override logic is not implemented.
|
||||
"""
|
||||
# Setup: Build index with query template
|
||||
index_path = temp_index_dir / "override_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={
|
||||
"build_prompt_template": "doc: ",
|
||||
"query_prompt_template": "query: ",
|
||||
},
|
||||
)
|
||||
builder.add_text("Test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mock
|
||||
mock_compute_embeddings.reset_mock()
|
||||
|
||||
# Act: Search with override
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||
|
||||
# This should accept provider_options parameter
|
||||
searcher.search(
|
||||
"test",
|
||||
top_k=1,
|
||||
recompute_embeddings=False,
|
||||
provider_options={"prompt_template": "override: "},
|
||||
)
|
||||
|
||||
# Assert: Override template was applied
|
||||
call_args = mock_compute_embeddings.call_args
|
||||
texts_arg = call_args[0][0]
|
||||
|
||||
assert texts_arg[0] == "override: test", (
|
||||
f"Override template should take precedence: "
|
||||
f"expected 'override: test', got '{texts_arg[0]}'"
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplateReuseInChat:
|
||||
"""Tests for prompt template reuse in chat/ask operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings(self):
|
||||
"""Mock compute_embeddings to return dummy embeddings."""
|
||||
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
yield mock_compute
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_server_manager(self):
|
||||
"""Mock EmbeddingServerManager for chat tests."""
|
||||
with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class:
|
||||
mock_manager = Mock()
|
||||
mock_manager.start_server.return_value = (True, 5557)
|
||||
mock_manager_class.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
@pytest.fixture
|
||||
def index_with_template(self, temp_index_dir, mock_embeddings):
|
||||
"""Build an index with a prompt template."""
|
||||
index_path = temp_index_dir / "chat_template_index.leann"
|
||||
template = "document_query: "
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={"prompt_template": template},
|
||||
)
|
||||
|
||||
builder.add_text("Test document for chat")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
return str(index_path), template
|
||||
|
||||
|
||||
class TestPromptTemplateIntegrationWithEmbeddingModes:
|
||||
"""Tests for prompt template compatibility with different embedding modes."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode,model,template,filename_prefix",
|
||||
[
|
||||
(
|
||||
"openai",
|
||||
"text-embedding-3-small",
|
||||
"Represent this for searching: ",
|
||||
"openai_template",
|
||||
),
|
||||
("ollama", "nomic-embed-text", "search_query: ", "ollama_template"),
|
||||
("sentence-transformers", "facebook/contriever", "query: ", "st_template"),
|
||||
],
|
||||
)
|
||||
def test_prompt_template_metadata_with_embedding_modes(
|
||||
self, temp_index_dir, mode, model, template, filename_prefix
|
||||
):
|
||||
"""Verify prompt template is saved correctly across different embedding modes.
|
||||
|
||||
Tests that prompt templates are persisted to .meta.json for:
|
||||
- OpenAI mode (primary use case)
|
||||
- Ollama mode (also supports templates)
|
||||
- Sentence-transformers mode (saved for forward compatibility)
|
||||
|
||||
Expected behavior: Template is saved to .meta.json regardless of mode.
|
||||
"""
|
||||
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
|
||||
index_path = temp_index_dir / f"{filename_prefix}.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model,
|
||||
embedding_mode=mode,
|
||||
embedding_options={"prompt_template": template},
|
||||
)
|
||||
|
||||
builder.add_text(f"{mode.capitalize()} test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Verify metadata
|
||||
meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
|
||||
assert meta_data["embedding_mode"] == mode
|
||||
# Template should be saved for all modes (even if not used by some)
|
||||
if "embedding_options" in meta_data:
|
||||
assert meta_data["embedding_options"]["prompt_template"] == template
|
||||
|
||||
|
||||
class TestQueryTemplateApplicationInComputeEmbedding:
|
||||
"""Tests for query template application in compute_query_embedding() (Bug Fix).
|
||||
|
||||
These tests verify that query templates are applied consistently in BOTH
|
||||
code paths (server and fallback) when computing query embeddings.
|
||||
|
||||
This addresses the bug where query templates were only applied in the
|
||||
fallback path, not when using the embedding server (the default path).
|
||||
|
||||
Bug Context:
|
||||
- Issue: Query templates were stored in metadata but only applied during
|
||||
fallback (direct) computation, not when using embedding server
|
||||
- Fix: Move template application to BEFORE any computation path in
|
||||
compute_query_embedding() (searcher_base.py:107-110)
|
||||
- Impact: Critical for models like EmbeddingGemma that require task-specific
|
||||
templates for optimal performance
|
||||
|
||||
These tests ensure the fix works correctly and prevent regression.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_with_template(self):
|
||||
"""Create a temporary index with query template in metadata"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
index_dir = Path(tmpdir)
|
||||
index_file = index_dir / "test.leann"
|
||||
meta_file = index_dir / "test.leann.meta.json"
|
||||
|
||||
# Create minimal metadata with query template
|
||||
metadata = {
|
||||
"version": "1.0",
|
||||
"backend_name": "hnsw",
|
||||
"embedding_model": "text-embedding-embeddinggemma-300m-qat",
|
||||
"dimensions": 768,
|
||||
"embedding_mode": "openai",
|
||||
"backend_kwargs": {
|
||||
"graph_degree": 32,
|
||||
"complexity": 64,
|
||||
"distance_metric": "cosine",
|
||||
},
|
||||
"embedding_options": {
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_key": "test-key",
|
||||
"build_prompt_template": "title: none | text: ",
|
||||
"query_prompt_template": "task: search result | query: ",
|
||||
},
|
||||
}
|
||||
|
||||
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
# Create minimal HNSW index file (empty is okay for this test)
|
||||
index_file.write_bytes(b"")
|
||||
|
||||
yield str(index_file)
|
||||
|
||||
def test_query_template_applied_in_fallback_path(self, temp_index_with_template):
|
||||
"""Test that query template is applied when using fallback (direct) path"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
# Create a concrete implementation for testing
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
# Load metadata
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
# Mock compute_embeddings to capture the query text
|
||||
captured_queries = []
|
||||
|
||||
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||
captured_queries.extend(texts)
|
||||
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||
|
||||
with patch(
|
||||
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||
):
|
||||
# Call compute_query_embedding with template (fallback path)
|
||||
result = searcher.compute_query_embedding(
|
||||
query="vector database",
|
||||
use_server_if_available=False, # Force fallback path
|
||||
query_template="task: search result | query: ",
|
||||
)
|
||||
|
||||
# Verify template was applied
|
||||
assert len(captured_queries) == 1
|
||||
assert captured_queries[0] == "task: search result | query: vector database"
|
||||
assert result.shape == (1, 768)
|
||||
|
||||
def test_query_template_applied_in_server_path(self, temp_index_with_template):
|
||||
"""Test that query template is applied when using server path"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
# Create a concrete implementation for testing
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
# Load metadata
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
# Mock the server methods to capture the query text
|
||||
captured_queries = []
|
||||
|
||||
def mock_ensure_server_running(passages_file, port):
|
||||
return port
|
||||
|
||||
def mock_compute_embedding_via_server(chunks, port):
|
||||
captured_queries.extend(chunks)
|
||||
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||
|
||||
searcher._ensure_server_running = mock_ensure_server_running
|
||||
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||
|
||||
# Call compute_query_embedding with template (server path)
|
||||
result = searcher.compute_query_embedding(
|
||||
query="vector database",
|
||||
use_server_if_available=True, # Use server path
|
||||
query_template="task: search result | query: ",
|
||||
)
|
||||
|
||||
# Verify template was applied BEFORE calling server
|
||||
assert len(captured_queries) == 1
|
||||
assert captured_queries[0] == "task: search result | query: vector database"
|
||||
assert result.shape == (1, 768)
|
||||
|
||||
def test_query_template_without_template_parameter(self, temp_index_with_template):
|
||||
"""Test that query is unchanged when no template is provided"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
captured_queries = []
|
||||
|
||||
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||
captured_queries.extend(texts)
|
||||
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||
|
||||
with patch(
|
||||
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||
):
|
||||
searcher.compute_query_embedding(
|
||||
query="vector database",
|
||||
use_server_if_available=False,
|
||||
query_template=None, # No template
|
||||
)
|
||||
|
||||
# Verify query is unchanged
|
||||
assert len(captured_queries) == 1
|
||||
assert captured_queries[0] == "vector database"
|
||||
|
||||
def test_query_template_consistency_between_paths(self, temp_index_with_template):
|
||||
"""Test that both paths apply template identically"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
query_template = "task: search result | query: "
|
||||
original_query = "vector database"
|
||||
|
||||
# Capture queries from fallback path
|
||||
fallback_queries = []
|
||||
|
||||
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||
fallback_queries.extend(texts)
|
||||
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||
|
||||
with patch(
|
||||
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||
):
|
||||
searcher.compute_query_embedding(
|
||||
query=original_query,
|
||||
use_server_if_available=False,
|
||||
query_template=query_template,
|
||||
)
|
||||
|
||||
# Capture queries from server path
|
||||
server_queries = []
|
||||
|
||||
def mock_ensure_server_running(passages_file, port):
|
||||
return port
|
||||
|
||||
def mock_compute_embedding_via_server(chunks, port):
|
||||
server_queries.extend(chunks)
|
||||
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||
|
||||
searcher._ensure_server_running = mock_ensure_server_running
|
||||
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||
|
||||
searcher.compute_query_embedding(
|
||||
query=original_query,
|
||||
use_server_if_available=True,
|
||||
query_template=query_template,
|
||||
)
|
||||
|
||||
# Verify both paths produced identical templated queries
|
||||
assert len(fallback_queries) == 1
|
||||
assert len(server_queries) == 1
|
||||
assert fallback_queries[0] == server_queries[0]
|
||||
assert fallback_queries[0] == f"{query_template}{original_query}"
|
||||
|
||||
def test_query_template_with_empty_string(self, temp_index_with_template):
|
||||
"""Test behavior with empty template string"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
captured_queries = []
|
||||
|
||||
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||
captured_queries.extend(texts)
|
||||
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||
|
||||
with patch(
|
||||
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||
):
|
||||
searcher.compute_query_embedding(
|
||||
query="vector database",
|
||||
use_server_if_available=False,
|
||||
query_template="", # Empty string
|
||||
)
|
||||
|
||||
# Empty string is falsy, so no template should be applied
|
||||
assert captured_queries[0] == "vector database"
|
||||
@@ -266,3 +266,378 @@ class TestTokenTruncation:
|
||||
assert result_tokens <= target_tokens, (
|
||||
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
||||
)
|
||||
|
||||
|
||||
class TestLMStudioHybridDiscovery:
|
||||
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
|
||||
|
||||
These tests verify that get_model_token_limit() properly integrates with
|
||||
the LM Studio SDK bridge for dynamic token limit discovery. The integration
|
||||
should:
|
||||
|
||||
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
|
||||
2. Convert HTTP URLs to WebSocket format for SDK queries
|
||||
3. Query LM Studio SDK and use discovered limit
|
||||
4. Fall back to registry when SDK returns None
|
||||
5. Execute AFTER Ollama detection but BEFORE registry fallback
|
||||
|
||||
All tests are written in Red Phase - they should FAIL initially because the
|
||||
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
|
||||
"""
|
||||
|
||||
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
|
||||
"""Verify LM Studio SDK query succeeds and returns detected limit.
|
||||
|
||||
When a LM Studio base_url is detected and the SDK query succeeds,
|
||||
get_model_token_limit() should return the dynamically discovered
|
||||
context length without falling back to the registry.
|
||||
"""
|
||||
|
||||
# Mock _query_lmstudio_context_limit to return successful SDK query
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
# Verify WebSocket URL was passed (not HTTP)
|
||||
assert base_url.startswith("ws://"), (
|
||||
f"Should convert HTTP to WebSocket format, got: {base_url}"
|
||||
)
|
||||
return 8192 # Successful SDK query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with HTTP URL that should be converted to WebSocket
|
||||
limit = get_model_token_limit(
|
||||
model_name="custom-model", base_url="http://localhost:1234/v1"
|
||||
)
|
||||
|
||||
assert limit == 8192, "Should return limit from LM Studio SDK query"
|
||||
|
||||
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
|
||||
"""Verify fallback to registry when LM Studio SDK returns None.
|
||||
|
||||
When LM Studio SDK query fails (returns None), get_model_token_limit()
|
||||
should fall back to the EMBEDDING_MODEL_LIMITS registry.
|
||||
"""
|
||||
|
||||
# Mock _query_lmstudio_context_limit to return None (SDK failure)
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
return None # SDK query failed
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with known model that exists in registry
|
||||
limit = get_model_token_limit(
|
||||
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
|
||||
)
|
||||
|
||||
# Should fall back to registry value
|
||||
assert limit == 2048, "Should fall back to registry when SDK returns None"
|
||||
|
||||
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
|
||||
"""Verify detection of LM Studio via port 1234.
|
||||
|
||||
get_model_token_limit() should recognize port 1234 as a LM Studio
|
||||
server and attempt SDK query, regardless of hostname.
|
||||
"""
|
||||
query_called = False
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal query_called
|
||||
query_called = True
|
||||
return 4096
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with port 1234 (default LM Studio port)
|
||||
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
|
||||
|
||||
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
|
||||
assert limit == 4096, "Should return SDK query result"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_url,expected_limit,keyword",
|
||||
[
|
||||
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
|
||||
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
|
||||
],
|
||||
)
|
||||
def test_get_model_token_limit_lmstudio_url_keyword_detection(
|
||||
self, monkeypatch, test_url, expected_limit, keyword
|
||||
):
|
||||
"""Verify detection of LM Studio via keywords in URL.
|
||||
|
||||
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
|
||||
in the URL as indicating a LM Studio server.
|
||||
"""
|
||||
query_called = False
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal query_called
|
||||
query_called = True
|
||||
return expected_limit
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
|
||||
|
||||
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
|
||||
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_url,expected_protocol,expected_host",
|
||||
[
|
||||
("http://localhost:1234/v1", "ws://", "localhost:1234"),
|
||||
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
|
||||
],
|
||||
)
|
||||
def test_get_model_token_limit_protocol_conversion(
|
||||
self, monkeypatch, input_url, expected_protocol, expected_host
|
||||
):
|
||||
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
|
||||
|
||||
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
|
||||
1. Convert 'http://' to 'ws://'
|
||||
2. Convert 'https://' to 'wss://'
|
||||
3. Remove '/v1' or other path suffixes (SDK expects base URL)
|
||||
4. Preserve host and port
|
||||
"""
|
||||
conversions_tested = []
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
conversions_tested.append(base_url)
|
||||
return 8192
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
get_model_token_limit(model_name="test-model", base_url=input_url)
|
||||
|
||||
# Verify conversion happened
|
||||
assert len(conversions_tested) == 1, "Should have called SDK query once"
|
||||
assert conversions_tested[0].startswith(expected_protocol), (
|
||||
f"Should convert to {expected_protocol}"
|
||||
)
|
||||
assert expected_host in conversions_tested[0], (
|
||||
f"Should preserve host and port: {expected_host}"
|
||||
)
|
||||
|
||||
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
|
||||
"""Verify LM Studio detection happens AFTER Ollama detection.
|
||||
|
||||
The hybrid discovery order should be:
|
||||
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
|
||||
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
|
||||
3. Registry fallback
|
||||
|
||||
If both Ollama and LM Studio patterns match, Ollama should take precedence.
|
||||
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
|
||||
"""
|
||||
ollama_called = False
|
||||
lmstudio_called = False
|
||||
|
||||
def mock_query_ollama(model_name, base_url):
|
||||
nonlocal ollama_called
|
||||
ollama_called = True
|
||||
return 2048 # Ollama query succeeds
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal lmstudio_called
|
||||
lmstudio_called = True
|
||||
return None # Should not be reached if Ollama succeeds
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_ollama_context_limit",
|
||||
mock_query_ollama,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with Ollama URL
|
||||
limit = get_model_token_limit(
|
||||
model_name="test-model", base_url="http://localhost:11434/api"
|
||||
)
|
||||
|
||||
assert ollama_called, "Should attempt Ollama query first"
|
||||
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
|
||||
assert limit == 2048, "Should return Ollama result"
|
||||
|
||||
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
|
||||
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
|
||||
|
||||
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
|
||||
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
|
||||
"""
|
||||
lmstudio_called = False
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal lmstudio_called
|
||||
lmstudio_called = True
|
||||
return 8192
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with non-LM Studio URLs
|
||||
test_cases = [
|
||||
"http://localhost:8080/v1", # Different port
|
||||
"http://openai.example.com/v1", # Different service
|
||||
"http://localhost:3000/v1", # Another port
|
||||
]
|
||||
|
||||
for base_url in test_cases:
|
||||
lmstudio_called = False # Reset for each test
|
||||
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
|
||||
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
|
||||
|
||||
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
|
||||
"""Verify LM Studio detection is case-insensitive for keywords.
|
||||
|
||||
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
|
||||
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
|
||||
"""
|
||||
query_called = False
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal query_called
|
||||
query_called = True
|
||||
return 8192
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test various case variations
|
||||
test_cases = [
|
||||
"http://LMStudio.local:8080/v1",
|
||||
"http://LMSTUDIO.example.com/v1",
|
||||
"http://LmStudio.local/v1",
|
||||
"http://api.LM.STUDIO:5000/v1",
|
||||
]
|
||||
|
||||
for base_url in test_cases:
|
||||
query_called = False # Reset for each test
|
||||
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
|
||||
assert query_called, f"Should detect LM Studio in URL: {base_url}"
|
||||
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
|
||||
|
||||
|
||||
class TestTokenLimitCaching:
|
||||
"""Tests for token limit caching to prevent repeated SDK/API calls.
|
||||
|
||||
Caching prevents duplicate SDK/API calls within the same Python process,
|
||||
which is important because:
|
||||
1. LM Studio SDK load() can load duplicate model instances
|
||||
2. Ollama /api/show queries add latency
|
||||
3. Registry lookups are pure overhead
|
||||
|
||||
Cache is process-scoped and resets between leann build invocations.
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear cache before each test."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
_token_limit_cache.clear()
|
||||
|
||||
def test_registry_lookup_is_cached(self):
|
||||
"""Verify that registry lookups are cached."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
# First call
|
||||
limit1 = get_model_token_limit("text-embedding-3-small")
|
||||
assert limit1 == 8192
|
||||
|
||||
# Verify it's in cache
|
||||
cache_key = ("text-embedding-3-small", "")
|
||||
assert cache_key in _token_limit_cache
|
||||
assert _token_limit_cache[cache_key] == 8192
|
||||
|
||||
# Second call should use cache
|
||||
limit2 = get_model_token_limit("text-embedding-3-small")
|
||||
assert limit2 == 8192
|
||||
|
||||
def test_default_fallback_is_cached(self):
|
||||
"""Verify that default fallbacks are cached."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
# First call with unknown model
|
||||
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||
assert limit1 == 512
|
||||
|
||||
# Verify it's in cache
|
||||
cache_key = ("unknown-model-xyz", "")
|
||||
assert cache_key in _token_limit_cache
|
||||
assert _token_limit_cache[cache_key] == 512
|
||||
|
||||
# Second call should use cache
|
||||
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||
assert limit2 == 512
|
||||
|
||||
def test_different_urls_create_separate_cache_entries(self):
|
||||
"""Verify that different base_urls create separate cache entries."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
# Same model, different URLs
|
||||
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
|
||||
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
|
||||
|
||||
# Both should find the model in registry (2048)
|
||||
assert limit1 == 2048
|
||||
assert limit2 == 2048
|
||||
|
||||
# But they should be separate cache entries
|
||||
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
|
||||
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
|
||||
|
||||
assert cache_key1 in _token_limit_cache
|
||||
assert cache_key2 in _token_limit_cache
|
||||
assert len(_token_limit_cache) == 2
|
||||
|
||||
def test_cache_prevents_repeated_lookups(self):
|
||||
"""Verify that cache prevents repeated registry/API lookups."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
model_name = "text-embedding-ada-002"
|
||||
|
||||
# First call - should add to cache
|
||||
assert len(_token_limit_cache) == 0
|
||||
limit1 = get_model_token_limit(model_name)
|
||||
|
||||
cache_size_after_first = len(_token_limit_cache)
|
||||
assert cache_size_after_first == 1
|
||||
|
||||
# Multiple subsequent calls - cache size should not change
|
||||
for _ in range(5):
|
||||
limit = get_model_token_limit(model_name)
|
||||
assert limit == limit1
|
||||
assert len(_token_limit_cache) == cache_size_after_first
|
||||
|
||||
def test_versioned_model_names_cached_correctly(self):
|
||||
"""Verify that versioned model names (e.g., model:tag) are cached."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
# Model with version tag
|
||||
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
|
||||
assert limit == 2048
|
||||
|
||||
# Should be cached with full name including version
|
||||
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
|
||||
assert cache_key in _token_limit_cache
|
||||
assert _token_limit_cache[cache_key] == 2048
|
||||
|
||||
Reference in New Issue
Block a user