Compare commits

..

3 Commits

Author SHA1 Message Date
aakash
697d247698 fix security vulnerability: replace eval() 2025-11-13 11:12:31 -08:00
aakash
9b7353f336 Fix linting errors in colqwen_rag.py and test_colqwen_reproduction.py
- Add noqa comments for E402 errors (imports after sys.path modifications)
- Remove unused variable assignment in colqwen_rag.py
- Use importlib.util.find_spec for dependency checks instead of unused imports
- Fix import ordering in test_colqwen_reproduction.py
2025-11-11 05:12:49 -08:00
aakash
9dd0e0b26f feat: Add ColQwen multimodal PDF retrieval integration
- Add ColQwenRAG class with easy-to-use CLI for multimodal PDF retrieval
- Support for both ColQwen2 and ColPali models with automatic device selection
- MPS optimization for Apple Silicon with memory-efficient loading
- Complete pipeline: PDF→images→embeddings→HNSW index→search
- Multi-vector indexing for fine-grained document matching
- Comprehensive user guide and reproduction test script
- Resolves #119: ColQwen Doc and Support Management

Features:
- python -m apps.colqwen_rag build --pdfs ./pdfs/ --index my_index
- python -m apps.colqwen_rag search my_index "query text"
- python -m apps.colqwen_rag ask my_index --interactive
- Automatic CPU fallback for memory constraints
- Robust error handling and progress tracking
2025-11-10 13:31:58 -08:00
20 changed files with 5214 additions and 5988 deletions

200
COLQWEN_GUIDE.md Normal file
View File

@@ -0,0 +1,200 @@
# ColQwen Integration Guide
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
## Quick Start
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
### 1. Install Dependencies
```bash
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
brew install poppler # macOS only, for PDF processing
```
### 2. Basic Usage
```bash
# Build index from PDFs
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
# Search with text queries
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
# Interactive Q&A
python -m apps.colqwen_rag ask research_papers --interactive
```
## Commands
### Build Index
```bash
python -m apps.colqwen_rag build \
--pdfs ./pdf_directory/ \
--index my_index \
--model colqwen2 \
--pages-dir ./page_images/ # Optional: save page images
```
**Options:**
- `--pdfs`: Directory containing PDF files (or single PDF path)
- `--index`: Name for the index (required)
- `--model`: `colqwen2` (default) or `colpali`
- `--pages-dir`: Directory to save page images (optional)
### Search Index
```bash
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
```
**Options:**
- `--top-k`: Number of results to return (default: 5)
- `--model`: Model used for search (should match build model)
### Interactive Q&A
```bash
python -m apps.colqwen_rag ask my_index --interactive
```
**Commands in interactive mode:**
- Type your questions naturally
- `help`: Show available commands
- `quit`/`exit`/`q`: Exit interactive mode
## 🧪 Test & Reproduce Results
Run the reproduction test for issue #119:
```bash
python test_colqwen_reproduction.py
```
This will:
1. ✅ Check dependencies
2. 📥 Download sample PDF (Attention Is All You Need paper)
3. 🏗️ Build test index
4. 🔍 Run sample queries
5. 📊 Show how to generate similarity maps
## 🎨 Advanced: Similarity Maps
For visual similarity analysis, use the existing advanced script:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector/
python multi-vector-leann-similarity-map.py
```
Edit the script to customize:
- `QUERY`: Your question
- `MODEL`: "colqwen2" or "colpali"
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
- `SIMILARITY_MAP`: Generate heatmaps
- `ANSWER`: Enable Qwen-VL answer generation
## 🔧 How It Works
### ColQwen2 vs ColPali
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
### Architecture
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
2. **Vision Encoding**: Process images with ColQwen2/ColPali
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
4. **Query Processing**: Encode text queries with same model
5. **Similarity Search**: Find most relevant pages/regions
6. **Visual Maps**: Generate attention heatmaps (optional)
### Device Support
- **CUDA**: Best performance with GPU acceleration
- **MPS**: Apple Silicon Mac support
- **CPU**: Fallback for any system (slower)
Auto-detection: CUDA > MPS > CPU
## 📊 Performance Tips
### For Best Performance:
```bash
# Use ColQwen2 for latest features
--model colqwen2
# Save page images for reuse
--pages-dir ./cached_pages/
# Adjust batch size based on GPU memory
# (automatically handled)
```
### For Large Document Sets:
- Process PDFs in batches
- Use SSD storage for index files
- Consider using CUDA if available
## 🔗 Related Resources
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
- **Pylate**: https://github.com/lightonai/pylate
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
- **ColPali Paper**: Vision-Language Models for Document Retrieval
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
## 🐛 Troubleshooting
### PDF Conversion Issues (macOS)
```bash
# Install poppler
brew install poppler
which pdfinfo && pdfinfo -v
```
### Memory Issues
- Reduce batch size (automatically handled)
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
- Process fewer PDFs at once
### Model Download Issues
- Ensure internet connection for first run
- Models are cached after first download
- Use HuggingFace mirrors if needed
### Import Errors
```bash
# Ensure all dependencies installed
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
# Check PyTorch installation
python -c "import torch; print(torch.__version__)"
```
## 💡 Examples
### Research Paper Analysis
```bash
# Index your research papers
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
# Ask research questions
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
```
### Document Q&A
```bash
# Index business documents
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
# Interactive analysis
python -m apps.colqwen_rag ask reports --interactive
```
### Visual Analysis
```bash
# Generate similarity maps for specific queries
cd apps/multimodal/vision-based-pdf-multi-vector/
# Edit multi-vector-leann-similarity-map.py with your query
python multi-vector-leann-similarity-map.py
# Check ./figures/ for generated heatmaps
```
---
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**

View File

@@ -1,110 +0,0 @@
# Issue #159 Performance Analysis - Conclusion
## Problem Summary
User reported search times of 15-30 seconds instead of the ~2 seconds mentioned in the paper.
**Configuration:**
- GPU: 4090×1
- Embedding Model: BAAI/bge-large-zh-v1.5 (~300M parameters)
- Data Size: 180MB text (~90K chunks)
- Backend: HNSW
- beam_width: 10
- Other parameters: Default values
## Root Cause Analysis
### 1. **Search Complexity Parameter**
The **default `complexity` parameter is 64**, which is too high for achieving ~2 second search times with this configuration.
**Test Results (Reproduced):**
- **Complexity 64 (default)**: **36.17 seconds**
- **Complexity 32**: **2.49 seconds**
- **Complexity 16**: **2.24 seconds** ✅ (Close to paper's ~2 seconds)
- **Complexity 8**: **1.67 seconds**
### 2. **beam_width Parameter**
The `beam_width` parameter is **mainly for DiskANN backend**, not HNSW. Setting it to 10 has minimal/no effect on HNSW search performance.
### 3. **Embedding Model Size**
The paper uses a smaller embedding model (~100M parameters), while the user is using `BAAI/bge-large-zh-v1.5` (~300M parameters). This contributes to slower embedding computation during search, but the main bottleneck is the search complexity parameter.
## Solution
### **Recommended Fix: Reduce Search Complexity**
To achieve search times close to ~2 seconds, use:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search(
query="your query",
top_k=10,
complexity=16, # or complexity=32 for slightly better accuracy
# beam_width parameter doesn't affect HNSW, can be ignored
)
```
Or via CLI:
```bash
leann search your-index "your query" --complexity 16
```
### **Alternative Solutions**
1. **Use DiskANN Backend** (Recommended by maintainer)
- DiskANN is faster for large datasets
- Better performance scaling
- `beam_width` parameter is relevant here
```python
builder = LeannBuilder(backend_name="diskann")
```
2. **Use Smaller Embedding Model**
- Switch to a smaller model (~100M parameters) like the paper
- Faster embedding computation
- Example: `BAAI/bge-base-zh-v1.5` instead of `bge-large-zh-v1.5`
3. **Disable Recomputation** (Trade storage for speed)
- Use `--no-recompute` flag
- Stores all embeddings (much larger storage)
- Faster search (no embedding recomputation)
```bash
leann build your-index --no-recompute --no-compact
leann search your-index "query" --no-recompute
```
## Performance Comparison
| Complexity | Search Time | Accuracy | Recommendation |
|------------|-------------|----------|---------------|
| 64 (default) | ~36s | Highest | ❌ Too slow |
| 32 | ~2.5s | High | ✅ Good balance |
| 16 | ~2.2s | Good | ✅ **Recommended** (matches paper) |
| 8 | ~1.7s | Lower | ⚠️ May sacrifice accuracy |
## Key Takeaways
1. **The default `complexity=64` is optimized for accuracy, not speed**
2. **For ~2 second search times, use `complexity=16` or `complexity=32`**
3. **`beam_width` parameter is for DiskANN, not HNSW**
4. **The paper's ~2 second results likely used:**
- Smaller embedding model (~100M params)
- Lower complexity (16-32)
- Possibly DiskANN backend
## Verification
The issue has been reproduced and verified. The test script `test_issue_159.py` demonstrates:
- Default complexity (64) results in ~36 second search times
- Reducing complexity to 16-32 achieves ~2 second search times
- This matches the user's reported issue and provides a clear solution
## Next Steps
1. ✅ Issue reproduced and root cause identified
2. ✅ Solution provided (reduce complexity parameter)
3. ⏳ User should test with `complexity=16` or `complexity=32`
4. ⏳ Consider updating documentation to clarify complexity parameter trade-offs

View File

@@ -24,7 +24,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276) LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy. **Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md) \* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)

View File

@@ -12,7 +12,6 @@ from pathlib import Path
try: try:
from leann.chunking_utils import ( from leann.chunking_utils import (
CODE_EXTENSIONS, CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks, create_ast_chunks,
create_text_chunks, create_text_chunks,
create_traditional_chunks, create_traditional_chunks,
@@ -26,7 +25,6 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
sys.path.insert(0, str(leann_src)) sys.path.insert(0, str(leann_src))
from leann.chunking_utils import ( from leann.chunking_utils import (
CODE_EXTENSIONS, CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks, create_ast_chunks,
create_text_chunks, create_text_chunks,
create_traditional_chunks, create_traditional_chunks,
@@ -38,7 +36,6 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
__all__ = [ __all__ = [
"CODE_EXTENSIONS", "CODE_EXTENSIONS",
"_traditional_chunks_as_dicts",
"create_ast_chunks", "create_ast_chunks",
"create_text_chunks", "create_text_chunks",
"create_traditional_chunks", "create_traditional_chunks",

364
apps/colqwen_rag.py Normal file
View File

@@ -0,0 +1,364 @@
#!/usr/bin/env python3
"""
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
Usage:
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
python -m apps.colqwen_rag search my_index "How does attention work?"
python -m apps.colqwen_rag ask my_index --interactive
"""
import argparse
import os
import sys
from pathlib import Path
from typing import Optional, cast
# Add LEANN packages to path
_repo_root = Path(__file__).resolve().parents[1]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
import torch # noqa: E402
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
from pdf2image import convert_from_path # noqa: E402
from PIL import Image # noqa: E402
from torch.utils.data import DataLoader # noqa: E402
from tqdm import tqdm # noqa: E402
# Import the existing multi-vector implementation
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
from leann_multi_vector import LeannMultiVector # noqa: E402
class ColQwenRAG:
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
def __init__(self, model_type: str = "colpali"):
"""
Initialize ColQwen RAG system.
Args:
model_type: "colqwen2" or "colpali"
"""
self.model_type = model_type
self.device = self._get_device()
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
if self.device.type == "mps":
self.dtype = torch.float32
elif self.device.type == "cuda":
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
# Load model and processor with MPS-optimized settings
try:
if model_type == "colqwen2":
self.model_name = "vidore/colqwen2-v1.0"
if self.device.type == "mps":
# For MPS, load on CPU first then move to avoid memory allocation issues
self.model = ColQwen2.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="cpu",
low_cpu_mem_usage=True,
).eval()
self.model = self.model.to(self.device)
else:
self.model = ColQwen2.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map=self.device,
low_cpu_mem_usage=True,
).eval()
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
else: # colpali
self.model_name = "vidore/colpali-v1.2"
if self.device.type == "mps":
# For MPS, load on CPU first then move to avoid memory allocation issues
self.model = ColPali.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="cpu",
low_cpu_mem_usage=True,
).eval()
self.model = self.model.to(self.device)
else:
self.model = ColPali.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map=self.device,
low_cpu_mem_usage=True,
).eval()
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
except Exception as e:
if "memory" in str(e).lower() or "offload" in str(e).lower():
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
self.device = torch.device("cpu")
self.dtype = torch.float32
if model_type == "colqwen2":
self.model = ColQwen2.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="cpu",
low_cpu_mem_usage=True,
).eval()
else:
self.model = ColPali.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="cpu",
low_cpu_mem_usage=True,
).eval()
else:
raise
def _get_device(self):
"""Auto-select best available device."""
if torch.cuda.is_available():
return torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
"""
Build multimodal index from PDF files.
Args:
pdf_paths: List of PDF file paths
index_name: Name for the index
pages_dir: Directory to save page images (optional)
"""
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
# Convert PDFs to images
all_images = []
all_metadata = []
if pages_dir:
os.makedirs(pages_dir, exist_ok=True)
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
try:
images = convert_from_path(pdf_path, dpi=150)
pdf_name = Path(pdf_path).stem
for i, image in enumerate(images):
# Save image if pages_dir specified
if pages_dir:
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
image.save(image_path)
all_images.append(image)
all_metadata.append(
{
"pdf_path": pdf_path,
"pdf_name": pdf_name,
"page_number": i + 1,
"image_path": str(image_path) if pages_dir else None,
}
)
except Exception as e:
print(f"❌ Error processing {pdf_path}: {e}")
continue
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
print(f"All metadata: {all_metadata}")
# Generate embeddings
print("🧠 Generating embeddings...")
embeddings = self._embed_images(all_images)
# Build LEANN index
print("🔍 Building LEANN index...")
leann_mv = LeannMultiVector(
index_path=index_name,
dim=embeddings.shape[-1],
embedding_model_name=self.model_type,
)
# Create collection and insert data
leann_mv.create_collection()
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
data = {
"doc_id": i,
"filepath": metadata.get("image_path", ""),
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
}
leann_mv.insert(data)
# Build the index
leann_mv.create_index()
print(f"✅ Index '{index_name}' built successfully!")
return leann_mv
def search(self, index_name: str, query: str, top_k: int = 5):
"""
Search the index with a text query.
Args:
index_name: Name of the index to search
query: Text query
top_k: Number of results to return
"""
print(f"🔍 Searching '{index_name}' for: '{query}'")
# Load index
leann_mv = LeannMultiVector(
index_path=index_name,
dim=128, # Will be updated when loading
embedding_model_name=self.model_type,
)
# Generate query embedding
query_embedding = self._embed_query(query)
# Search (returns list of (score, doc_id) tuples)
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
# Display results
print(f"\n📋 Top {len(search_results)} results:")
for i, (score, doc_id) in enumerate(search_results, 1):
# Get metadata for this doc_id (we need to load the metadata)
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
return search_results
def ask(self, index_name: str, interactive: bool = False):
"""
Interactive Q&A with the indexed documents.
Args:
index_name: Name of the index to query
interactive: Whether to run in interactive mode
"""
print(f"💬 ColQwen Chat with '{index_name}'")
if interactive:
print("Type 'quit' to exit, 'help' for commands")
while True:
try:
query = input("\n🤔 Your question: ").strip()
if query.lower() in ["quit", "exit", "q"]:
break
elif query.lower() == "help":
print("Commands: quit/exit/q (exit), help (this message)")
continue
elif not query:
continue
self.search(index_name, query, top_k=3)
# TODO: Add answer generation with Qwen-VL
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
except KeyboardInterrupt:
print("\n👋 Goodbye!")
break
else:
query = input("🤔 Your question: ").strip()
if query:
self.search(index_name, query)
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
"""Generate embeddings for a list of images."""
dataset = ListDataset(images)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
embeddings = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Embedding images"):
batch_images = cast(list, batch)
batch_inputs = self.processor.process_images(batch_images).to(self.device)
batch_embeddings = self.model(**batch_inputs)
embeddings.append(batch_embeddings.cpu())
return torch.cat(embeddings, dim=0)
def _embed_query(self, query: str) -> torch.Tensor:
"""Generate embedding for a text query."""
with torch.no_grad():
query_inputs = self.processor.process_queries([query]).to(self.device)
query_embedding = self.model(**query_inputs)
return query_embedding.cpu()
def main():
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Build command
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
build_parser.add_argument("--index", required=True, help="Index name")
build_parser.add_argument(
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
)
build_parser.add_argument("--pages-dir", help="Directory to save page images")
# Search command
search_parser = subparsers.add_parser("search", help="Search the index")
search_parser.add_argument("index", help="Index name")
search_parser.add_argument("query", help="Search query")
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
search_parser.add_argument(
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
ask_parser.add_argument("index", help="Index name")
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
ask_parser.add_argument(
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
)
args = parser.parse_args()
if not args.command:
parser.print_help()
return
# Initialize ColQwen RAG
if args.command == "build":
colqwen = ColQwenRAG(args.model)
# Get PDF files
pdf_dir = Path(args.pdfs)
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
pdf_paths = [str(pdf_dir)]
elif pdf_dir.is_dir():
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
else:
print(f"❌ Invalid PDF path: {args.pdfs}")
return
if not pdf_paths:
print(f"❌ No PDF files found in {args.pdfs}")
return
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
elif args.command == "search":
colqwen = ColQwenRAG(args.model)
colqwen.search(args.index, args.query, args.top_k)
elif args.command == "ask":
colqwen = ColQwenRAG(args.model)
colqwen.ask(args.index, args.interactive)
if __name__ == "__main__":
main()

View File

@@ -1,18 +1,13 @@
import concurrent.futures from __future__ import annotations
import json
import os
import re
import sys import sys
import concurrent.futures
from pathlib import Path from pathlib import Path
from typing import Any, Optional, cast
import numpy as np import numpy as np
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None: def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3] _repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src" _leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw" _leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
@@ -22,380 +17,6 @@ def _ensure_repo_paths_importable(current_file: str) -> None:
sys.path.append(str(_leann_hnsw_pkg)) sys.path.append(str(_leann_hnsw_pkg))
def _find_backend_module_file() -> Optional[Path]:
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
this_file = Path(__file__).resolve()
candidates: list[Path] = []
# Common in-repo location
repo_root = this_file.parents[3]
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
candidates.append(
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
)
for cand in candidates:
try:
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
pass
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
for p in list(sys.path):
try:
cand = Path(p) / "leann_multi_vector.py"
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
continue
return None
_BACKEND_LEANN_CLASS: Optional[type] = None
def _get_backend_leann_multi_vector() -> type:
"""Load backend LeannMultiVector class even if this file shadows its module name."""
global _BACKEND_LEANN_CLASS
if _BACKEND_LEANN_CLASS is not None:
return _BACKEND_LEANN_CLASS
backend_path = _find_backend_module_file()
if backend_path is None:
# Fallback to local implementation in this module
try:
cls = LeannMultiVector # type: ignore[name-defined]
_BACKEND_LEANN_CLASS = cls
return cls
except Exception as e:
raise ImportError(
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
) from e
import importlib.util
module_name = "leann_hnsw_backend_module"
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
if spec is None or spec.loader is None:
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
backend_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = backend_module
spec.loader.exec_module(backend_module) # type: ignore[assignment]
if not hasattr(backend_module, "LeannMultiVector"):
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
return _BACKEND_LEANN_CLASS
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
) -> Any:
LeannMultiVector = _get_backend_leann_multi_vector()
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
"image": images[i], # Include the original image
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
LeannMultiVector = _get_backend_leann_multi_vector()
index_base = Path(index_path)
# Check for the actual HNSW index file written by the backend + our sidecar files
index_file = index_base.parent / f"{index_base.stem}.index"
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_file.exists() and meta.exists() and labels.exists():
try:
with open(meta, encoding="utf-8") as f:
meta_json = json.load(f)
dim = int(meta_json.get("dimensions", 128))
except Exception:
dim = 128
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Ensure repo paths are importable for dynamic backend loading
_ensure_repo_paths_importable(__file__) _ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402 from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
@@ -450,7 +71,6 @@ class LeannMultiVector:
"doc_id": int(data["doc_id"]), "doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""), "filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]], "colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
"image": data.get("image"), # PIL Image object (optional)
} }
) )
@@ -466,11 +86,6 @@ class LeannMultiVector:
index_path_obj = Path(self.index_path) index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy" return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
def _images_dir_path(self) -> Path:
"""Directory where original images are stored."""
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.images"
def create_index(self) -> None: def create_index(self) -> None:
if not self._pending_items: if not self._pending_items:
return return
@@ -478,23 +93,10 @@ class LeannMultiVector:
embeddings: list[np.ndarray] = [] embeddings: list[np.ndarray] = []
labels_meta: list[dict] = [] labels_meta: list[dict] = []
# Create images directory if needed
images_dir = self._images_dir_path()
images_dir.mkdir(parents=True, exist_ok=True)
for item in self._pending_items: for item in self._pending_items:
doc_id = int(item["doc_id"]) doc_id = int(item["doc_id"])
filepath = item.get("filepath", "") filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"] colbert_vecs = item["colbert_vecs"]
image = item.get("image")
# Save image if provided
image_path = ""
if image is not None and isinstance(image, Image.Image):
image_filename = f"doc_{doc_id}.png"
image_path = str(images_dir / image_filename)
image.save(image_path, "PNG")
for seq_id, vec in enumerate(colbert_vecs): for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32) vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np) embeddings.append(vec_np)
@@ -504,7 +106,6 @@ class LeannMultiVector:
"doc_id": doc_id, "doc_id": doc_id,
"seq_id": int(seq_id), "seq_id": int(seq_id),
"filepath": filepath, "filepath": filepath,
"image_path": image_path, # Store the path to the saved image
} }
) )
@@ -512,6 +113,7 @@ class LeannMultiVector:
return return
embeddings_np = np.vstack(embeddings).astype(np.float32) embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape) print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim}) builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
@@ -736,45 +338,3 @@ class LeannMultiVector:
scores.sort(key=lambda x: x[0], reverse=True) scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve the original image for a given doc_id from the index.
Args:
doc_id: The document ID
Returns:
PIL Image object if found, None otherwise
"""
self._load_labels_meta_if_needed()
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
image_path = meta.get("image_path", "")
if image_path and Path(image_path).exists():
return Image.open(image_path)
break
return None
def get_metadata(self, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a given doc_id.
Args:
doc_id: The document ID
Returns:
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
"""
self._load_labels_meta_if_needed()
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
return {
"doc_id": doc_id,
"filepath": meta.get("filepath", ""),
"image_path": meta.get("image_path", ""),
}
return None

View File

@@ -2,31 +2,35 @@
# %% # %%
# uv pip install matplotlib qwen_vl_utils # uv pip install matplotlib qwen_vl_utils
import os import os
from typing import Any, Optional import json
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from leann_multi_vector import ( # utility functions/classes def _ensure_repo_paths_importable(current_file: str) -> None:
_ensure_repo_paths_importable, """Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_load_images_from_dir, _repo_root = Path(current_file).resolve().parents[3]
_maybe_convert_pdf_to_images, _leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_load_colvision, _leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
_embed_images, if str(_leann_core_src) not in sys.path:
_embed_queries, sys.path.append(str(_leann_core_src))
_build_index, if str(_leann_hnsw_pkg) not in sys.path:
_load_retriever_if_index_exists, sys.path.append(str(_leann_hnsw_pkg))
_generate_similarity_map,
QwenVL,
)
_ensure_repo_paths_importable(__file__) _ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %% # %%
# Config # Config
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?" QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2" MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended) # Data source: set to True to use the Hugging Face dataset example (recommended)
@@ -41,7 +45,7 @@ PAGES_DIR: str = "./pages"
# Index + retrieval settings # Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann" INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 3 TOPK: int = 1
FIRST_STAGE_K: int = 500 FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False REBUILD_INDEX: bool = False
@@ -51,57 +55,338 @@ SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png" SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True ANSWER: bool = True
MAX_NEW_TOKENS: int = 1024 MAX_NEW_TOKENS: int = 128
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
try:
with open(meta, "r", encoding="utf-8") as f:
meta_json = json.load(f)
dim = int(meta_json.get("dimensions", 128))
except Exception:
dim = 128
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# %% # %%
# Step 1: Check if we can skip data loading (index already exists) # Step 1: Prepare data
retriever: Optional[Any] = None if USE_HF_DATASET:
need_to_build_index = REBUILD_INDEX from datasets import load_dataset
if not REBUILD_INDEX: dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
retriever = _load_retriever_if_index_exists(INDEX_PATH) N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
if retriever is not None: filepaths: list[str] = []
print(f"✓ Index loaded from {INDEX_PATH}") images: list[Image.Image] = []
print(f"✓ Images available at: {retriever._images_dir_path()}") for i in tqdm(range(N), desc="Loading dataset", total=N ):
need_to_build_index = False p = dataset[i]
else: # Compose a descriptive identifier for printing later
print(f"Index not found, will build new index") identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
need_to_build_index = True print(identifier)
filepaths.append(identifier)
# Step 2: Load data only if we need to build the index images.append(p["page_image"]) # PIL Image
if need_to_build_index:
print("Loading dataset...")
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print(f"Loaded {len(images)} images")
else: else:
print("Skipping dataset loading (using existing index)") _maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths = [] # Not needed when using existing index filepaths, images = _load_images_from_dir(PAGES_DIR)
images = [] # Not needed when using existing index if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
# %% # %%
# Step 3: Load model and processor (only if we need to build index or perform search) # Step 2: Load model and processor
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}") print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
@@ -109,39 +394,30 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %% # %%
# %% # %%
# Step 4: Build index if needed # Step 3: Build or load index
if need_to_build_index and retriever is None: retriever: Optional[LeannMultiVector] = None
print("Building index...") if not REBUILD_INDEX:
doc_vecs = _embed_images(model, processor, images) retriever = _load_retriever_if_index_exists(INDEX_PATH)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
# Clear memory
del images, filepaths, doc_vecs
# Note: Images are now stored in the index, retriever will load them on-demand from disk if retriever is None:
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
# %% # %%
# Step 5: Embed query and search # Step 4: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0] q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK) results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
if not results: if not results:
print("No results found.") print("No results found.")
else: else:
print(f'Top {len(results)} results for query: "{QUERY}"') print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = [] top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1): for rank, (score, doc_id) in enumerate(results, start=1):
# Retrieve image from index instead of memory path = filepaths[doc_id]
image = retriever.get_image(doc_id)
if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
continue
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
# For HF dataset, path is a descriptive identifier, not a real file path # For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}") print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(image) top_images.append(images[doc_id])
if SAVE_TOP_IMAGE: if SAVE_TOP_IMAGE:
from pathlib import Path as _Path from pathlib import Path as _Path
@@ -154,17 +430,12 @@ else:
else: else:
out_path = base / f"retrieved_page_rank{rank}.png" out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path)) img.save(str(out_path))
# Print the retrieval score (document-level MaxSim) alongside the saved path print(f"Saved retrieved page (rank {rank}) to: {out_path}")
try:
score, _doc_id = results[rank - 1]
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page ## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %% # %%
# Step 6: Similarity maps for top-K results # Step 5: Similarity maps for top-K results
if results and SIMILARITY_MAP: if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX) token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path from pathlib import Path as _Path
@@ -201,7 +472,7 @@ if results and SIMILARITY_MAP:
# %% # %%
# Step 7: Optional answer generation # Step 6: Optional answer generation
if results and ANSWER: if results and ANSWER:
qwen = QwenVL(device=device_str) qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS) response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)

View File

@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
flexible message processing options. flexible message processing options.
""" """
import ast
import asyncio import asyncio
import json import json
import logging import logging
@@ -146,16 +147,16 @@ class SlackMCPReader:
match = re.search(r"'error':\s*(\{[^}]+\})", str(e)) match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match: if match:
try: try:
error_dict = eval(match.group(1)) error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError, NameError): except (ValueError, SyntaxError):
pass pass
else: else:
# Try alternative format # Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e)) match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match: if match:
try: try:
error_dict = eval(match.group(1)) error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError, NameError): except (ValueError, SyntaxError):
pass pass
if self._is_cache_sync_error(error_dict): if self._is_cache_sync_error(error_dict):

View File

@@ -1,149 +0,0 @@
#!/usr/bin/env python3
"""
Test script to reproduce issue #159: Slow search performance
Configuration:
- GPU: 4090×1
- embedding_model: BAAI/bge-large-zh-v1.5
- data size: 180M text (~90K chunks)
- beam_width: 10 (though this is mainly for DiskANN, not HNSW)
- backend: hnsw
"""
import os
import time
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher, SearchResult
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
# Configuration matching the issue
INDEX_PATH = "./test_issue_159.leann"
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
BACKEND_NAME = "hnsw"
def generate_test_data(num_chunks=90000, chunk_size=2000):
"""Generate test data similar to 180MB text (~90K chunks)"""
# Each chunk is approximately 2000 characters
# 90K chunks * 2000 chars ≈ 180MB
chunks = []
base_text = (
"这是一个测试文档。LEANN是一个创新的向量数据库通过图基选择性重计算实现97%的存储节省。"
)
for i in range(num_chunks):
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
chunks.append(chunk[:chunk_size])
return chunks
def test_search_performance():
"""Test search performance with different configurations"""
print("=" * 80)
print("Testing LEANN Search Performance (Issue #159)")
print("=" * 80)
meta_path = Path(f"{INDEX_PATH}.meta.json")
if meta_path.exists():
print(f"\n✓ Index already exists at {INDEX_PATH}")
print(" Skipping build phase. Delete the index to rebuild.")
else:
print("\n📦 Building index...")
print(f" Backend: {BACKEND_NAME}")
print(f" Embedding Model: {EMBEDDING_MODEL}")
print(" Generating test data (~90K chunks, ~180MB)...")
chunks = generate_test_data(num_chunks=90000)
print(f" Generated {len(chunks)} chunks")
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
builder = LeannBuilder(
backend_name=BACKEND_NAME,
embedding_model=EMBEDDING_MODEL,
)
print(" Adding chunks to builder...")
start_time = time.time()
for i, chunk in enumerate(chunks):
builder.add_text(chunk)
if (i + 1) % 10000 == 0:
print(f" Added {i + 1}/{len(chunks)} chunks...")
print(" Building index...")
build_start = time.time()
builder.build_index(INDEX_PATH)
build_time = time.time() - build_start
print(f" ✓ Index built in {build_time:.2f} seconds")
# Test search with different complexity values
print("\n🔍 Testing search performance...")
searcher = LeannSearcher(INDEX_PATH)
test_query = "LEANN向量数据库存储优化"
# Test with default complexity (64)
print("\n Test 1: Default complexity (64) `1 ")
print(f" Query: '{test_query}'")
start_time = time.time()
results: list[SearchResult] = searcher.search(test_query, top_k=10, complexity=64)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with default complexity (64)
print("\n Test 1: Default complexity (64)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=64)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with lower complexity (32)
print("\n Test 2: Lower complexity (32)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=32)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with even lower complexity (16)
print("\n Test 3: Lower complexity (16)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=16)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with minimal complexity (8)
print("\n Test 4: Minimal complexity (8)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=8)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
print("\n" + "=" * 80)
print("Performance Analysis:")
print("=" * 80)
print("\nKey Findings:")
print("1. beam_width parameter is mainly for DiskANN backend, not HNSW")
print("2. For HNSW, the main parameter affecting search speed is 'complexity'")
print("3. Lower complexity values (16-32) should provide faster search")
print("4. The paper mentions ~2 seconds, which likely uses:")
print(" - Smaller embedding model (~100M params vs 300M for bge-large)")
print(" - Lower complexity (16-32)")
print(" - Possibly DiskANN backend for better performance")
print("\nRecommendations:")
print("- Try complexity=16 or complexity=32 for faster search")
print("- Consider using DiskANN backend for better performance on large datasets")
print("- Or use a smaller embedding model if speed is critical")
if __name__ == "__main__":
test_search_performance()

View File

@@ -143,6 +143,8 @@ def create_hnsw_embedding_server(
pass pass
return str(nid) return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event): def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal. """ZMQ server thread that respects shutdown signal.
@@ -156,238 +158,225 @@ def create_hnsw_embedding_server(
rep_socket.bind(f"tcp://*:{zmq_port}") rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}") logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000) rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0) rep_socket.setsockopt(zmq.LINGER, 0)
last_request_type = "unknown" # Track last request type/length for shape-correct fallbacks
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0 last_request_length = 0
def _build_safe_fallback():
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
return [[large_distance] * fallback_len]
if last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
if dim > 0:
return [[bsz, dim], [0.0] * (bsz * dim)]
return [[0, 0], []]
if last_request_type == "text":
return []
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
def _handle_text_embedding(request: list[str]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_distance_request(request: list[Any]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
node_ids = request[0]
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else:
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as exc:
logger.error(f"Distance computation error, using sentinels: {exc}")
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
e2e_start = time.time()
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as exc:
logger.error(f"Embedding computation error, returning zeros: {exc}")
response_payload = [dims, flat_data]
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
try: try:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...") logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv() request_bytes = rep_socket.recv()
except zmq.Again:
continue
try: # Rest of the processing logic (same as original)
request = msgpack.unpackb(request_bytes) request = msgpack.unpackb(request_bytes)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error")
break
logger.error(f"Error unpacking ZMQ message: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
continue
try: if len(request) == 1 and request[0] == "__QUERY_MODEL__":
# Model query response_bytes = msgpack.packb([model_name])
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
if ( if (
isinstance(request, list)
and len(request) == 1
and request[0] == "__QUERY_MODEL__"
):
rep_socket.send(msgpack.packb([model_name]))
# Direct text embedding
elif (
isinstance(request, list) isinstance(request, list)
and request and request
and all(isinstance(item, str) for item in request) and all(isinstance(item, str) for item in request)
): ):
_handle_text_embedding(request) last_request_type = "text"
# Distance calculation: [[ids], [query_vector]] last_request_length = len(request)
elif ( embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
isinstance(request, list) isinstance(request, list)
and len(request) == 2 and len(request) == 2
and isinstance(request[0], list) and isinstance(request[0], list)
and isinstance(request[1], list) and isinstance(request[1], list)
): ):
_handle_distance_request(request) node_ids = request[0]
# Embedding-by-id fallback # Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
rep_socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else: else:
_handle_embedding_by_id(request)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error") logger.info("Shutdown in progress, ignoring ZMQ error")
break break
logger.error(f"Error in ZMQ server loop: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
finally: finally:
try: try:
rep_socket.close(0) rep_socket.close(0)

View File

@@ -864,13 +864,7 @@ class LeannBuilder:
class LeannSearcher: class LeannSearcher:
def __init__( def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
self,
index_path: str,
enable_warmup: bool = True,
recompute_embeddings: bool = True,
**backend_kwargs,
):
# Fix path resolution for Colab and other environments # Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute(): if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve()) index_path = str(Path(index_path).resolve())
@@ -901,32 +895,14 @@ class LeannSearcher:
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.") raise ValueError(f"Backend '{backend_name}' not found.")
# Global recompute flag for this searcher (explicit knob, default True)
self.recompute_embeddings: bool = bool(recompute_embeddings)
# Warmup flag: keep using the existing enable_warmup parameter,
# but default it to True so cold-start happens earlier.
self._warmup: bool = bool(enable_warmup)
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = self._warmup final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options: if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options) final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs index_path, **final_kwargs
) )
# Optional one-shot warmup at construction time to hide cold-start latency.
if self._warmup:
try:
_ = self.backend_impl.compute_query_embedding(
"__LEANN_WARMUP__",
use_server_if_available=self.recompute_embeddings,
)
except Exception as exc:
logger.warning(f"Warmup embedding failed (ignored): {exc}")
def search( def search(
self, self,
query: str, query: str,
@@ -934,7 +910,7 @@ class LeannSearcher:
complexity: int = 64, complexity: int = 64,
beam_width: int = 1, beam_width: int = 1,
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: Optional[bool] = None, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
@@ -951,8 +927,7 @@ class LeannSearcher:
complexity: Search complexity/candidate list size, higher = more accurate but slower complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: (Deprecated) Per-call override for recompute mode. recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional" pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata. metadata_filters: Optional filters to apply to search results based on metadata.
@@ -991,19 +966,8 @@ class LeannSearcher:
zmq_port = None zmq_port = None
# Resolve effective recompute flag for this search.
if recompute_embeddings is not None:
logger.warning(
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
"will be removed in a future version. Configure recompute at "
"LeannSearcher(..., recompute_embeddings=...) instead."
)
effective_recompute = bool(recompute_embeddings)
else:
effective_recompute = self.recompute_embeddings
start_time = time.time() start_time = time.time()
if effective_recompute: if recompute_embeddings:
zmq_port = self.backend_impl._ensure_server_running( zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str, self.meta_path_str,
port=expected_zmq_port, port=expected_zmq_port,
@@ -1017,7 +981,7 @@ class LeannSearcher:
query_embedding = self.backend_impl.compute_query_embedding( query_embedding = self.backend_impl.compute_query_embedding(
query, query,
use_server_if_available=effective_recompute, use_server_if_available=recompute_embeddings,
zmq_port=zmq_port, zmq_port=zmq_port,
) )
logger.info(f" Generated embedding shape: {query_embedding.shape}") logger.info(f" Generated embedding shape: {query_embedding.shape}")
@@ -1029,7 +993,7 @@ class LeannSearcher:
"complexity": complexity, "complexity": complexity,
"beam_width": beam_width, "beam_width": beam_width,
"prune_ratio": prune_ratio, "prune_ratio": prune_ratio,
"recompute_embeddings": effective_recompute, "recompute_embeddings": recompute_embeddings,
"pruning_strategy": pruning_strategy, "pruning_strategy": pruning_strategy,
"zmq_port": zmq_port, "zmq_port": zmq_port,
} }

View File

@@ -5,15 +5,12 @@ Packaged within leann-core so installed wheels can import it reliably.
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Optional
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Flag to ensure AST token warning only shown once per session
_ast_token_warning_shown = False
def estimate_token_count(text: str) -> int: def estimate_token_count(text: str) -> int:
""" """
@@ -177,44 +174,37 @@ def create_ast_chunks(
max_chunk_size: int = 512, max_chunk_size: int = 512,
chunk_overlap: int = 64, chunk_overlap: int = 64,
metadata_template: str = "default", metadata_template: str = "default",
) -> list[dict[str, Any]]: ) -> list[str]:
"""Create AST-aware chunks from code documents using astchunk. """Create AST-aware chunks from code documents using astchunk.
Falls back to traditional chunking if astchunk is unavailable. Falls back to traditional chunking if astchunk is unavailable.
Returns:
List of dicts with {"text": str, "metadata": dict}
""" """
try: try:
from astchunk import ASTChunkBuilder # optional dependency from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e: except ImportError as e:
logger.error(f"astchunk not available: {e}") logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files") logger.info("Falling back to traditional chunking for code files")
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap) return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
all_chunks = [] all_chunks = []
for doc in documents: for doc in documents:
language = doc.metadata.get("language") language = doc.metadata.get("language")
if not language: if not language:
logger.warning("No language detected; falling back to traditional chunking") logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap)) all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
continue continue
try: try:
# Warn once if AST chunk size + overlap might exceed common token limits # Warn if AST chunk size + overlap might exceed common token limits
# Note: Actual truncation happens at embedding time with dynamic model limits
global _ast_token_warning_shown
estimated_max_tokens = int( estimated_max_tokens = int(
(max_chunk_size + chunk_overlap) * 1.2 (max_chunk_size + chunk_overlap) * 1.2
) # Conservative estimate ) # Conservative estimate
if estimated_max_tokens > 512 and not _ast_token_warning_shown: if estimated_max_tokens > 512:
logger.warning( logger.warning(
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars " f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). " f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. " f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}"
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
) )
_ast_token_warning_shown = True
configs = { configs = {
"max_chunk_size": max_chunk_size, "max_chunk_size": max_chunk_size,
@@ -239,40 +229,17 @@ def create_ast_chunks(
chunks = chunk_builder.chunkify(code_content) chunks = chunk_builder.chunkify(code_content)
for chunk in chunks: for chunk in chunks:
chunk_text = None
astchunk_metadata = {}
if hasattr(chunk, "text"): if hasattr(chunk, "text"):
chunk_text = chunk.text chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str): elif isinstance(chunk, str):
chunk_text = chunk chunk_text = chunk
elif isinstance(chunk, dict):
# Handle astchunk format: {"content": "...", "metadata": {...}}
if "content" in chunk:
chunk_text = chunk["content"]
astchunk_metadata = chunk.get("metadata", {})
elif "text" in chunk:
chunk_text = chunk["text"]
else:
chunk_text = str(chunk) # Last resort
else: else:
chunk_text = str(chunk) chunk_text = str(chunk)
if chunk_text and chunk_text.strip(): if chunk_text and chunk_text.strip():
# Extract document-level metadata all_chunks.append(chunk_text.strip())
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Merge document metadata + astchunk metadata
combined_metadata = {**doc_metadata, **astchunk_metadata}
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
logger.info( logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}" f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
@@ -280,19 +247,15 @@ def create_ast_chunks(
except Exception as e: except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}") logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking") logger.info("Falling back to traditional chunking")
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap)) all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
return all_chunks return all_chunks
def create_traditional_chunks( def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128 documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]: ) -> list[str]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter. """Create traditional text chunks using LlamaIndex SentenceSplitter."""
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if chunk_size <= 0: if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256") logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256 chunk_size = 256
@@ -308,40 +271,19 @@ def create_traditional_chunks(
paragraph_separator="\n\n", paragraph_separator="\n\n",
) )
result = [] all_texts = []
for doc in documents: for doc in documents:
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
try: try:
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
if nodes: if nodes:
for node in nodes: all_texts.extend(node.get_content() for node in nodes)
result.append({"text": node.get_content(), "metadata": doc_metadata})
except Exception as e: except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}") logger.error(f"Traditional chunking failed for document: {e}")
content = doc.get_content() content = doc.get_content()
if content and content.strip(): if content and content.strip():
result.append({"text": content.strip(), "metadata": doc_metadata}) all_texts.append(content.strip())
return result return all_texts
def _traditional_chunks_as_dicts(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]:
"""Helper: Traditional chunking that returns dict format for consistency.
This is now just an alias for create_traditional_chunks for backwards compatibility.
"""
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
def create_text_chunks( def create_text_chunks(
@@ -353,12 +295,8 @@ def create_text_chunks(
ast_chunk_overlap: int = 64, ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None, code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True, ast_fallback_traditional: bool = True,
) -> list[dict[str, Any]]: ) -> list[str]:
"""Create text chunks from documents with optional AST support for code files. """Create text chunks from documents with optional AST support for code files."""
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if not documents: if not documents:
logger.warning("No documents provided for chunking") logger.warning("No documents provided for chunking")
return [] return []
@@ -393,17 +331,24 @@ def create_text_chunks(
logger.error(f"AST chunking failed: {e}") logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional: if ast_fallback_traditional:
all_chunks.extend( all_chunks.extend(
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap) create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
) )
else: else:
raise raise
if text_docs: if text_docs:
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap)) all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
else: else:
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap) all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}") logger.info(f"Total chunks created: {len(all_chunks)}")
# Note: Token truncation is now handled at embedding time with dynamic model limits # Validate chunk token limits (default to 512 for safety)
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py # This provides a safety net for embedding models with token limits
return all_chunks validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512)
if num_truncated > 0:
logger.info(
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
)
return validated_chunks

View File

@@ -1279,8 +1279,13 @@ Examples:
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True), ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
) )
# create_text_chunks now returns list[dict] with metadata preserved # Note: AST chunking currently returns plain text chunks without metadata
all_texts.extend(chunk_texts) # We preserve basic file info by associating chunks with their source documents
# For better metadata preservation, documents list order should be maintained
for chunk_text in chunk_texts:
# TODO: Enhance create_text_chunks to return metadata alongside text
# For now, we store chunks with empty metadata
all_texts.append({"text": chunk_text, "metadata": {}})
except ImportError as e: except ImportError as e:
print( print(

View File

@@ -10,63 +10,72 @@ import time
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
import tiktoken
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Token limit registry for embedding models def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI) """
# Ollama models use dynamic discovery via /api/show Truncate texts to token limit using tiktoken or conservative character truncation.
EMBEDDING_MODEL_LIMITS = {
# Nomic models (common across servers) Args:
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show texts: List of texts to truncate
"nomic-embed-text-v1.5": 2048, max_tokens: Maximum tokens allowed per text
"nomic-embed-text-v2": 512,
# Other embedding models Returns:
"mxbai-embed-large": 512, List of truncated texts that should fit within token limit
"all-minilm": 512, """
"bge-m3": 8192, try:
"snowflake-arctic-embed": 512, import tiktoken
# OpenAI models
"text-embedding-3-small": 8192, encoder = tiktoken.get_encoding("cl100k_base")
"text-embedding-3-large": 8192, truncated = []
"text-embedding-ada-002": 8192,
} for text in texts:
tokens = encoder.encode(text)
if len(tokens) > max_tokens:
# Truncate to max_tokens and decode back to text
truncated_tokens = tokens[:max_tokens]
truncated_text = encoder.decode(truncated_tokens)
truncated.append(truncated_text)
logger.warning(
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
f"(from {len(text)} to {len(truncated_text)} characters)"
)
else:
truncated.append(text)
return truncated
except ImportError:
# Fallback: Conservative character truncation
# Assume worst case: 1.5 tokens per character for code content
char_limit = int(max_tokens / 1.5)
truncated = []
for text in texts:
if len(text) > char_limit:
truncated_text = text[:char_limit]
truncated.append(truncated_text)
logger.warning(
f"Truncated text from {len(text)} to {char_limit} characters "
f"(conservative estimate for {max_tokens} tokens)"
)
else:
truncated.append(text)
return truncated
def get_model_token_limit( def get_model_token_limit(model_name: str) -> int:
model_name: str,
base_url: Optional[str] = None,
default: int = 2048,
) -> int:
""" """
Get token limit for a given embedding model. Get token limit for a given embedding model.
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
Args: Args:
model_name: Name of the embedding model model_name: Name of the embedding model
base_url: Base URL of the embedding server (for dynamic discovery)
default: Default token limit if model not found
Returns: Returns:
Token limit for the model in tokens Token limit for the model, defaults to 512 if unknown
""" """
# Try Ollama dynamic discovery if base_url provided
if base_url:
# Detect Ollama servers by port or "ollama" in URL
if "11434" in base_url or "ollama" in base_url.lower():
limit = _query_ollama_context_limit(model_name, base_url)
if limit:
return limit
# Fallback to known model registry with version handling (from PR #154)
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text") # Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
base_model_name = model_name.split(":")[0] base_model_name = model_name.split(":")[0]
@@ -83,111 +92,31 @@ def get_model_token_limit(
if known_model in base_model_name or base_model_name in known_model: if known_model in base_model_name or base_model_name in known_model:
return limit return limit
# Default fallback # Default to conservative 512 token limit
logger.warning(f"Unknown model '{model_name}', using default {default} token limit") logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
return default return 512
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]: # Set up logger with proper level
""" logger = logging.getLogger(__name__)
Truncate texts to fit within token limit using tiktoken. LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
Args: logger.setLevel(log_level)
texts: List of text strings to truncate
token_limit: Maximum number of tokens allowed
Returns:
List of truncated texts (same length as input)
"""
if not texts:
return []
# Use tiktoken with cl100k_base encoding
enc = tiktoken.get_encoding("cl100k_base")
truncated_texts = []
truncation_count = 0
total_tokens_removed = 0
max_original_length = 0
for i, text in enumerate(texts):
tokens = enc.encode(text)
original_length = len(tokens)
if original_length <= token_limit:
# Text is within limit, keep as is
truncated_texts.append(text)
else:
# Truncate to token_limit
truncated_tokens = tokens[:token_limit]
truncated_text = enc.decode(truncated_tokens)
truncated_texts.append(truncated_text)
# Track truncation statistics
truncation_count += 1
tokens_removed = original_length - token_limit
total_tokens_removed += tokens_removed
max_original_length = max(max_original_length, original_length)
# Log individual truncation at WARNING level (first few only)
if truncation_count <= 3:
logger.warning(
f"Text {i + 1} truncated: {original_length}{token_limit} tokens "
f"({tokens_removed} tokens removed)"
)
elif truncation_count == 4:
logger.warning("Further truncation warnings suppressed...")
# Log summary at INFO level
if truncation_count > 0:
logger.warning(
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
)
else:
logger.debug(
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
)
return truncated_texts
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query Ollama /api/show for model context limit.
Args:
model_name: Name of the Ollama model
base_url: Base URL of the Ollama server
Returns:
Context limit in tokens if found, None otherwise
"""
try:
import requests
response = requests.post(
f"{base_url}/api/show",
json={"name": model_name},
timeout=5,
)
if response.status_code == 200:
data = response.json()
if "model_info" in data:
# Look for *.context_length in model_info
for key, value in data["model_info"].items():
if "context_length" in key and isinstance(value, int):
logger.info(f"Detected {model_name} context limit: {value} tokens")
return value
except Exception as e:
logger.debug(f"Failed to query Ollama context limit: {e}")
return None
# Global model cache to avoid repeated loading # Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {} _model_cache: dict[str, Any] = {}
# Known embedding model token limits
EMBEDDING_MODEL_LIMITS = {
"nomic-embed-text": 512,
"nomic-embed-text-v2": 512,
"mxbai-embed-large": 512,
"all-minilm": 512,
"bge-m3": 8192,
"snowflake-arctic-embed": 512,
# Add more models as needed
}
def compute_embeddings( def compute_embeddings(
texts: list[str], texts: list[str],
@@ -215,14 +144,9 @@ def compute_embeddings(
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
""" """
provider_options = provider_options or {} provider_options = provider_options or {}
wrapper_start_time = time.time()
logger.debug(
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
)
if mode == "sentence-transformers": if mode == "sentence-transformers":
inner_start_time = time.time() return compute_embeddings_sentence_transformers(
result = compute_embeddings_sentence_transformers(
texts, texts,
model_name, model_name,
is_build=is_build, is_build=is_build,
@@ -231,14 +155,6 @@ def compute_embeddings(
manual_tokenize=manual_tokenize, manual_tokenize=manual_tokenize,
max_length=max_length, max_length=max_length,
) )
inner_end_time = time.time()
wrapper_end_time = time.time()
logger.debug(
"[compute_embeddings] sentence-transformers timings: "
f"inner={inner_end_time - inner_start_time:.6f}s, "
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
)
return result
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai( return compute_embeddings_openai(
texts, texts,
@@ -284,7 +200,6 @@ def compute_embeddings_sentence_transformers(
is_build: Whether this is a build operation (shows progress bar) is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size adaptive_optimization: Whether to use adaptive optimization based on batch size
""" """
outer_start_time = time.time()
# Handle empty input # Handle empty input
if not texts: if not texts:
raise ValueError("Cannot compute embeddings for empty text list") raise ValueError("Cannot compute embeddings for empty text list")
@@ -315,14 +230,7 @@ def compute_embeddings_sentence_transformers(
# Create cache key # Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized" cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
pre_model_init_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers pre-model-init time "
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
)
# Check if model is already cached # Check if model is already cached
start_time = time.time()
if cache_key in _model_cache: if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}") logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key] model = _model_cache[cache_key]
@@ -462,13 +370,10 @@ def compute_embeddings_sentence_transformers(
_model_cache[cache_key] = model _model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}") logger.info(f"Model cached: {cache_key}")
end_time = time.time() # Compute embeddings with optimized inference mode
logger.info(
# Compute embeddings with optimized inference mode f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
logger.info( )
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
start_time = time.time() start_time = time.time()
if not manual_tokenize: if not manual_tokenize:
@@ -489,46 +394,32 @@ def compute_embeddings_sentence_transformers(
except Exception: except Exception:
pass pass
else: else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel. # Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
# This path is reserved for an aggressively optimized FP pipeline
# (no quantization), mainly for experimentation.
try: try:
from transformers import AutoModel, AutoTokenizer # type: ignore from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e: except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}") raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}" tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp" mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache: if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key] hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key] hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual FP path") logger.info("Using cached HF tokenizer/model for manual path")
else: else:
logger.info("Loading HF tokenizer/model for manual FP path") logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32 torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained( hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
model_name,
torch_dtype=torch_dtype,
)
hf_model.to(device) hf_model.to(device)
hf_model.eval() hf_model.eval()
# Optional compile on supported devices # Optional compile on supported devices
if device in ["cuda", "mps"]: if device in ["cuda", "mps"]:
try: try:
hf_model = torch.compile( # type: ignore hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
hf_model, mode="reduce-overhead", dynamic=True except Exception:
) pass
logger.info(
f"Applied torch.compile to HF model for {model_name} "
f"(device={device}, dtype={torch_dtype})"
)
except Exception as exc:
logger.warning(f"torch.compile optimization failed: {exc}")
_model_cache[tok_cache_key] = hf_tokenizer _model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model _model_cache[mdl_cache_key] = hf_model
@@ -554,6 +445,7 @@ def compute_embeddings_sentence_transformers(
for start_index in batch_iter: for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts)) end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index] batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer( inputs = hf_tokenizer(
batch_texts, batch_texts,
padding=True, padding=True,
@@ -561,17 +453,34 @@ def compute_embeddings_sentence_transformers(
max_length=max_length, max_length=max_length,
return_tensors="pt", return_tensors="pt",
) )
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()} inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs) outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H) last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask") attention_mask = inputs.get("attention_mask")
if attention_mask is None: if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1) pooled = last_hidden_state.mean(dim=1)
else: else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1) lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy() batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings) all_embeddings.append(batch_embeddings)
@@ -591,12 +500,6 @@ def compute_embeddings_sentence_transformers(
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}") raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
outer_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers total time "
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
)
return embeddings return embeddings
@@ -911,13 +814,15 @@ def compute_embeddings_ollama(
logger.info(f"Using batch size: {batch_size} for true batch processing") logger.info(f"Using batch size: {batch_size} for true batch processing")
# Get model token limit and apply truncation before batching # Get model token limit and apply truncation
token_limit = get_model_token_limit(model_name, base_url=resolved_host) token_limit = get_model_token_limit(model_name)
logger.info(f"Model '{model_name}' token limit: {token_limit}") logger.info(f"Model '{model_name}' token limit: {token_limit}")
# Apply truncation to all texts before batch processing # Apply token-aware truncation to all texts
# Function logs truncation details internally truncated_texts = truncate_to_token_limit(texts, token_limit)
texts = truncate_to_token_limit(texts, token_limit) if len(truncated_texts) != len(texts):
logger.error("Truncation failed - text count mismatch")
truncated_texts = texts # Fallback to original texts
def get_batch_embeddings(batch_texts): def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts using /api/embed endpoint.""" """Get embeddings for a batch of texts using /api/embed endpoint."""
@@ -975,12 +880,12 @@ def compute_embeddings_ollama(
return None, list(range(len(batch_texts))) return None, list(range(len(batch_texts)))
# Process texts in batches # Process truncated texts in batches
all_embeddings = [] all_embeddings = []
all_failed_indices = [] all_failed_indices = []
# Setup progress bar if needed # Setup progress bar if needed
show_progress = is_build or len(texts) > 10 show_progress = is_build or len(truncated_texts) > 10
try: try:
if show_progress: if show_progress:
from tqdm import tqdm from tqdm import tqdm
@@ -988,7 +893,7 @@ def compute_embeddings_ollama(
show_progress = False show_progress = False
# Process batches # Process batches
num_batches = (len(texts) + batch_size - 1) // batch_size num_batches = (len(truncated_texts) + batch_size - 1) // batch_size
if show_progress: if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)") batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
@@ -997,8 +902,8 @@ def compute_embeddings_ollama(
for batch_idx in batch_iterator: for batch_idx in batch_iterator:
start_idx = batch_idx * batch_size start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(texts)) end_idx = min(start_idx + batch_size, len(truncated_texts))
batch_texts = texts[start_idx:end_idx] batch_texts = truncated_texts[start_idx:end_idx]
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts) batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
@@ -1013,11 +918,11 @@ def compute_embeddings_ollama(
# Handle failed embeddings # Handle failed embeddings
if all_failed_indices: if all_failed_indices:
if len(all_failed_indices) == len(texts): if len(all_failed_indices) == len(truncated_texts):
raise RuntimeError("Failed to compute any embeddings") raise RuntimeError("Failed to compute any embeddings")
logger.warning( logger.warning(
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts" f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts"
) )
# Use zero embeddings as fallback for failed ones # Use zero embeddings as fallback for failed ones

View File

@@ -57,8 +57,6 @@ dependencies = [
"tree-sitter-c-sharp>=0.20.0", "tree-sitter-c-sharp>=0.20.0",
"tree-sitter-typescript>=0.20.0", "tree-sitter-typescript>=0.20.0",
"torchvision>=0.23.0", "torchvision>=0.23.0",
"einops",
"seaborn",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -0,0 +1,162 @@
#!/usr/bin/env python3
"""
Test script to reproduce ColQwen results from issue #119
https://github.com/yichuan-w/LEANN/issues/119
This script demonstrates the ColQwen workflow:
1. Download sample PDF
2. Convert to images
3. Build multimodal index
4. Run test queries
5. Generate similarity maps
"""
import importlib.util
import os
from pathlib import Path
def main():
print("🧪 ColQwen Reproduction Test - Issue #119")
print("=" * 50)
# Check if we're in the right directory
repo_root = Path.cwd()
if not (repo_root / "apps" / "colqwen_rag.py").exists():
print("❌ Please run this script from the LEANN repository root")
print(" cd /path/to/LEANN && python test_colqwen_reproduction.py")
return
print("✅ Repository structure looks good")
# Step 1: Check dependencies
print("\n📦 Checking dependencies...")
try:
import torch
# Check if pdf2image is available
if importlib.util.find_spec("pdf2image") is None:
raise ImportError("pdf2image not found")
# Check if colpali_engine is available
if importlib.util.find_spec("colpali_engine") is None:
raise ImportError("colpali_engine not found")
print("✅ Core dependencies available")
print(f" - PyTorch: {torch.__version__}")
print(f" - CUDA available: {torch.cuda.is_available()}")
print(
f" - MPS available: {hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()}"
)
except ImportError as e:
print(f"❌ Missing dependency: {e}")
print("\n📥 Install missing dependencies:")
print(
" uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn"
)
return
# Step 2: Download sample PDF
print("\n📄 Setting up sample PDF...")
pdf_dir = repo_root / "test_pdfs"
pdf_dir.mkdir(exist_ok=True)
sample_pdf = pdf_dir / "attention_paper.pdf"
if not sample_pdf.exists():
print("📥 Downloading sample paper (Attention Is All You Need)...")
import urllib.request
try:
urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", sample_pdf)
print(f"✅ Downloaded: {sample_pdf}")
except Exception as e:
print(f"❌ Download failed: {e}")
print(" Please manually download a PDF to test_pdfs/attention_paper.pdf")
return
else:
print(f"✅ Using existing PDF: {sample_pdf}")
# Step 3: Test ColQwen RAG
print("\n🚀 Testing ColQwen RAG...")
# Build index
print("\n1⃣ Building multimodal index...")
build_cmd = f"python -m apps.colqwen_rag build --pdfs {pdf_dir} --index test_attention --model colqwen2 --pages-dir test_pages"
print(f" Command: {build_cmd}")
try:
result = os.system(build_cmd)
if result == 0:
print("✅ Index built successfully!")
else:
print("❌ Index building failed")
return
except Exception as e:
print(f"❌ Error building index: {e}")
return
# Test search
print("\n2⃣ Testing search...")
test_queries = [
"How does attention mechanism work?",
"What is the transformer architecture?",
"How do you compute self-attention?",
]
for query in test_queries:
print(f"\n🔍 Query: '{query}'")
search_cmd = f'python -m apps.colqwen_rag search test_attention "{query}" --top-k 3'
print(f" Command: {search_cmd}")
try:
result = os.system(search_cmd)
if result == 0:
print("✅ Search completed")
else:
print("❌ Search failed")
except Exception as e:
print(f"❌ Search error: {e}")
# Test interactive mode (briefly)
print("\n3⃣ Testing interactive mode...")
print(" You can test interactive mode with:")
print(" python -m apps.colqwen_rag ask test_attention --interactive")
# Step 4: Test similarity maps (using existing script)
print("\n4⃣ Testing similarity maps...")
similarity_script = (
repo_root
/ "apps"
/ "multimodal"
/ "vision-based-pdf-multi-vector"
/ "multi-vector-leann-similarity-map.py"
)
if similarity_script.exists():
print(" You can generate similarity maps with:")
print(f" cd {similarity_script.parent}")
print(" python multi-vector-leann-similarity-map.py")
print(" (Edit the script to use your local PDF)")
print("\n🎉 ColQwen reproduction test completed!")
print("\n📋 Summary:")
print(" ✅ Dependencies checked")
print(" ✅ Sample PDF prepared")
print(" ✅ Index building tested")
print(" ✅ Search functionality tested")
print(" ✅ Interactive mode available")
print(" ✅ Similarity maps available")
print("\n🔗 Related repositories to check:")
print(" - https://github.com/lightonai/fast-plaid")
print(" - https://github.com/lightonai/pylate")
print(" - https://github.com/stanford-futuredata/ColBERT")
print("\n📝 Next steps:")
print(" 1. Test with your own PDFs")
print(" 2. Experiment with different queries")
print(" 3. Generate similarity maps for visual analysis")
print(" 4. Compare ColQwen2 vs ColPali performance")
if __name__ == "__main__":
main()

View File

@@ -8,7 +8,7 @@ import subprocess
import sys import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch from unittest.mock import patch
import pytest import pytest
@@ -116,10 +116,8 @@ class TestChunkingFunctions:
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10) chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0 assert len(chunks) > 0
# Traditional chunks now return dict format for consistency assert all(isinstance(chunk, str) for chunk in chunks)
assert all(isinstance(chunk, dict) for chunk in chunks) assert all(len(chunk.strip()) > 0 for chunk in chunks)
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
def test_create_traditional_chunks_empty_docs(self): def test_create_traditional_chunks_empty_docs(self):
"""Test traditional chunking with empty documents.""" """Test traditional chunking with empty documents."""
@@ -160,22 +158,11 @@ class Calculator:
# Should have multiple chunks due to different functions/classes # Should have multiple chunks due to different functions/classes
assert len(chunks) > 0 assert len(chunks) > 0
# R3: Expect dict format with "text" and "metadata" keys assert all(isinstance(chunk, str) for chunk in chunks)
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts" assert all(len(chunk.strip()) > 0 for chunk in chunks)
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
"Each chunk text should be non-empty"
)
# Check metadata is present
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
"Each chunk should have file_path metadata"
)
# Check that code structure is somewhat preserved # Check that code structure is somewhat preserved
combined_content = " ".join([c["text"] for c in chunks]) combined_content = " ".join(chunks)
assert "def hello_world" in combined_content assert "def hello_world" in combined_content
assert "class Calculator" in combined_content assert "class Calculator" in combined_content
@@ -207,11 +194,7 @@ class Calculator:
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10) chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0 assert len(chunks) > 0
# R3: Traditional chunking should also return dict format for consistency assert all(isinstance(chunk, str) for chunk in chunks)
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_ast_mode(self): def test_create_text_chunks_ast_mode(self):
"""Test text chunking in AST mode.""" """Test text chunking in AST mode."""
@@ -230,11 +213,7 @@ class Calculator:
) )
assert len(chunks) > 0 assert len(chunks) > 0
# R3: AST mode should also return dict format assert all(isinstance(chunk, str) for chunk in chunks)
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_custom_extensions(self): def test_create_text_chunks_custom_extensions(self):
"""Test text chunking with custom code file extensions.""" """Test text chunking with custom code file extensions."""
@@ -374,552 +353,6 @@ class MathUtils:
pytest.skip("Test timed out - likely due to model download in CI") pytest.skip("Test timed out - likely due to model download in CI")
class TestASTContentExtraction:
"""Test AST content extraction bug fix.
These tests verify that astchunk's dict format with 'content' key is handled correctly,
and that the extraction logic doesn't fall through to stringifying entire dicts.
"""
def test_extract_content_from_astchunk_dict(self):
"""Test that astchunk dict format with 'content' key is handled correctly.
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
This causes fallthrough to str(chunk), stringifying the entire dict.
This test will FAIL until the bug is fixed because:
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
- Fixed code should extract just the content value
"""
# Mock the ASTChunkBuilder class
mock_builder = Mock()
# Astchunk returns this format
astchunk_format_chunk = {
"content": "def hello():\n print('world')",
"metadata": {
"filepath": "test.py",
"line_count": 2,
"start_line_no": 0,
"end_line_no": 1,
"node_count": 1,
},
}
mock_builder.chunkify.return_value = [astchunk_format_chunk]
# Create mock document
doc = MockDocument(
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
)
# Mock the astchunk module and its ASTChunkBuilder class
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
# Patch sys.modules to inject our mock before the import
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should return dict format with proper metadata
assert len(chunks) > 0, "Should return at least one chunk"
# R3: Each chunk should be a dict
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "metadata" in chunk, "Chunk should have 'metadata' key"
chunk_text = chunk["text"]
# CRITICAL: Should NOT contain stringified dict markers in the text field
# These assertions will FAIL with current buggy code
assert "'content':" not in chunk_text, (
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
)
assert "'metadata':" not in chunk_text, (
"Chunk text contains stringified metadata - extraction failed! "
f"Got: {chunk_text[:100]}..."
)
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
"Chunk text appears to be a stringified dict"
)
# Should contain actual content
assert "def hello()" in chunk_text, "Should extract actual code content"
assert "print('world')" in chunk_text, "Should extract complete code content"
# R3: Should preserve astchunk metadata
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
"Should preserve file path metadata"
)
def test_extract_text_key_fallback(self):
"""Test that 'text' key still works for backward compatibility.
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
This test should PASS even with current code.
"""
mock_builder = Mock()
# Some chunks might use "text" key
text_key_chunk = {"text": "def legacy_function():\n return True"}
mock_builder.chunkify.return_value = [text_key_chunk]
# Create mock document
doc = MockDocument(
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract text correctly as dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
# Should NOT be stringified
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
# Should contain actual content
assert "def legacy_function()" in chunk_text
assert "return True" in chunk_text
def test_handles_string_chunks(self):
"""Test that plain string chunks still work.
Some chunkers might return plain strings - verify these are preserved.
This test should PASS with current code.
"""
mock_builder = Mock()
# Plain string chunk
plain_string_chunk = "def simple_function():\n pass"
mock_builder.chunkify.return_value = [plain_string_chunk]
# Create mock document
doc = MockDocument(
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should wrap string in dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
assert chunk_text == plain_string_chunk.strip(), (
"Should preserve plain string chunk content"
)
assert "def simple_function()" in chunk_text
assert "pass" in chunk_text
def test_multiple_chunks_with_mixed_formats(self):
"""Test handling of multiple chunks with different formats.
Real-world scenario: astchunk might return a mix of formats.
This test will FAIL if any chunk with 'content' key gets stringified.
"""
mock_builder = Mock()
# Mix of formats
mixed_chunks = [
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
"def second():\n return 2", # Plain string
{"text": "def third():\n return 3"}, # Old format
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
]
mock_builder.chunkify.return_value = mixed_chunks
# Create mock document
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract all chunks correctly as dicts
assert len(chunks) == 4, "Should extract all 4 chunks"
# Check each chunk
for i, chunk in enumerate(chunks):
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
assert "text" in chunk, f"Chunk {i} should have 'text' key"
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
chunk_text = chunk["text"]
# None should be stringified dicts
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
assert "'metadata':" not in chunk_text, (
f"Chunk {i} text is stringified (has 'metadata':)"
)
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
# Verify actual content is present
combined = "\n".join([c["text"] for c in chunks])
assert "def first()" in combined
assert "def second()" in combined
assert "def third()" in combined
assert "class MyClass:" in combined
def test_empty_content_value_handling(self):
"""Test handling of chunks with empty content values.
Edge case: chunk has 'content' key but value is empty.
Should skip these chunks, not stringify them.
"""
mock_builder = Mock()
chunks_with_empty = [
{"content": "", "metadata": {"line_count": 0}}, # Empty content
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
]
mock_builder.chunkify.return_value = chunks_with_empty
doc = MockDocument(
"def valid():\n return True", "/test/empty.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# R3: Should only have the valid chunk (empty ones filtered out)
assert len(chunks) == 1, "Should filter out empty content chunks"
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "def valid()" in chunk["text"]
# Should not have stringified the empty dict
assert "'content': ''" not in chunk["text"]
class TestASTMetadataPreservation:
"""Test metadata preservation in AST chunk dictionaries.
R3: These tests define the contract for metadata preservation when returning
chunk dictionaries instead of plain strings. Each chunk dict should have:
- "text": str - the actual chunk content
- "metadata": dict - all metadata from document AND astchunk
These tests will FAIL until G3 implementation changes return type to list[dict].
"""
def test_ast_chunks_preserve_file_metadata(self):
"""Test that document metadata is preserved in chunk metadata.
This test verifies that all document-level metadata (file_path, file_name,
creation_date, last_modified_date) is included in each chunk's metadata dict.
This will FAIL because current code returns list[str], not list[dict].
"""
# Create mock document with rich metadata
python_code = '''
def calculate_sum(numbers):
"""Calculate sum of numbers."""
return sum(numbers)
class DataProcessor:
"""Process data records."""
def process(self, data):
return [x * 2 for x in data]
'''
doc = MockDocument(
python_code,
file_path="/project/src/utils.py",
metadata={
"language": "python",
"file_path": "/project/src/utils.py",
"file_name": "utils.py",
"creation_date": "2024-01-15T10:30:00",
"last_modified_date": "2024-10-31T15:45:00",
},
)
# Mock astchunk to return chunks with metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def calculate_sum(numbers):\n return sum(numbers)",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 2,
"start_line_no": 1,
"end_line_no": 2,
"node_count": 1,
},
},
{
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 3,
"start_line_no": 5,
"end_line_no": 7,
"node_count": 2,
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These assertions will FAIL with current list[str] return type
assert len(chunks) == 2, "Should return 2 chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL: current code returns strings
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
# Document metadata preservation - WILL FAIL
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
assert metadata["file_path"] == "/project/src/utils.py", (
f"Chunk {i} file_path incorrect"
)
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
f"Chunk {i} creation_date incorrect"
)
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
f"Chunk {i} last_modified_date incorrect"
)
# Verify metadata is consistent across chunks from same document
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
"All chunks from same document should have same file_path"
)
# Verify text content is present and not stringified
assert "def calculate_sum" in chunks[0]["text"]
assert "class DataProcessor" in chunks[1]["text"]
def test_ast_chunks_include_astchunk_metadata(self):
"""Test that astchunk-specific metadata is merged into chunk metadata.
This test verifies that astchunk's metadata (line_count, start_line_no,
end_line_no, node_count) is merged with document metadata.
This will FAIL because current code returns list[str], not list[dict].
"""
python_code = '''
def function_one():
"""First function."""
x = 1
y = 2
return x + y
def function_two():
"""Second function."""
return 42
'''
doc = MockDocument(
python_code,
file_path="/test/code.py",
metadata={
"language": "python",
"file_path": "/test/code.py",
"file_name": "code.py",
},
)
# Mock astchunk with detailed metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
"metadata": {
"filepath": "/test/code.py",
"line_count": 4,
"start_line_no": 1,
"end_line_no": 4,
"node_count": 5, # function, assignments, return
},
},
{
"content": "def function_two():\n return 42",
"metadata": {
"filepath": "/test/code.py",
"line_count": 2,
"start_line_no": 7,
"end_line_no": 8,
"node_count": 2, # function, return
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These will FAIL with current list[str] return
assert len(chunks) == 2
# First chunk - function_one
chunk1 = chunks[0]
assert isinstance(chunk1, dict), "Chunk should be dict"
assert "metadata" in chunk1
metadata1 = chunk1["metadata"]
# Check astchunk metadata is present
assert "line_count" in metadata1, "Should include astchunk line_count"
assert metadata1["line_count"] == 4, "line_count should be 4"
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
assert "node_count" in metadata1, "Should include astchunk node_count"
assert metadata1["node_count"] == 5, "node_count should be 5"
# Second chunk - function_two
chunk2 = chunks[1]
metadata2 = chunk2["metadata"]
assert metadata2["line_count"] == 2, "line_count should be 2"
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
assert metadata2["node_count"] == 2, "node_count should be 2"
# Verify document metadata is ALSO present (merged, not replaced)
assert metadata1["file_path"] == "/test/code.py"
assert metadata1["file_name"] == "code.py"
assert metadata2["file_path"] == "/test/code.py"
assert metadata2["file_name"] == "code.py"
# Verify text content is correct
assert "def function_one" in chunk1["text"]
assert "def function_two" in chunk2["text"]
def test_traditional_chunks_as_dicts_helper(self):
"""Test the helper function that wraps traditional chunks as dicts.
This test verifies that when create_traditional_chunks is called,
its plain string chunks are wrapped into dict format with metadata.
This will FAIL because the helper function _traditional_chunks_as_dicts()
doesn't exist yet, and create_traditional_chunks returns list[str].
"""
# Create documents with various metadata
docs = [
MockDocument(
"This is the first paragraph of text. It contains multiple sentences. "
"This should be split into chunks based on size.",
file_path="/docs/readme.txt",
metadata={
"file_path": "/docs/readme.txt",
"file_name": "readme.txt",
"creation_date": "2024-01-01",
},
),
MockDocument(
"Second document with different metadata. It also has content that needs chunking.",
file_path="/docs/guide.md",
metadata={
"file_path": "/docs/guide.md",
"file_name": "guide.md",
"last_modified_date": "2024-10-31",
},
),
]
# Call create_traditional_chunks (which should now return list[dict])
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
# CRITICAL: Will FAIL - current code returns list[str]
assert len(chunks) > 0, "Should return chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
# Text should be non-empty
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
# Metadata should include document info
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
# Verify metadata tracking works correctly
# At least one chunk should be from readme.txt
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
# At least one chunk should be from guide.md
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
# Verify creation_date is preserved for readme chunks
for chunk in readme_chunks:
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
"readme.txt chunks should preserve creation_date"
)
# Verify last_modified_date is preserved for guide chunks
for chunk in guide_chunks:
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
"guide.md chunks should preserve last_modified_date"
)
# Verify text content is present
all_text = " ".join([c["text"] for c in chunks])
assert "first paragraph" in all_text
assert "Second document" in all_text
class TestErrorHandling: class TestErrorHandling:
"""Test error handling and edge cases.""" """Test error handling and edge cases."""

View File

@@ -1,268 +0,0 @@
"""Unit tests for token-aware truncation functionality.
This test suite defines the contract for token truncation functions that prevent
500 errors from Ollama when text exceeds model token limits. These tests verify:
1. Model token limit retrieval (known and unknown models)
2. Text truncation behavior for single and multiple texts
3. Token counting and truncation accuracy using tiktoken
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
import pytest
import tiktoken
from leann.embedding_compute import (
EMBEDDING_MODEL_LIMITS,
get_model_token_limit,
truncate_to_token_limit,
)
class TestModelTokenLimits:
"""Tests for retrieving model-specific token limits."""
def test_get_model_token_limit_known_model(self):
"""Verify correct token limit is returned for known models.
Known models should return their specific token limits from
EMBEDDING_MODEL_LIMITS dictionary.
"""
# Test nomic-embed-text (2048 tokens)
limit = get_model_token_limit("nomic-embed-text")
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
# Test nomic-embed-text-v1.5 (2048 tokens)
limit = get_model_token_limit("nomic-embed-text-v1.5")
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
# Test nomic-embed-text-v2 (512 tokens)
limit = get_model_token_limit("nomic-embed-text-v2")
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
# Test OpenAI models (8192 tokens)
limit = get_model_token_limit("text-embedding-3-small")
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
def test_get_model_token_limit_unknown_model(self):
"""Verify default token limit is returned for unknown models.
Unknown models should return the default limit (2048) to allow
operation with reasonable safety margin.
"""
# Test with completely unknown model
limit = get_model_token_limit("unknown-model-xyz")
assert limit == 2048, "Unknown models should return default 2048"
# Test with empty string
limit = get_model_token_limit("")
assert limit == 2048, "Empty model name should return default 2048"
def test_get_model_token_limit_custom_default(self):
"""Verify custom default can be specified for unknown models.
Allow callers to specify their own default token limit when
model is not in the known models dictionary.
"""
limit = get_model_token_limit("unknown-model", default=4096)
assert limit == 4096, "Should return custom default for unknown models"
# Known model should ignore custom default
limit = get_model_token_limit("nomic-embed-text", default=4096)
assert limit == 2048, "Known model should ignore custom default"
def test_embedding_model_limits_dictionary_exists(self):
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
The dictionary should be importable and contain at least the
known nomic models with correct token limits.
"""
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
"Should contain nomic-embed-text-v1.5"
)
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
# OpenAI models
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
class TestTokenTruncation:
"""Tests for truncating texts to token limits."""
@pytest.fixture
def tokenizer(self):
"""Provide tiktoken tokenizer for token counting verification."""
return tiktoken.get_encoding("cl100k_base")
def test_truncate_single_text_under_limit(self, tokenizer):
"""Verify text under token limit remains unchanged.
When text is already within the token limit, it should be
returned unchanged with no truncation.
"""
text = "This is a short text that is well under the token limit."
token_count = len(tokenizer.encode(text))
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
# Truncate with generous limit
result = truncate_to_token_limit([text], token_limit=512)
assert len(result) == 1, "Should return same number of texts"
assert result[0] == text, "Text under limit should be unchanged"
def test_truncate_single_text_over_limit(self, tokenizer):
"""Verify text over token limit is truncated correctly.
When text exceeds the token limit, it should be truncated to
fit within the limit while maintaining valid token boundaries.
"""
# Create a text that definitely exceeds limit
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 50, (
f"Test setup: text should be long (has {original_token_count} tokens)"
)
# Truncate to 50 tokens
result = truncate_to_token_limit([text], token_limit=50)
assert len(result) == 1, "Should return same number of texts"
assert result[0] != text, "Text over limit should be truncated"
assert len(result[0]) < len(text), "Truncated text should be shorter"
# Verify truncated text is within token limit
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 50, (
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
)
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
"""Verify multiple texts with mixed lengths are handled correctly.
When processing multiple texts:
- Texts under limit should remain unchanged
- Texts over limit should be truncated independently
- Output list should maintain same order and length
"""
texts = [
"Short text.", # Under limit
"word " * 200, # Over limit
"Another short one.", # Under limit
"token " * 150, # Over limit
]
# Verify test setup
for i, text in enumerate(texts):
token_count = len(tokenizer.encode(text))
if i in [1, 3]:
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
else:
assert token_count < 50, (
f"Text {i} should be under limit (has {token_count} tokens)"
)
# Truncate with 50 token limit
result = truncate_to_token_limit(texts, token_limit=50)
assert len(result) == len(texts), "Should return same number of texts"
# Verify each text individually
for i, (original, truncated) in enumerate(zip(texts, result)):
token_count = len(tokenizer.encode(truncated))
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
# Short texts should be unchanged
if i in [0, 2]:
assert truncated == original, f"Short text {i} should be unchanged"
# Long texts should be truncated
else:
assert len(truncated) < len(original), f"Long text {i} should be truncated"
def test_truncate_empty_list(self):
"""Verify empty input list returns empty output list.
Edge case: empty list should return empty list without errors.
"""
result = truncate_to_token_limit([], token_limit=512)
assert result == [], "Empty input should return empty output"
def test_truncate_preserves_order(self, tokenizer):
"""Verify truncation preserves original text order.
Output list should maintain the same order as input list,
regardless of which texts were truncated.
"""
texts = [
"First text " * 50, # Will be truncated
"Second text.", # Won't be truncated
"Third text " * 50, # Will be truncated
]
result = truncate_to_token_limit(texts, token_limit=20)
assert len(result) == 3, "Should preserve list length"
# Check that order is maintained by looking for distinctive words
assert "First" in result[0], "First text should remain in first position"
assert "Second" in result[1], "Second text should remain in second position"
assert "Third" in result[2], "Third text should remain in third position"
def test_truncate_extremely_long_text(self, tokenizer):
"""Verify extremely long texts are truncated efficiently.
Test with text that far exceeds token limit to ensure
truncation handles extreme cases without performance issues.
"""
# Create very long text (simulate real-world scenario)
text = "token " * 5000 # ~5000+ tokens
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 1000, "Test setup: text should be very long"
# Truncate to small limit
result = truncate_to_token_limit([text], token_limit=100)
assert len(result) == 1
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 100, (
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
)
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
def test_truncate_exact_token_limit(self, tokenizer):
"""Verify text at exactly token limit is handled correctly.
Edge case: text with exactly the token limit should either
remain unchanged or be safely truncated by 1 token.
"""
# Create text with approximately 50 tokens
# We'll adjust to get exactly 50
target_tokens = 50
text = "word " * 50
tokens = tokenizer.encode(text)
# Adjust to get exactly target_tokens
if len(tokens) > target_tokens:
tokens = tokens[:target_tokens]
text = tokenizer.decode(tokens)
elif len(tokens) < target_tokens:
# Add more words
while len(tokenizer.encode(text)) < target_tokens:
text += "word "
tokens = tokenizer.encode(text)[:target_tokens]
text = tokenizer.decode(tokens)
# Verify we have exactly target_tokens
assert len(tokenizer.encode(text)) == target_tokens, (
"Test setup: should have exactly 50 tokens"
)
result = truncate_to_token_limit([text], token_limit=target_tokens)
assert len(result) == 1
result_tokens = len(tokenizer.encode(result[0]))
assert result_tokens <= target_tokens, (
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
)

7553
uv.lock generated
View File

File diff suppressed because it is too large Load Diff