Compare commits
12 Commits
feature/op
...
fix/pdf-du
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2afcdf7b77 | ||
|
|
13beb98164 | ||
|
|
9b7353f336 | ||
|
|
9dd0e0b26f | ||
|
|
dc6c9f696e | ||
|
|
2406c41eef | ||
|
|
d4f5f2896f | ||
|
|
366984e92e | ||
|
|
64b92a04a7 | ||
|
|
a85d0ad4a7 | ||
|
|
dbb5f4d352 | ||
|
|
f180b83589 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -105,3 +105,6 @@ apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weavia
|
|||||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||||
|
|
||||||
|
# AUR build directory (Arch Linux)
|
||||||
|
paru-bin/
|
||||||
|
|||||||
200
COLQWEN_GUIDE.md
Normal file
200
COLQWEN_GUIDE.md
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
# ColQwen Integration Guide
|
||||||
|
|
||||||
|
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||||
|
|
||||||
|
### 1. Install Dependencies
|
||||||
|
```bash
|
||||||
|
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||||
|
brew install poppler # macOS only, for PDF processing
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Basic Usage
|
||||||
|
```bash
|
||||||
|
# Build index from PDFs
|
||||||
|
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||||
|
|
||||||
|
# Search with text queries
|
||||||
|
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||||
|
|
||||||
|
# Interactive Q&A
|
||||||
|
python -m apps.colqwen_rag ask research_papers --interactive
|
||||||
|
```
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
### Build Index
|
||||||
|
```bash
|
||||||
|
python -m apps.colqwen_rag build \
|
||||||
|
--pdfs ./pdf_directory/ \
|
||||||
|
--index my_index \
|
||||||
|
--model colqwen2 \
|
||||||
|
--pages-dir ./page_images/ # Optional: save page images
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--pdfs`: Directory containing PDF files (or single PDF path)
|
||||||
|
- `--index`: Name for the index (required)
|
||||||
|
- `--model`: `colqwen2` (default) or `colpali`
|
||||||
|
- `--pages-dir`: Directory to save page images (optional)
|
||||||
|
|
||||||
|
### Search Index
|
||||||
|
```bash
|
||||||
|
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--top-k`: Number of results to return (default: 5)
|
||||||
|
- `--model`: Model used for search (should match build model)
|
||||||
|
|
||||||
|
### Interactive Q&A
|
||||||
|
```bash
|
||||||
|
python -m apps.colqwen_rag ask my_index --interactive
|
||||||
|
```
|
||||||
|
|
||||||
|
**Commands in interactive mode:**
|
||||||
|
- Type your questions naturally
|
||||||
|
- `help`: Show available commands
|
||||||
|
- `quit`/`exit`/`q`: Exit interactive mode
|
||||||
|
|
||||||
|
## 🧪 Test & Reproduce Results
|
||||||
|
|
||||||
|
Run the reproduction test for issue #119:
|
||||||
|
```bash
|
||||||
|
python test_colqwen_reproduction.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
1. ✅ Check dependencies
|
||||||
|
2. 📥 Download sample PDF (Attention Is All You Need paper)
|
||||||
|
3. 🏗️ Build test index
|
||||||
|
4. 🔍 Run sample queries
|
||||||
|
5. 📊 Show how to generate similarity maps
|
||||||
|
|
||||||
|
## 🎨 Advanced: Similarity Maps
|
||||||
|
|
||||||
|
For visual similarity analysis, use the existing advanced script:
|
||||||
|
```bash
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||||
|
python multi-vector-leann-similarity-map.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Edit the script to customize:
|
||||||
|
- `QUERY`: Your question
|
||||||
|
- `MODEL`: "colqwen2" or "colpali"
|
||||||
|
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
|
||||||
|
- `SIMILARITY_MAP`: Generate heatmaps
|
||||||
|
- `ANSWER`: Enable Qwen-VL answer generation
|
||||||
|
|
||||||
|
## 🔧 How It Works
|
||||||
|
|
||||||
|
### ColQwen2 vs ColPali
|
||||||
|
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
|
||||||
|
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
|
||||||
|
2. **Vision Encoding**: Process images with ColQwen2/ColPali
|
||||||
|
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
|
||||||
|
4. **Query Processing**: Encode text queries with same model
|
||||||
|
5. **Similarity Search**: Find most relevant pages/regions
|
||||||
|
6. **Visual Maps**: Generate attention heatmaps (optional)
|
||||||
|
|
||||||
|
### Device Support
|
||||||
|
- **CUDA**: Best performance with GPU acceleration
|
||||||
|
- **MPS**: Apple Silicon Mac support
|
||||||
|
- **CPU**: Fallback for any system (slower)
|
||||||
|
|
||||||
|
Auto-detection: CUDA > MPS > CPU
|
||||||
|
|
||||||
|
## 📊 Performance Tips
|
||||||
|
|
||||||
|
### For Best Performance:
|
||||||
|
```bash
|
||||||
|
# Use ColQwen2 for latest features
|
||||||
|
--model colqwen2
|
||||||
|
|
||||||
|
# Save page images for reuse
|
||||||
|
--pages-dir ./cached_pages/
|
||||||
|
|
||||||
|
# Adjust batch size based on GPU memory
|
||||||
|
# (automatically handled)
|
||||||
|
```
|
||||||
|
|
||||||
|
### For Large Document Sets:
|
||||||
|
- Process PDFs in batches
|
||||||
|
- Use SSD storage for index files
|
||||||
|
- Consider using CUDA if available
|
||||||
|
|
||||||
|
## 🔗 Related Resources
|
||||||
|
|
||||||
|
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
|
||||||
|
- **Pylate**: https://github.com/lightonai/pylate
|
||||||
|
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
|
||||||
|
- **ColPali Paper**: Vision-Language Models for Document Retrieval
|
||||||
|
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
|
||||||
|
|
||||||
|
## 🐛 Troubleshooting
|
||||||
|
|
||||||
|
### PDF Conversion Issues (macOS)
|
||||||
|
```bash
|
||||||
|
# Install poppler
|
||||||
|
brew install poppler
|
||||||
|
which pdfinfo && pdfinfo -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory Issues
|
||||||
|
- Reduce batch size (automatically handled)
|
||||||
|
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
|
||||||
|
- Process fewer PDFs at once
|
||||||
|
|
||||||
|
### Model Download Issues
|
||||||
|
- Ensure internet connection for first run
|
||||||
|
- Models are cached after first download
|
||||||
|
- Use HuggingFace mirrors if needed
|
||||||
|
|
||||||
|
### Import Errors
|
||||||
|
```bash
|
||||||
|
# Ensure all dependencies installed
|
||||||
|
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||||
|
|
||||||
|
# Check PyTorch installation
|
||||||
|
python -c "import torch; print(torch.__version__)"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 💡 Examples
|
||||||
|
|
||||||
|
### Research Paper Analysis
|
||||||
|
```bash
|
||||||
|
# Index your research papers
|
||||||
|
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
|
||||||
|
|
||||||
|
# Ask research questions
|
||||||
|
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
|
||||||
|
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Document Q&A
|
||||||
|
```bash
|
||||||
|
# Index business documents
|
||||||
|
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
|
||||||
|
|
||||||
|
# Interactive analysis
|
||||||
|
python -m apps.colqwen_rag ask reports --interactive
|
||||||
|
```
|
||||||
|
|
||||||
|
### Visual Analysis
|
||||||
|
```bash
|
||||||
|
# Generate similarity maps for specific queries
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||||
|
# Edit multi-vector-leann-similarity-map.py with your query
|
||||||
|
python multi-vector-leann-similarity-map.py
|
||||||
|
# Check ./figures/ for generated heatmaps
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**
|
||||||
@@ -24,7 +24,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
|||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||||
@@ -1213,3 +1213,7 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
Made with ❤️ by the Leann team
|
Made with ❤️ by the Leann team
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
## 🤖 Explore LEANN with AI
|
||||||
|
|
||||||
|
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.
|
||||||
|
|||||||
@@ -180,14 +180,14 @@ class BaseRAGExample(ABC):
|
|||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--ast-chunk-size",
|
"--ast-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=300,
|
||||||
help="Maximum characters per AST chunk (default: 512)",
|
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
|
||||||
)
|
)
|
||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--ast-chunk-overlap",
|
"--ast-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="Overlap between AST chunks (default: 64)",
|
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
|
||||||
)
|
)
|
||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--code-file-extensions",
|
"--code-file-extensions",
|
||||||
|
|||||||
364
apps/colqwen_rag.py
Normal file
364
apps/colqwen_rag.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
|
||||||
|
python -m apps.colqwen_rag search my_index "How does attention work?"
|
||||||
|
python -m apps.colqwen_rag ask my_index --interactive
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
# Add LEANN packages to path
|
||||||
|
_repo_root = Path(__file__).resolve().parents[1]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
import torch # noqa: E402
|
||||||
|
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
|
||||||
|
from pdf2image import convert_from_path # noqa: E402
|
||||||
|
from PIL import Image # noqa: E402
|
||||||
|
from torch.utils.data import DataLoader # noqa: E402
|
||||||
|
from tqdm import tqdm # noqa: E402
|
||||||
|
|
||||||
|
# Import the existing multi-vector implementation
|
||||||
|
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
|
||||||
|
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class ColQwenRAG:
|
||||||
|
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
|
||||||
|
|
||||||
|
def __init__(self, model_type: str = "colpali"):
|
||||||
|
"""
|
||||||
|
Initialize ColQwen RAG system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: "colqwen2" or "colpali"
|
||||||
|
"""
|
||||||
|
self.model_type = model_type
|
||||||
|
self.device = self._get_device()
|
||||||
|
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
|
||||||
|
if self.device.type == "mps":
|
||||||
|
self.dtype = torch.float32
|
||||||
|
elif self.device.type == "cuda":
|
||||||
|
self.dtype = torch.float16
|
||||||
|
else:
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
|
||||||
|
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
|
||||||
|
|
||||||
|
# Load model and processor with MPS-optimized settings
|
||||||
|
try:
|
||||||
|
if model_type == "colqwen2":
|
||||||
|
self.model_name = "vidore/colqwen2-v1.0"
|
||||||
|
if self.device.type == "mps":
|
||||||
|
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||||
|
self.model = ColQwen2.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
else:
|
||||||
|
self.model = ColQwen2.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map=self.device,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||||
|
else: # colpali
|
||||||
|
self.model_name = "vidore/colpali-v1.2"
|
||||||
|
if self.device.type == "mps":
|
||||||
|
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||||
|
self.model = ColPali.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
else:
|
||||||
|
self.model = ColPali.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map=self.device,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
|
||||||
|
except Exception as e:
|
||||||
|
if "memory" in str(e).lower() or "offload" in str(e).lower():
|
||||||
|
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self.dtype = torch.float32
|
||||||
|
|
||||||
|
if model_type == "colqwen2":
|
||||||
|
self.model = ColQwen2.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
else:
|
||||||
|
self.model = ColPali.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_device(self):
|
||||||
|
"""Auto-select best available device."""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Build multimodal index from PDF files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_paths: List of PDF file paths
|
||||||
|
index_name: Name for the index
|
||||||
|
pages_dir: Directory to save page images (optional)
|
||||||
|
"""
|
||||||
|
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
|
||||||
|
|
||||||
|
# Convert PDFs to images
|
||||||
|
all_images = []
|
||||||
|
all_metadata = []
|
||||||
|
|
||||||
|
if pages_dir:
|
||||||
|
os.makedirs(pages_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
|
||||||
|
try:
|
||||||
|
images = convert_from_path(pdf_path, dpi=150)
|
||||||
|
pdf_name = Path(pdf_path).stem
|
||||||
|
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
# Save image if pages_dir specified
|
||||||
|
if pages_dir:
|
||||||
|
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
|
||||||
|
image.save(image_path)
|
||||||
|
|
||||||
|
all_images.append(image)
|
||||||
|
all_metadata.append(
|
||||||
|
{
|
||||||
|
"pdf_path": pdf_path,
|
||||||
|
"pdf_name": pdf_name,
|
||||||
|
"page_number": i + 1,
|
||||||
|
"image_path": str(image_path) if pages_dir else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error processing {pdf_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
|
||||||
|
print(f"All metadata: {all_metadata}")
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
print("🧠 Generating embeddings...")
|
||||||
|
embeddings = self._embed_images(all_images)
|
||||||
|
|
||||||
|
# Build LEANN index
|
||||||
|
print("🔍 Building LEANN index...")
|
||||||
|
leann_mv = LeannMultiVector(
|
||||||
|
index_path=index_name,
|
||||||
|
dim=embeddings.shape[-1],
|
||||||
|
embedding_model_name=self.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create collection and insert data
|
||||||
|
leann_mv.create_collection()
|
||||||
|
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
|
||||||
|
data = {
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": metadata.get("image_path", ""),
|
||||||
|
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
|
||||||
|
}
|
||||||
|
leann_mv.insert(data)
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
leann_mv.create_index()
|
||||||
|
print(f"✅ Index '{index_name}' built successfully!")
|
||||||
|
|
||||||
|
return leann_mv
|
||||||
|
|
||||||
|
def search(self, index_name: str, query: str, top_k: int = 5):
|
||||||
|
"""
|
||||||
|
Search the index with a text query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name: Name of the index to search
|
||||||
|
query: Text query
|
||||||
|
top_k: Number of results to return
|
||||||
|
"""
|
||||||
|
print(f"🔍 Searching '{index_name}' for: '{query}'")
|
||||||
|
|
||||||
|
# Load index
|
||||||
|
leann_mv = LeannMultiVector(
|
||||||
|
index_path=index_name,
|
||||||
|
dim=128, # Will be updated when loading
|
||||||
|
embedding_model_name=self.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate query embedding
|
||||||
|
query_embedding = self._embed_query(query)
|
||||||
|
|
||||||
|
# Search (returns list of (score, doc_id) tuples)
|
||||||
|
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
print(f"\n📋 Top {len(search_results)} results:")
|
||||||
|
for i, (score, doc_id) in enumerate(search_results, 1):
|
||||||
|
# Get metadata for this doc_id (we need to load the metadata)
|
||||||
|
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
def ask(self, index_name: str, interactive: bool = False):
|
||||||
|
"""
|
||||||
|
Interactive Q&A with the indexed documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name: Name of the index to query
|
||||||
|
interactive: Whether to run in interactive mode
|
||||||
|
"""
|
||||||
|
print(f"💬 ColQwen Chat with '{index_name}'")
|
||||||
|
|
||||||
|
if interactive:
|
||||||
|
print("Type 'quit' to exit, 'help' for commands")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("\n🤔 Your question: ").strip()
|
||||||
|
if query.lower() in ["quit", "exit", "q"]:
|
||||||
|
break
|
||||||
|
elif query.lower() == "help":
|
||||||
|
print("Commands: quit/exit/q (exit), help (this message)")
|
||||||
|
continue
|
||||||
|
elif not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.search(index_name, query, top_k=3)
|
||||||
|
|
||||||
|
# TODO: Add answer generation with Qwen-VL
|
||||||
|
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 Goodbye!")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
query = input("🤔 Your question: ").strip()
|
||||||
|
if query:
|
||||||
|
self.search(index_name, query)
|
||||||
|
|
||||||
|
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
|
||||||
|
"""Generate embeddings for a list of images."""
|
||||||
|
dataset = ListDataset(images)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(dataloader, desc="Embedding images"):
|
||||||
|
batch_images = cast(list, batch)
|
||||||
|
batch_inputs = self.processor.process_images(batch_images).to(self.device)
|
||||||
|
batch_embeddings = self.model(**batch_inputs)
|
||||||
|
embeddings.append(batch_embeddings.cpu())
|
||||||
|
|
||||||
|
return torch.cat(embeddings, dim=0)
|
||||||
|
|
||||||
|
def _embed_query(self, query: str) -> torch.Tensor:
|
||||||
|
"""Generate embedding for a text query."""
|
||||||
|
with torch.no_grad():
|
||||||
|
query_inputs = self.processor.process_queries([query]).to(self.device)
|
||||||
|
query_embedding = self.model(**query_inputs)
|
||||||
|
return query_embedding.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
|
||||||
|
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
|
||||||
|
build_parser.add_argument("--index", required=True, help="Index name")
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||||
|
)
|
||||||
|
build_parser.add_argument("--pages-dir", help="Directory to save page images")
|
||||||
|
|
||||||
|
# Search command
|
||||||
|
search_parser = subparsers.add_parser("search", help="Search the index")
|
||||||
|
search_parser.add_argument("index", help="Index name")
|
||||||
|
search_parser.add_argument("query", help="Search query")
|
||||||
|
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ask command
|
||||||
|
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
|
||||||
|
ask_parser.add_argument("index", help="Index name")
|
||||||
|
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.command:
|
||||||
|
parser.print_help()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize ColQwen RAG
|
||||||
|
if args.command == "build":
|
||||||
|
colqwen = ColQwenRAG(args.model)
|
||||||
|
|
||||||
|
# Get PDF files
|
||||||
|
pdf_dir = Path(args.pdfs)
|
||||||
|
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
|
||||||
|
pdf_paths = [str(pdf_dir)]
|
||||||
|
elif pdf_dir.is_dir():
|
||||||
|
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
|
||||||
|
else:
|
||||||
|
print(f"❌ Invalid PDF path: {args.pdfs}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not pdf_paths:
|
||||||
|
print(f"❌ No PDF files found in {args.pdfs}")
|
||||||
|
return
|
||||||
|
|
||||||
|
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
|
||||||
|
|
||||||
|
elif args.command == "search":
|
||||||
|
colqwen = ColQwenRAG(args.model)
|
||||||
|
colqwen.search(args.index, args.query, args.top_k)
|
||||||
|
|
||||||
|
elif args.command == "ask":
|
||||||
|
colqwen = ColQwenRAG(args.model)
|
||||||
|
colqwen.ask(args.index, args.interactive)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
218
apps/image_rag.py
Normal file
218
apps/image_rag.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
CLIP Image RAG Application
|
||||||
|
|
||||||
|
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||||
|
You can index a directory of images and search them using text queries.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||||
|
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from apps.base_rag_example import BaseRAGExample
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRAG(BaseRAGExample):
|
||||||
|
"""
|
||||||
|
RAG application for images using CLIP embeddings.
|
||||||
|
|
||||||
|
This class provides a complete RAG pipeline for image data, including
|
||||||
|
CLIP embedding generation, indexing, and text-based image search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Image RAG",
|
||||||
|
description="RAG application for images using CLIP embeddings",
|
||||||
|
default_index_name="image_index",
|
||||||
|
)
|
||||||
|
# Override default embedding model to use CLIP
|
||||||
|
self.embedding_model_default = "clip-ViT-L-14"
|
||||||
|
self.embedding_mode_default = "sentence-transformers"
|
||||||
|
self._image_data: list[dict] = []
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add image-specific arguments."""
|
||||||
|
image_group = parser.add_argument_group("Image Parameters")
|
||||||
|
image_group.add_argument(
|
||||||
|
"--image-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing images to index",
|
||||||
|
)
|
||||||
|
image_group.add_argument(
|
||||||
|
"--image-extensions",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||||
|
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||||
|
)
|
||||||
|
image_group.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size for CLIP embedding generation (default: 32)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||||
|
self._image_data = self._load_images_and_embeddings(args)
|
||||||
|
return [entry["text"] for entry in self._image_data]
|
||||||
|
|
||||||
|
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||||
|
"""Helper to process images and produce embeddings/metadata."""
|
||||||
|
image_dir = Path(args.image_dir)
|
||||||
|
if not image_dir.exists():
|
||||||
|
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||||
|
|
||||||
|
print(f"📸 Loading images from {image_dir}...")
|
||||||
|
|
||||||
|
# Find all image files
|
||||||
|
image_files = []
|
||||||
|
for ext in args.image_extensions:
|
||||||
|
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||||
|
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||||
|
|
||||||
|
if not image_files:
|
||||||
|
raise ValueError(
|
||||||
|
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Found {len(image_files)} images")
|
||||||
|
|
||||||
|
# Limit if max_items is set
|
||||||
|
if args.max_items > 0:
|
||||||
|
image_files = image_files[: args.max_items]
|
||||||
|
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||||
|
|
||||||
|
# Load CLIP model
|
||||||
|
print("🔍 Loading CLIP model...")
|
||||||
|
model = SentenceTransformer(self.embedding_model_default)
|
||||||
|
|
||||||
|
# Process images and generate embeddings
|
||||||
|
print("🖼️ Processing images and generating embeddings...")
|
||||||
|
image_data = []
|
||||||
|
batch_images = []
|
||||||
|
batch_paths = []
|
||||||
|
|
||||||
|
for image_path in tqdm(image_files, desc="Processing images"):
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
batch_images.append(image)
|
||||||
|
batch_paths.append(image_path)
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
if len(batch_images) >= args.batch_size:
|
||||||
|
embeddings = model.encode(
|
||||||
|
batch_images,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for img_path, embedding in zip(batch_paths, embeddings):
|
||||||
|
image_data.append(
|
||||||
|
{
|
||||||
|
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||||
|
"metadata": {
|
||||||
|
"image_path": str(img_path),
|
||||||
|
"image_name": img_path.name,
|
||||||
|
"image_dir": str(image_dir),
|
||||||
|
},
|
||||||
|
"embedding": embedding.astype(np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = []
|
||||||
|
batch_paths = []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process remaining images
|
||||||
|
if batch_images:
|
||||||
|
embeddings = model.encode(
|
||||||
|
batch_images,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=len(batch_images),
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for img_path, embedding in zip(batch_paths, embeddings):
|
||||||
|
image_data.append(
|
||||||
|
{
|
||||||
|
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||||
|
"metadata": {
|
||||||
|
"image_path": str(img_path),
|
||||||
|
"image_name": img_path.name,
|
||||||
|
"image_dir": str(image_dir),
|
||||||
|
},
|
||||||
|
"embedding": embedding.astype(np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Processed {len(image_data)} images")
|
||||||
|
return image_data
|
||||||
|
|
||||||
|
async def build_index(self, args, texts: list[str]) -> str:
|
||||||
|
"""Build index using pre-computed CLIP embeddings."""
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if not self._image_data or len(self._image_data) != len(texts):
|
||||||
|
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||||
|
|
||||||
|
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend_name,
|
||||||
|
embedding_model=self.embedding_model_default,
|
||||||
|
embedding_mode=self.embedding_mode_default,
|
||||||
|
is_recompute=False,
|
||||||
|
distance_metric="cosine",
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
build_complexity=args.build_complexity,
|
||||||
|
is_compact=not args.no_compact,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text, data in zip(texts, self._image_data):
|
||||||
|
builder.add_text(text=text, metadata=data["metadata"])
|
||||||
|
|
||||||
|
ids = [str(i) for i in range(len(self._image_data))]
|
||||||
|
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||||
|
pickle.dump((ids, embeddings), f)
|
||||||
|
pkl_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||||
|
print(f"✅ Index built successfully at {index_path}")
|
||||||
|
return index_path
|
||||||
|
finally:
|
||||||
|
Path(pkl_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for the image RAG application."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
app = ImageRAG()
|
||||||
|
asyncio.run(app.run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import concurrent.futures
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -45,6 +46,7 @@ class LeannMultiVector:
|
|||||||
"is_recompute": is_recompute,
|
"is_recompute": is_recompute,
|
||||||
}
|
}
|
||||||
self._labels_meta: list[dict] = []
|
self._labels_meta: list[dict] = []
|
||||||
|
self._docid_to_indices: dict[int, list[int]] | None = None
|
||||||
|
|
||||||
def _meta_dict(self) -> dict:
|
def _meta_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@@ -80,6 +82,10 @@ class LeannMultiVector:
|
|||||||
index_path_obj = Path(self.index_path)
|
index_path_obj = Path(self.index_path)
|
||||||
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||||
|
|
||||||
|
def _embeddings_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
|
||||||
|
|
||||||
def create_index(self) -> None:
|
def create_index(self) -> None:
|
||||||
if not self._pending_items:
|
if not self._pending_items:
|
||||||
return
|
return
|
||||||
@@ -121,6 +127,9 @@ class LeannMultiVector:
|
|||||||
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||||
_json.dump(labels_meta, f)
|
_json.dump(labels_meta, f)
|
||||||
|
|
||||||
|
# Persist embeddings for exact reranking
|
||||||
|
np.save(self._embeddings_path(), embeddings_np)
|
||||||
|
|
||||||
self._labels_meta = labels_meta
|
self._labels_meta = labels_meta
|
||||||
|
|
||||||
def _load_labels_meta_if_needed(self) -> None:
|
def _load_labels_meta_if_needed(self) -> None:
|
||||||
@@ -133,6 +142,19 @@ class LeannMultiVector:
|
|||||||
with open(labels_path, encoding="utf-8") as f:
|
with open(labels_path, encoding="utf-8") as f:
|
||||||
self._labels_meta = _json.load(f)
|
self._labels_meta = _json.load(f)
|
||||||
|
|
||||||
|
def _build_docid_to_indices_if_needed(self) -> None:
|
||||||
|
if self._docid_to_indices is not None:
|
||||||
|
return
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
mapping: dict[int, list[int]] = {}
|
||||||
|
for idx, meta in enumerate(self._labels_meta):
|
||||||
|
try:
|
||||||
|
doc_id = int(meta["doc_id"]) # type: ignore[index]
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
mapping.setdefault(doc_id, []).append(idx)
|
||||||
|
self._docid_to_indices = mapping
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||||
) -> list[tuple[float, int]]:
|
) -> list[tuple[float, int]]:
|
||||||
@@ -180,3 +202,139 @@ class LeannMultiVector:
|
|||||||
|
|
||||||
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||||
return scores[:topk] if len(scores) >= topk else scores
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
|
|
||||||
|
def search_exact(
|
||||||
|
self,
|
||||||
|
data: np.ndarray,
|
||||||
|
topk: int,
|
||||||
|
*,
|
||||||
|
first_stage_k: int = 200,
|
||||||
|
max_workers: int = 32,
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
|
"""
|
||||||
|
High-precision MaxSim reranking over candidate documents.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
|
||||||
|
2) For each candidate doc, load all its token embeddings and compute
|
||||||
|
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
|
||||||
|
|
||||||
|
Returns top-k list of (score, doc_id).
|
||||||
|
"""
|
||||||
|
# Normalize inputs
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
self._build_docid_to_indices_if_needed()
|
||||||
|
|
||||||
|
emb_path = self._embeddings_path()
|
||||||
|
if not emb_path.exists():
|
||||||
|
# Fallback to approximate if we don't have persisted embeddings
|
||||||
|
return self.search(data, topk, first_stage_k=first_stage_k)
|
||||||
|
|
||||||
|
# Memory-map embeddings to avoid loading all into RAM
|
||||||
|
all_embeddings = np.load(emb_path, mmap_mode="r")
|
||||||
|
if all_embeddings.dtype != np.float32:
|
||||||
|
all_embeddings = all_embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
# First-stage ANN to collect candidate doc_ids
|
||||||
|
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||||
|
raw = searcher.search(
|
||||||
|
data,
|
||||||
|
first_stage_k,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
complexity=128,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
batch_size=0,
|
||||||
|
)
|
||||||
|
labels = raw.get("labels")
|
||||||
|
if labels is None:
|
||||||
|
return []
|
||||||
|
candidate_doc_ids: set[int] = set()
|
||||||
|
for batch in labels:
|
||||||
|
for sid in batch:
|
||||||
|
try:
|
||||||
|
idx = int(sid)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if 0 <= idx < len(self._labels_meta):
|
||||||
|
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
|
||||||
|
|
||||||
|
# Exact scoring per doc (parallelized)
|
||||||
|
assert self._docid_to_indices is not None
|
||||||
|
|
||||||
|
def _score_one(doc_id: int) -> tuple[float, int]:
|
||||||
|
token_indices = self._docid_to_indices.get(doc_id, [])
|
||||||
|
if not token_indices:
|
||||||
|
return (0.0, doc_id)
|
||||||
|
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
||||||
|
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
|
||||||
|
sim = np.dot(data, doc_vecs.T)
|
||||||
|
# nan-safe
|
||||||
|
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
||||||
|
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
||||||
|
return (float(score), doc_id)
|
||||||
|
|
||||||
|
scores: list[tuple[float, int]] = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||||
|
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
|
||||||
|
for fut in concurrent.futures.as_completed(futures):
|
||||||
|
scores.append(fut.result())
|
||||||
|
|
||||||
|
scores.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
|
|
||||||
|
def search_exact_all(
|
||||||
|
self,
|
||||||
|
data: np.ndarray,
|
||||||
|
topk: int,
|
||||||
|
*,
|
||||||
|
max_workers: int = 32,
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
|
"""
|
||||||
|
Exact MaxSim over ALL documents (no ANN pre-filtering).
|
||||||
|
|
||||||
|
This computes, for each document, sum_i max_j dot(q_i, d_j).
|
||||||
|
It memory-maps the persisted token-embedding matrix for scalability.
|
||||||
|
"""
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
self._build_docid_to_indices_if_needed()
|
||||||
|
|
||||||
|
emb_path = self._embeddings_path()
|
||||||
|
if not emb_path.exists():
|
||||||
|
return self.search(data, topk)
|
||||||
|
|
||||||
|
all_embeddings = np.load(emb_path, mmap_mode="r")
|
||||||
|
if all_embeddings.dtype != np.float32:
|
||||||
|
all_embeddings = all_embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
assert self._docid_to_indices is not None
|
||||||
|
candidate_doc_ids = list(self._docid_to_indices.keys())
|
||||||
|
|
||||||
|
def _score_one(doc_id: int) -> tuple[float, int]:
|
||||||
|
token_indices = self._docid_to_indices.get(doc_id, [])
|
||||||
|
if not token_indices:
|
||||||
|
return (0.0, doc_id)
|
||||||
|
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
||||||
|
sim = np.dot(data, doc_vecs.T)
|
||||||
|
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
||||||
|
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
||||||
|
return (float(score), doc_id)
|
||||||
|
|
||||||
|
scores: list[tuple[float, int]] = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||||
|
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
|
||||||
|
for fut in concurrent.futures.as_completed(futures):
|
||||||
|
scores.append(fut.result())
|
||||||
|
|
||||||
|
scores.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# %%
|
# %%
|
||||||
# uv pip install matplotlib qwen_vl_utils
|
# uv pip install matplotlib qwen_vl_utils
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -230,12 +231,18 @@ def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) ->
|
|||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
def _load_retriever_if_index_exists(index_path: str) -> Optional[LeannMultiVector]:
|
||||||
index_base = Path(index_path)
|
index_base = Path(index_path)
|
||||||
# Rough heuristic: index dir exists AND meta+labels files exist
|
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||||
meta = index_base.parent / f"{index_base.name}.meta.json"
|
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||||
labels = index_base.parent / f"{index_base.name}.labels.json"
|
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||||
if index_base.exists() and meta.exists() and labels.exists():
|
if index_base.exists() and meta.exists() and labels.exists():
|
||||||
|
try:
|
||||||
|
with open(meta, "r", encoding="utf-8") as f:
|
||||||
|
meta_json = json.load(f)
|
||||||
|
dim = int(meta_json.get("dimensions", 128))
|
||||||
|
except Exception:
|
||||||
|
dim = 128
|
||||||
return LeannMultiVector(index_path=index_path, dim=dim)
|
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -390,11 +397,7 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
|||||||
# Step 3: Build or load index
|
# Step 3: Build or load index
|
||||||
retriever: Optional[LeannMultiVector] = None
|
retriever: Optional[LeannMultiVector] = None
|
||||||
if not REBUILD_INDEX:
|
if not REBUILD_INDEX:
|
||||||
try:
|
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||||
one_vec = _embed_images(model, processor, [images[0]])[0]
|
|
||||||
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
|
||||||
except Exception:
|
|
||||||
retriever = None
|
|
||||||
|
|
||||||
if retriever is None:
|
if retriever is None:
|
||||||
doc_vecs = _embed_images(model, processor, images)
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
|||||||
143
benchmarks/update/README.md
Normal file
143
benchmarks/update/README.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# Update Benchmarks
|
||||||
|
|
||||||
|
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||||
|
search” pipeline under different assumptions:
|
||||||
|
|
||||||
|
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||||
|
settings influence incremental `add()` latency when embeddings are fetched
|
||||||
|
over the ZMQ embedding server.
|
||||||
|
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||||
|
against an offline approach that keeps the graph static and fuses results.
|
||||||
|
|
||||||
|
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||||
|
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||||
|
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
### 1. HNSW RNG Recompute Benchmark
|
||||||
|
|
||||||
|
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||||
|
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||||
|
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||||
|
is enabled:
|
||||||
|
|
||||||
|
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||||
|
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||||
|
| `baseline` | Enabled | Enabled | Enabled |
|
||||||
|
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||||
|
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||||
|
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||||
|
|
||||||
|
For each scenario the script:
|
||||||
|
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||||
|
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||||
|
3. Appends the requested updates using the scenario’s RNG flags.
|
||||||
|
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||||
|
timings before appending a row to the CSV output.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||||
|
LEANN_LOG_LEVEL=INFO \
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--runs 1 \
|
||||||
|
--index-path .leann/bench/test.leann \
|
||||||
|
--initial-files data/PrideandPrejudice.txt \
|
||||||
|
--update-files data/huawei_pangu.md \
|
||||||
|
--max-initial 300 \
|
||||||
|
--max-updates 1 \
|
||||||
|
--add-timeout 120
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||||
|
(including ms/passage) for each run.
|
||||||
|
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||||
|
`LEANN_HNSW_LOG_PATH`).
|
||||||
|
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||||
|
|
||||||
|
### 2. Sequential vs. Offline Update Benchmark
|
||||||
|
|
||||||
|
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||||
|
same dataset:
|
||||||
|
|
||||||
|
- **Scenario A – Sequential Update**
|
||||||
|
- Start an embedding server.
|
||||||
|
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||||
|
mutates the HNSW graph.
|
||||||
|
- After all inserts, run a search on the updated graph.
|
||||||
|
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||||
|
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||||
|
latency.
|
||||||
|
|
||||||
|
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||||
|
- Stop Scenario A’s server and start a fresh embedding server.
|
||||||
|
- Spawn two threads: one generates embeddings for the new passages offline
|
||||||
|
(graph unchanged); the other computes the query embedding and searches the
|
||||||
|
existing graph.
|
||||||
|
- Merge offline similarities with the graph search results to emulate late
|
||||||
|
fusion, then report the merged top‑k preview.
|
||||||
|
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||||
|
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||||
|
|
||||||
|
**Run (both scenarios):**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 \
|
||||||
|
--num-updates 1
|
||||||
|
```
|
||||||
|
|
||||||
|
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||||
|
print timing summaries to stdout and append the results to CSV.
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||||
|
Scenario A and B.
|
||||||
|
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||||
|
checks.
|
||||||
|
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||||
|
|
||||||
|
### 3. Visualisation
|
||||||
|
|
||||||
|
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||||
|
benchmark into a single two-panel plot.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.plot_bench_results \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||||
|
- `--csv` – RNG benchmark results CSV (left panel).
|
||||||
|
- `--csv-right` – Update strategy results CSV (right panel).
|
||||||
|
- `--out` – Output image path (PNG/PDF supported).
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||||
|
suites.
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||||
|
slides/papers.
|
||||||
|
|
||||||
|
## Parameters & Environment
|
||||||
|
|
||||||
|
### Common CLI Flags
|
||||||
|
- `--max-initial` – Number of initial passages used to seed the index.
|
||||||
|
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||||
|
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||||
|
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||||
|
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||||
|
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||||
|
execution of the embedding model.
|
||||||
|
|
||||||
|
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||||
|
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||||
|
fusion better match your latency/accuracy trade-offs.
|
||||||
16
benchmarks/update/__init__.py
Normal file
16
benchmarks/update/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Benchmarks for LEANN update workflows."""
|
||||||
|
|
||||||
|
# Expose helper to locate repository root for other modules that need it.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_repo_root() -> Path:
|
||||||
|
"""Return the project root containing pyproject.toml."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
return current.parents[1]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["find_repo_root"]
|
||||||
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
@@ -0,0 +1,804 @@
|
|||||||
|
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||||
|
embedding recomputation.
|
||||||
|
|
||||||
|
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||||
|
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||||
|
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||||
|
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||||
|
|
||||||
|
Example usage (run from the repo root; downloads the model on first run)::
|
||||||
|
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--index-path .leann/bench/leann-demo.leann \
|
||||||
|
--runs 1
|
||||||
|
|
||||||
|
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||||
|
if you want a larger or different workload, and change the embedding model via
|
||||||
|
``--model-name``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||||
|
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_update_with_mode(
|
||||||
|
index_path: Path,
|
||||||
|
new_chunks: list[dict[str, Any]],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
disable_forward_rng: bool,
|
||||||
|
disable_reverse_rng: bool,
|
||||||
|
server_port: int,
|
||||||
|
add_timeout: int,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||||
|
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
with open(offset_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
existing_ids = set(offset_map.keys())
|
||||||
|
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
for chunk in new_chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
continue
|
||||||
|
metadata = chunk.setdefault("metadata", {})
|
||||||
|
passage_id = chunk.get("id") or metadata.get("id")
|
||||||
|
if passage_id and passage_id in existing_ids:
|
||||||
|
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
|
||||||
|
if not valid_chunks:
|
||||||
|
raise ValueError("No valid chunks to append.")
|
||||||
|
|
||||||
|
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts_to_embed,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
index = faiss.read_index(str(index_file))
|
||||||
|
index.is_recompute = True
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
try:
|
||||||
|
storage_index.ntotal = index.ntotal
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||||
|
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||||
|
if ef_construction is not None:
|
||||||
|
index.hnsw.efConstruction = ef_construction
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||||
|
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||||
|
logger.info(
|
||||||
|
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||||
|
disable_forward_rng,
|
||||||
|
disable_reverse_rng,
|
||||||
|
applied_forward,
|
||||||
|
applied_reverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_id = index.ntotal
|
||||||
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
|
new_id = str(base_id + offset)
|
||||||
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
|
chunk["id"] = new_id
|
||||||
|
|
||||||
|
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||||
|
offset_map_backup = offset_map.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for chunk in valid_chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk.get("metadata", {}),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
server_started, actual_port = server_manager.start_server(
|
||||||
|
port=server_port,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
)
|
||||||
|
if not server_started:
|
||||||
|
raise RuntimeError("Failed to start embedding server.")
|
||||||
|
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
|
_warmup_embedding_server(actual_port)
|
||||||
|
|
||||||
|
total_start = time.time()
|
||||||
|
add_elapsed = 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def _timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("incremental add timed out")
|
||||||
|
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||||
|
signal.alarm(add_timeout)
|
||||||
|
|
||||||
|
add_start = time.time()
|
||||||
|
for i in range(embeddings.shape[0]):
|
||||||
|
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||||
|
add_elapsed = time.time() - add_start
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.alarm(0)
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
if passages_file.exists():
|
||||||
|
with open(passages_file, "rb+") as f:
|
||||||
|
f.truncate(rollback_size)
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map_backup, f)
|
||||||
|
raise
|
||||||
|
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(False)
|
||||||
|
index.hnsw.set_disable_reverse_prune(False)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
total_elapsed = time.time() - total_start
|
||||||
|
|
||||||
|
return total_elapsed, add_elapsed
|
||||||
|
|
||||||
|
|
||||||
|
def _total_zmq_nodes(log_path: Path) -> int:
|
||||||
|
if not log_path.exists():
|
||||||
|
return 0
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
text = log_file.read()
|
||||||
|
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||||
|
|
||||||
|
|
||||||
|
def _warmup_embedding_server(port: int) -> None:
|
||||||
|
"""Send a dummy REQ so the embedding server loads its model."""
|
||||||
|
ctx = zmq.Context()
|
||||||
|
try:
|
||||||
|
sock = ctx.socket(zmq.REQ)
|
||||||
|
sock.setsockopt(zmq.LINGER, 0)
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||||
|
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||||
|
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||||
|
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||||
|
sock.send(payload)
|
||||||
|
try:
|
||||||
|
sock.recv()
|
||||||
|
except zmq.error.Again:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
ctx.term()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/leann-demo.leann"),
|
||||||
|
help="Output index base path (without extension).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Files used to build the initial index.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Files appended during the benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model used for build/update.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
default="sentence-transformers",
|
||||||
|
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
help="Distance metric for HNSW backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-construction",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="efConstruction setting for initial build.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-port",
|
||||||
|
type=int,
|
||||||
|
default=5557,
|
||||||
|
help="Port for the real embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-initial",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="Optional cap on initial passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-updates",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Optional cap on update passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--add-timeout",
|
||||||
|
type=int,
|
||||||
|
default=900,
|
||||||
|
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plot-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("bench_latency.png"),
|
||||||
|
help="Where to save the latency bar plot.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Where to append per-scenario results as CSV.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||||
|
|
||||||
|
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
|
||||||
|
scenarios = [
|
||||||
|
("baseline", False, False, True),
|
||||||
|
("no_cache_baseline", False, False, False),
|
||||||
|
("disable_forward_rng", True, False, True),
|
||||||
|
("disable_forward_and_reverse_rng", True, True, True),
|
||||||
|
]
|
||||||
|
|
||||||
|
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||||
|
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
import csv
|
||||||
|
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"cache_enabled",
|
||||||
|
"ef_construction",
|
||||||
|
"max_initial",
|
||||||
|
"max_updates",
|
||||||
|
"total_time_s",
|
||||||
|
"add_only_s",
|
||||||
|
"latency_ms_per_passage",
|
||||||
|
"zmq_nodes",
|
||||||
|
"stageA_time_s",
|
||||||
|
"stageBC_time_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
# Create CSV with header if missing
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for run in range(args.runs):
|
||||||
|
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||||
|
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||||
|
print(f"\nScenario: {name}")
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
if log_path.exists():
|
||||||
|
try:
|
||||||
|
log_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||||
|
args.index_path,
|
||||||
|
update_chunks,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
disable_forward,
|
||||||
|
disable_reverse,
|
||||||
|
args.server_port,
|
||||||
|
args.add_timeout,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
print(f"Scenario {name} timed out: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
if curr_size < prev_size:
|
||||||
|
prev_size = 0
|
||||||
|
zmq_count = 0
|
||||||
|
if log_path.exists():
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
log_file.seek(prev_size)
|
||||||
|
new_entries = log_file.read()
|
||||||
|
zmq_count = sum(
|
||||||
|
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||||
|
)
|
||||||
|
stageA = sum(
|
||||||
|
float(x)
|
||||||
|
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
stageBC = sum(
|
||||||
|
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stageA = 0.0
|
||||||
|
stageBC = 0.0
|
||||||
|
|
||||||
|
per_chunk = add_elapsed / len(update_chunks)
|
||||||
|
print(
|
||||||
|
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||||
|
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||||
|
)
|
||||||
|
print(f"ZMQ node fetch total: {zmq_count}")
|
||||||
|
results_total[name].append(total_elapsed)
|
||||||
|
results_add[name].append(add_elapsed)
|
||||||
|
results_zmq[name].append(zmq_count)
|
||||||
|
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||||
|
results_stageA[name].append(stageA)
|
||||||
|
results_stageBC[name].append(stageBC)
|
||||||
|
|
||||||
|
# Append row to CSV
|
||||||
|
if args.csv_path:
|
||||||
|
row = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": name,
|
||||||
|
"cache_enabled": 1 if cache_enabled else 0,
|
||||||
|
"ef_construction": args.ef_construction,
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"max_updates": args.max_updates,
|
||||||
|
"total_time_s": round(total_elapsed, 6),
|
||||||
|
"add_only_s": round(add_elapsed, 6),
|
||||||
|
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||||
|
"zmq_nodes": int(zmq_count),
|
||||||
|
"stageA_time_s": round(stageA, 6),
|
||||||
|
"stageBC_time_s": round(stageBC, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
for name in results_add:
|
||||||
|
add_values = results_add[name]
|
||||||
|
total_values = results_total[name]
|
||||||
|
zmq_values = results_zmq[name]
|
||||||
|
latency_values = results_ms_per_passage[name]
|
||||||
|
if not add_values:
|
||||||
|
print(f"{name}: no successful runs")
|
||||||
|
continue
|
||||||
|
avg_add = sum(add_values) / len(add_values)
|
||||||
|
avg_total = sum(total_values) / len(total_values)
|
||||||
|
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||||
|
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||||
|
runs = len(add_values)
|
||||||
|
print(
|
||||||
|
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||||
|
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.plot_path:
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
labels = [name for name, *_ in scenarios]
|
||||||
|
values = [
|
||||||
|
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||||
|
if results_ms_per_passage[name]
|
||||||
|
else 0.0
|
||||||
|
for name in labels
|
||||||
|
]
|
||||||
|
|
||||||
|
def _auto_cap(vals: list[float]) -> float | None:
|
||||||
|
s = sorted(vals, reverse=True)
|
||||||
|
if len(s) < 2:
|
||||||
|
return None
|
||||||
|
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||||
|
return s[1] * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||||
|
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.4, 5.0),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||||
|
)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap * 0.02,
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False)
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.set_xticks(range(len(labels)))
|
||||||
|
ax_bottom.set_xticklabels(labels)
|
||||||
|
ax = ax_bottom
|
||||||
|
else:
|
||||||
|
cap = args.cap_y or _auto_cap(values)
|
||||||
|
plt.figure(figsize=(7.2, 4.2))
|
||||||
|
ax = plt.gca()
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||||
|
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(b[0])
|
||||||
|
if v > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
ax.plot(
|
||||||
|
[0.02 - 0.02, 0.02 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
ax.plot(
|
||||||
|
[0.98 - 0.02, 0.98 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
if any(v > cap for v in values):
|
||||||
|
ax.legend(
|
||||||
|
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||||
|
)
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels)
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||||
|
|
||||||
|
plt.ylabel("Average add latency (ms per passage)")
|
||||||
|
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(args.plot_path)
|
||||||
|
print(f"Saved latency bar plot to {args.plot_path}")
|
||||||
|
# ZMQ time split (Stage A vs B/C)
|
||||||
|
try:
|
||||||
|
plt.figure(figsize=(6, 4))
|
||||||
|
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||||
|
bc_vals = [
|
||||||
|
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||||
|
]
|
||||||
|
ind = range(len(labels))
|
||||||
|
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||||
|
plt.bar(
|
||||||
|
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||||
|
)
|
||||||
|
plt.xticks(list(ind), labels, rotation=10)
|
||||||
|
plt.ylabel("Server ZMQ time (s)")
|
||||||
|
plt.title(
|
||||||
|
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||||
|
)
|
||||||
|
plt.legend()
|
||||||
|
out2 = args.plot_path.with_name(
|
||||||
|
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||||
|
)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out2)
|
||||||
|
print(f"Saved ZMQ time split plot to {out2}")
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to plot ZMQ split:", e)
|
||||||
|
except ImportError:
|
||||||
|
print("matplotlib not available; skipping plot generation")
|
||||||
|
|
||||||
|
# leave the last build on disk for inspection
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/bench_results.csv
Normal file
5
benchmarks/update/bench_results.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
@@ -0,0 +1,704 @@
|
|||||||
|
"""
|
||||||
|
Compare two latency models for small incremental updates vs. search:
|
||||||
|
|
||||||
|
Scenario A (sequential update then search):
|
||||||
|
- Build initial HNSW (is_recompute=True)
|
||||||
|
- Start embedding server (ZMQ) for recompute
|
||||||
|
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||||
|
- Then run a search query on the updated index
|
||||||
|
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||||
|
|
||||||
|
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||||
|
- Do NOT insert the N passages into the graph
|
||||||
|
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||||
|
embedding and run a search on the existing index
|
||||||
|
- After both finish, compute similarity between the query embedding and the N
|
||||||
|
new passage embeddings, merge with the index search results by score, and
|
||||||
|
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||||
|
|
||||||
|
This script reuses the model/data loading conventions of
|
||||||
|
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||||
|
comparison for the two execution strategies above.
|
||||||
|
|
||||||
|
Example (from the repository root):
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 --num-updates 5 --k 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import psutil # type: ignore
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||||
|
if metric == "cosine":
|
||||||
|
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||||
|
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
vecs = vecs / norms
|
||||||
|
return vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _read_index_for_search(index_path: Path) -> Any:
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
# Force-disable experimental disk cache when loading the index so that
|
||||||
|
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||||
|
cfg = faiss.HNSWIndexConfig()
|
||||||
|
cfg.is_recompute = True
|
||||||
|
if hasattr(cfg, "disk_cache_ratio"):
|
||||||
|
cfg.disk_cache_ratio = 0.0
|
||||||
|
if hasattr(cfg, "external_storage_path"):
|
||||||
|
cfg.external_storage_path = None
|
||||||
|
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||||
|
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||||
|
# ensure recompute mode persists after reload
|
||||||
|
try:
|
||||||
|
index.is_recompute = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
actual_ntotal = index.hnsw.levels.size()
|
||||||
|
except AttributeError:
|
||||||
|
actual_ntotal = index.ntotal
|
||||||
|
if actual_ntotal != index.ntotal:
|
||||||
|
print(
|
||||||
|
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
index.ntotal = actual_ntotal
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def _append_passages_for_updates(
|
||||||
|
meta_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
index_dir = meta_path.parent
|
||||||
|
meta_name = meta_path.name
|
||||||
|
if not meta_name.endswith(".meta.json"):
|
||||||
|
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||||
|
index_base = meta_name[: -len(".meta.json")]
|
||||||
|
|
||||||
|
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||||
|
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not offsets_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passage store missing; cannot register update passages for recompute mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(offsets_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
|
||||||
|
assigned_ids: list[str] = []
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
passage_id = str(start_id + i)
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[passage_id] = offset
|
||||||
|
assigned_ids.append(passage_id)
|
||||||
|
|
||||||
|
with open(offsets_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
meta = {}
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
return assigned_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||||
|
distances = np.zeros((1, k), dtype=np.float32)
|
||||||
|
indices = np.zeros((1, k), dtype=np.int64)
|
||||||
|
index.search(
|
||||||
|
1,
|
||||||
|
faiss.swig_ptr(q),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(indices),
|
||||||
|
)
|
||||||
|
return distances[0], indices[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _score_for_metric(dist: float, metric: str) -> float:
|
||||||
|
# Convert FAISS distance to a "higher is better" score
|
||||||
|
if metric in ("mips", "cosine"):
|
||||||
|
return float(dist)
|
||||||
|
# l2 distance (smaller better) -> negative distance as score
|
||||||
|
return -float(dist)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray],
|
||||||
|
offline_scores: list[tuple[int, float]],
|
||||||
|
k: int,
|
||||||
|
metric: str,
|
||||||
|
) -> list[tuple[str, float]]:
|
||||||
|
distances, indices = index_results
|
||||||
|
merged: list[tuple[str, float]] = []
|
||||||
|
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||||
|
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||||
|
for j, s in offline_scores:
|
||||||
|
merged.append((f"offline:{j}", s))
|
||||||
|
merged.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return merged[:k]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScenarioResult:
|
||||||
|
name: str
|
||||||
|
update_total_s: float
|
||||||
|
search_s: float
|
||||||
|
overall_s: float
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-initial", type=int, default=300)
|
||||||
|
parser.add_argument("--num-updates", type=int, default=5)
|
||||||
|
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="neural network",
|
||||||
|
help="Query text used for the search benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--server-port", type=int, default=5557)
|
||||||
|
parser.add_argument("--add-timeout", type=int, default=600)
|
||||||
|
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
)
|
||||||
|
parser.add_argument("--ef-construction", type=int, default=200)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only",
|
||||||
|
choices=["A", "B", "both"],
|
||||||
|
default="both",
|
||||||
|
help="Run only Scenario A, Scenario B, or both",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Where to append results (CSV).",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages loaded from --update-files")
|
||||||
|
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||||
|
if len(update_paragraphs) < args.num_updates:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
|
||||||
|
# Build initial index
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare index object and meta
|
||||||
|
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||||
|
index = _read_index_for_search(args.index_path)
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"max_initial",
|
||||||
|
"num_updates",
|
||||||
|
"k",
|
||||||
|
"total_time_s",
|
||||||
|
"add_total_s",
|
||||||
|
"search_time_s",
|
||||||
|
"emb_time_s",
|
||||||
|
"makespan_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
# Debug: list existing HNSW server PIDs before starting
|
||||||
|
try:
|
||||||
|
existing = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if existing:
|
||||||
|
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||||
|
for p in existing:
|
||||||
|
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||||
|
except Exception as _e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
add_total = 0.0
|
||||||
|
search_after_add = 0.0
|
||||||
|
total_seq = 0.0
|
||||||
|
port_a = None
|
||||||
|
if args.only in ("A", "both"):
|
||||||
|
# Scenario A: sequential update then search
|
||||||
|
start_id = index.ntotal
|
||||||
|
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||||
|
if assigned_ids:
|
||||||
|
logger.debug(
|
||||||
|
"Registered %d update passages starting at id %s",
|
||||||
|
len(assigned_ids),
|
||||||
|
assigned_ids[0],
|
||||||
|
)
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
ok, port = server_manager.start_server(
|
||||||
|
port=args.server_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError("Failed to start embedding server")
|
||||||
|
try:
|
||||||
|
# Set ZMQ port for recompute mode
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(port)
|
||||||
|
|
||||||
|
# Start A overall timer BEFORE computing update embeddings
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Compute embeddings for updates (counted into A's overall)
|
||||||
|
t_emb0 = time.time()
|
||||||
|
upd_embs = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time_updates = time.time() - t_emb0
|
||||||
|
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||||
|
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||||
|
|
||||||
|
# Perform sequential adds
|
||||||
|
for i in range(upd_embs.shape[0]):
|
||||||
|
t_add0 = time.time()
|
||||||
|
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||||
|
add_total += time.time() - t_add0
|
||||||
|
# Don't persist index after adds to avoid contaminating Scenario B
|
||||||
|
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||||
|
# faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
# Search after updates
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||||
|
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||||
|
|
||||||
|
# Warm up search with a dummy query first
|
||||||
|
print("[DEBUG] Warming up search...")
|
||||||
|
_ = _search(index, q_emb, 1)
|
||||||
|
|
||||||
|
t_s0 = time.time()
|
||||||
|
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||||
|
search_after_add = time.time() - t_s0
|
||||||
|
total_seq = time.time() - t0
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
port_a = port
|
||||||
|
|
||||||
|
print("\n=== Scenario A: update->search (sequential) ===")
|
||||||
|
# emb_time_updates is defined only when A runs
|
||||||
|
try:
|
||||||
|
_emb_a = emb_time_updates
|
||||||
|
except NameError:
|
||||||
|
_emb_a = 0.0
|
||||||
|
print(
|
||||||
|
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||||
|
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||||
|
)
|
||||||
|
# CSV row for A
|
||||||
|
if args.csv_path:
|
||||||
|
row_a = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "A",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": round(total_seq, 6),
|
||||||
|
"add_total_s": round(add_total, 6),
|
||||||
|
"search_time_s": round(search_after_add, 6),
|
||||||
|
"emb_time_s": round(_emb_a, 6),
|
||||||
|
"makespan_s": 0.0,
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_a)
|
||||||
|
|
||||||
|
# Verify server cleanup
|
||||||
|
try:
|
||||||
|
# short sleep to allow signal handling to finish
|
||||||
|
time.sleep(0.5)
|
||||||
|
leftovers = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if leftovers:
|
||||||
|
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||||
|
for p in leftovers:
|
||||||
|
print(
|
||||||
|
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||||
|
if args.only in ("B", "both"):
|
||||||
|
# ensure a server is available for recompute search
|
||||||
|
server_manager_b = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
requested_port = args.server_port if port_a is None else port_a
|
||||||
|
ok_b, port_b = server_manager_b.start_server(
|
||||||
|
port=requested_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok_b:
|
||||||
|
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||||
|
|
||||||
|
# Wait for server to fully initialize
|
||||||
|
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read the index first
|
||||||
|
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||||
|
|
||||||
|
# Then configure ZMQ port on the correct index object
|
||||||
|
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||||
|
index_no_update.hnsw.set_zmq_port(port_b)
|
||||||
|
elif hasattr(index_no_update, "set_zmq_port"):
|
||||||
|
index_no_update.set_zmq_port(port_b)
|
||||||
|
|
||||||
|
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||||
|
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||||
|
logger.info("Warming up embedding model for Scenario B...")
|
||||||
|
_ = compute_embeddings(
|
||||||
|
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare worker A: compute embeddings for the same N passages
|
||||||
|
emb_time = 0.0
|
||||||
|
updates_embs_offline: np.ndarray | None = None
|
||||||
|
|
||||||
|
def _worker_emb():
|
||||||
|
nonlocal emb_time, updates_embs_offline
|
||||||
|
t = time.time()
|
||||||
|
updates_embs_offline = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time = time.time() - t
|
||||||
|
|
||||||
|
# Pre-compute query embedding and warm up search outside of timed section.
|
||||||
|
q_vec = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||||
|
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||||
|
print("[DEBUG B] Warming up search...")
|
||||||
|
_ = _search(index_no_update, q_vec, 1)
|
||||||
|
|
||||||
|
# Worker B: timed search on the warmed index
|
||||||
|
search_time = 0.0
|
||||||
|
offline_elapsed = 0.0
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||||
|
|
||||||
|
def _worker_search():
|
||||||
|
nonlocal search_time, index_results
|
||||||
|
t = time.time()
|
||||||
|
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||||
|
search_time = time.time() - t
|
||||||
|
index_results = (distances, indices)
|
||||||
|
|
||||||
|
# Run two workers concurrently
|
||||||
|
t0 = time.time()
|
||||||
|
th1 = threading.Thread(target=_worker_emb)
|
||||||
|
th2 = threading.Thread(target=_worker_search)
|
||||||
|
th1.start()
|
||||||
|
th2.start()
|
||||||
|
th1.join()
|
||||||
|
th2.join()
|
||||||
|
offline_elapsed = time.time() - t0
|
||||||
|
|
||||||
|
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||||
|
offline_scores: list[tuple[int, float]] = []
|
||||||
|
if updates_embs_offline is not None:
|
||||||
|
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||||
|
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||||
|
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||||
|
for j in range(upd2.shape[0]):
|
||||||
|
if args.distance_metric in ("mips", "cosine"):
|
||||||
|
s = float(np.dot(q_vec[0], upd2[j]))
|
||||||
|
else:
|
||||||
|
diff = q_vec[0] - upd2[j]
|
||||||
|
s = -float(np.dot(diff, diff))
|
||||||
|
offline_scores.append((j, s))
|
||||||
|
|
||||||
|
merged_topk = (
|
||||||
|
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||||
|
if index_results
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||||
|
print(
|
||||||
|
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||||
|
)
|
||||||
|
if merged_topk:
|
||||||
|
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||||
|
print(f"Merged top-5 preview: {preview}")
|
||||||
|
# CSV row for B
|
||||||
|
if args.csv_path:
|
||||||
|
row_b = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "B",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": 0.0,
|
||||||
|
"add_total_s": 0.0,
|
||||||
|
"search_time_s": round(search_time, 6),
|
||||||
|
"emb_time_s": round(emb_time, 6),
|
||||||
|
"makespan_s": round(offline_elapsed, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_b)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
server_manager_b.stop_server()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
msg_a = (
|
||||||
|
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||||
|
if args.only in ("A", "both")
|
||||||
|
else "A: skipped"
|
||||||
|
)
|
||||||
|
msg_b = (
|
||||||
|
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||||
|
if args.only in ("B", "both")
|
||||||
|
else "B: skipped"
|
||||||
|
)
|
||||||
|
print(msg_a + "\n" + msg_b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/offline_vs_update.csv
Normal file
5
benchmarks/update/offline_vs_update.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
645
benchmarks/update/plot_bench_results.py
Normal file
645
benchmarks/update/plot_bench_results.py
Normal file
@@ -0,0 +1,645 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Plot latency bars from the benchmark CSV produced by
|
||||||
|
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||||
|
|
||||||
|
If you also provide an offline_vs_update.csv via --csv-right
|
||||||
|
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||||
|
output a side-by-side figure:
|
||||||
|
- Left: ms/passage bars (four RNG scenarios).
|
||||||
|
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python benchmarks/update/plot_bench_results.py \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
|
||||||
|
The script selects the latest run_id in the CSV and plots four bars for
|
||||||
|
the default scenarios:
|
||||||
|
- baseline
|
||||||
|
- no_cache_baseline
|
||||||
|
- disable_forward_rng
|
||||||
|
- disable_forward_and_reverse_rng
|
||||||
|
|
||||||
|
If multiple rows exist per scenario for that run_id, the script averages
|
||||||
|
their latency_ms_per_passage values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DEFAULT_SCENARIOS = [
|
||||||
|
"no_cache_baseline",
|
||||||
|
"baseline",
|
||||||
|
"disable_forward_rng",
|
||||||
|
"disable_forward_and_reverse_rng",
|
||||||
|
]
|
||||||
|
|
||||||
|
SCENARIO_LABELS = {
|
||||||
|
"baseline": "+ Cache",
|
||||||
|
"no_cache_baseline": "Naive \n Recompute",
|
||||||
|
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||||
|
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Paper-style colors and hatches for scenarios
|
||||||
|
SCENARIO_STYLES = {
|
||||||
|
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||||
|
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||||
|
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||||
|
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_latest_run(csv_path: Path):
|
||||||
|
rows = []
|
||||||
|
with csv_path.open("r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
rows.append(row)
|
||||||
|
if not rows:
|
||||||
|
raise SystemExit("CSV is empty: no rows to plot")
|
||||||
|
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||||
|
run_ids = [r.get("run_id", "") for r in rows]
|
||||||
|
latest = max(run_ids)
|
||||||
|
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||||
|
if not latest_rows:
|
||||||
|
# Fallback: take last 4 rows
|
||||||
|
latest_rows = rows[-4:]
|
||||||
|
latest = latest_rows[-1].get("run_id", "unknown")
|
||||||
|
return latest, latest_rows
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_latency(rows):
|
||||||
|
acc = defaultdict(list)
|
||||||
|
for r in rows:
|
||||||
|
sc = r.get("scenario", "")
|
||||||
|
try:
|
||||||
|
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
acc[sc].append(val)
|
||||||
|
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_cap(values: list[float]) -> float | None:
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
sorted_vals = sorted(values, reverse=True)
|
||||||
|
if len(sorted_vals) < 2:
|
||||||
|
return None
|
||||||
|
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||||
|
if second <= 0:
|
||||||
|
return None
|
||||||
|
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||||
|
if max_v >= 2.5 * second:
|
||||||
|
return second * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||||
|
# Draw small diagonal ticks near left/right to signal cap
|
||||||
|
x0, x1 = rel_x0, rel_x1
|
||||||
|
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
if v >= 1000:
|
||||||
|
return f"{v / 1000:.1f}k"
|
||||||
|
return f"{v:.1f}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.rcParams["font.family"] = "Helvetica"
|
||||||
|
plt.rcParams["ytick.direction"] = "in"
|
||||||
|
plt.rcParams["hatch.linewidth"] = 1.5
|
||||||
|
plt.rcParams["font.weight"] = "bold"
|
||||||
|
plt.rcParams["axes.labelweight"] = "bold"
|
||||||
|
plt.rcParams["text.usetex"] = True
|
||||||
|
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Path to results CSV (defaults to bench_results.csv)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--out",
|
||||||
|
type=Path,
|
||||||
|
default=Path("add_ablation.pdf"),
|
||||||
|
help="Output image path",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv-right",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--no-auto-cap",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||||
|
)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
latest_run, latest_rows = load_latest_run(args.csv)
|
||||||
|
avg = aggregate_latency(latest_rows)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
except Exception as e:
|
||||||
|
raise SystemExit(f"matplotlib not available: {e}")
|
||||||
|
|
||||||
|
scenarios = DEFAULT_SCENARIOS
|
||||||
|
values = [avg.get(name, 0.0) for name in scenarios]
|
||||||
|
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
# If right CSV is provided, build side-by-side figure
|
||||||
|
if args.csv_right is not None:
|
||||||
|
try:
|
||||||
|
right_rows_all = []
|
||||||
|
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||||
|
rreader = csv.DictReader(f)
|
||||||
|
right_rows_all = list(rreader)
|
||||||
|
if right_rows_all:
|
||||||
|
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||||
|
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||||
|
else:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
except Exception:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
|
||||||
|
a_total = 0.0
|
||||||
|
b_makespan = 0.0
|
||||||
|
for r in right_rows:
|
||||||
|
sc = (r.get("scenario", "") or "").strip().upper()
|
||||||
|
if sc == "A":
|
||||||
|
try:
|
||||||
|
a_total = float(r.get("total_time_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif sc == "B":
|
||||||
|
try:
|
||||||
|
b_makespan = float(r.get("makespan_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import gridspec
|
||||||
|
|
||||||
|
# Left subplot (reuse current style, with optional cap)
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
# Use broken axis for left subplot
|
||||||
|
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||||
|
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||||
|
gs = gridspec.GridSpec(
|
||||||
|
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||||
|
)
|
||||||
|
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||||
|
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||||
|
ax_right = fig.add_subplot(gs[:, 1])
|
||||||
|
|
||||||
|
# Determine break points
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = (
|
||||||
|
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||||
|
) # Increased to show more range
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.5, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = (
|
||||||
|
max(values) * 1.90 if values else 1.0
|
||||||
|
) # Increase headroom to 1.90 for text label and tick range
|
||||||
|
|
||||||
|
# Draw bars on both axes
|
||||||
|
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Set limits
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_left_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values (convert ms to s)
|
||||||
|
values_s = [v / 1000.0 for v in values]
|
||||||
|
lower_cap_s = lower_cap / 1000.0
|
||||||
|
upper_start_s = upper_start / 1000.0
|
||||||
|
ymax_s = ymax / 1000.0
|
||||||
|
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||||
|
ax_left_bottom.clear()
|
||||||
|
ax_left_top.clear()
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||||
|
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||||
|
# Draw in bottom axis for all bars
|
||||||
|
ax_left_bottom.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||||
|
if v > upper_start_s:
|
||||||
|
ax_left_top.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
for i, v in enumerate(values_s):
|
||||||
|
if v <= lower_cap_s:
|
||||||
|
ax_left_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap_s * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left_top.text(
|
||||||
|
i,
|
||||||
|
v + (ymax_s - upper_start_s) * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hide spines between axes
|
||||||
|
ax_left_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_left_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_left_top.tick_params(
|
||||||
|
labeltop=False, labelbottom=False, bottom=False
|
||||||
|
) # Hide tick marks
|
||||||
|
ax_left_bottom.xaxis.tick_bottom()
|
||||||
|
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||||
|
|
||||||
|
# Draw break marks (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_left_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||||
|
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
|
||||||
|
ax_left_bottom.set_xticks(x)
|
||||||
|
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||||
|
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match bar width with right subplot
|
||||||
|
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||||
|
ax_left_top.set_xlim(-0.6, 3.6)
|
||||||
|
|
||||||
|
ax_left = ax_left_bottom # for compatibility
|
||||||
|
else:
|
||||||
|
# Regular side-by-side layout
|
||||||
|
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||||
|
if val > cap:
|
||||||
|
bars[i].set_hatch("//")
|
||||||
|
ax_left.text(
|
||||||
|
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(val),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax_left.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax_left, y=0.98)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
else:
|
||||||
|
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax_left.set_ylabel("Latency (ms per passage)")
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
ax_left.set_title(
|
||||||
|
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right subplot (A vs B, seconds) - paper style
|
||||||
|
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||||
|
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||||
|
r_styles = [
|
||||||
|
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||||
|
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||||
|
]
|
||||||
|
# 2 bars, centered with proper spacing
|
||||||
|
xr = [0, 1]
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||||
|
ax_right.bar(
|
||||||
|
xr[i],
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
for i, v in enumerate(r_values):
|
||||||
|
max_v = max(r_values) if r_values else 1.0
|
||||||
|
offset = max(0.0002, 0.02 * max_v)
|
||||||
|
ax_right.text(
|
||||||
|
xr[i],
|
||||||
|
v + offset,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_right.set_xticks(xr)
|
||||||
|
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_right.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match left subplot's bar width visually
|
||||||
|
# Accounting for width_ratios=[1.5, 1]:
|
||||||
|
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||||
|
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# Right: 2 bars, need same visual width
|
||||||
|
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# range_right = 4.2 / 1.5 = 2.8
|
||||||
|
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||||
|
ax_right.set_xlim(-0.9, 1.9)
|
||||||
|
|
||||||
|
# Set y-axis limit with headroom for text labels
|
||||||
|
if r_values:
|
||||||
|
max_v = max(r_values)
|
||||||
|
ax_right.set_ylim(0, max_v * 1.15)
|
||||||
|
|
||||||
|
# Format y-axis to avoid scientific notation
|
||||||
|
ax_right.ticklabel_format(style="plain", axis="y")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Add aligned ylabels using fig.text (after tight_layout)
|
||||||
|
# Get the vertical center of the entire figure
|
||||||
|
fig_center_y = 0.5
|
||||||
|
# Left ylabel - closer to left plot
|
||||||
|
left_x = 0.05
|
||||||
|
fig.text(
|
||||||
|
left_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
# Right ylabel - closer to right plot
|
||||||
|
right_bbox = ax_right.get_position()
|
||||||
|
right_x = right_bbox.x0 - 0.07
|
||||||
|
fig.text(
|
||||||
|
right_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Broken-Y mode
|
||||||
|
if args.broken_y:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.5, 6.75),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine default breaks from second-highest
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
|
||||||
|
# Hide spines between axes and draw diagonal break marks
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
|
||||||
|
# Diagonal lines at the break (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||||
|
|
||||||
|
ax_bottom.set_xticks(x)
|
||||||
|
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax = ax_bottom # for labeling below
|
||||||
|
else:
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
|
||||||
|
plt.figure(figsize=(5.4, 3.15))
|
||||||
|
ax = plt.gca()
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||||
|
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(bar[0])
|
||||||
|
# Hatch and annotate when capped
|
||||||
|
if val > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax, y=0.98)
|
||||||
|
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||||
|
v > cap for v in values
|
||||||
|
) else None
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(
|
||||||
|
idx,
|
||||||
|
val + 1.0,
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=10,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
# Try to extract some context for title
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
fig.text(
|
||||||
|
0.02,
|
||||||
|
0.5,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
fig.suptitle(
|
||||||
|
"Add Operation Latency",
|
||||||
|
fontsize=11,
|
||||||
|
y=0.98,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||||
|
else:
|
||||||
|
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||||
|
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -29,12 +29,25 @@ if(APPLE)
|
|||||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Use system ZeroMQ instead of building from source
|
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
|
||||||
find_package(PkgConfig REQUIRED)
|
find_package(PkgConfig REQUIRED)
|
||||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
|
||||||
|
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
|
||||||
|
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||||
|
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
|
||||||
|
|
||||||
|
# This creates PkgConfig::ZMQ target automatically with correct properties
|
||||||
|
if(TARGET PkgConfig::ZMQ)
|
||||||
|
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Add cppzmq headers
|
# Add cppzmq headers
|
||||||
include_directories(third_party/cppzmq)
|
include_directories(SYSTEM third_party/cppzmq)
|
||||||
|
|
||||||
# Configure msgpack-c - disable boost dependency
|
# Configure msgpack-c - disable boost dependency
|
||||||
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||||
|
|||||||
@@ -215,6 +215,8 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
|
if hasattr(self._index, "set_zmq_port"):
|
||||||
|
self._index.set_zmq_port(zmq_port)
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
|
|||||||
@@ -820,10 +820,10 @@ class LeannBuilder:
|
|||||||
actual_port,
|
actual_port,
|
||||||
requested_zmq_port,
|
requested_zmq_port,
|
||||||
)
|
)
|
||||||
try:
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
index.hnsw.zmq_port = actual_port
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
except AttributeError:
|
elif hasattr(index, "set_zmq_port"):
|
||||||
pass
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
if needs_recompute:
|
if needs_recompute:
|
||||||
for i in range(embeddings.shape[0]):
|
for i in range(embeddings.shape[0]):
|
||||||
|
|||||||
@@ -11,6 +11,119 @@ from llama_index.core.node_parser import SentenceSplitter
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_token_count(text: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count for a text string.
|
||||||
|
Uses conservative estimation: ~4 characters per token for natural text,
|
||||||
|
~1.2 tokens per character for code (worse tokenization).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to estimate tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return len(encoder.encode(text))
|
||||||
|
except ImportError:
|
||||||
|
# Fallback: Conservative character-based estimation
|
||||||
|
# Assume worst case for code: 1.2 tokens per character
|
||||||
|
return int(len(text) * 1.2)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_safe_chunk_size(
|
||||||
|
model_token_limit: int,
|
||||||
|
overlap_tokens: int,
|
||||||
|
chunking_mode: str = "traditional",
|
||||||
|
safety_factor: float = 0.9,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Calculate safe chunk size accounting for overlap and safety margin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_token_limit: Maximum tokens supported by embedding model
|
||||||
|
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
|
||||||
|
chunking_mode: "traditional" (tokens) or "ast" (characters)
|
||||||
|
safety_factor: Safety margin (0.9 = 10% safety margin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Safe chunk size: tokens for traditional, characters for AST
|
||||||
|
"""
|
||||||
|
safe_limit = int(model_token_limit * safety_factor)
|
||||||
|
|
||||||
|
if chunking_mode == "traditional":
|
||||||
|
# Traditional chunking uses tokens
|
||||||
|
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
|
||||||
|
return max(1, safe_limit - overlap_tokens)
|
||||||
|
else: # AST chunking
|
||||||
|
# AST uses characters, need to convert
|
||||||
|
# Conservative estimate: 1.2 tokens per char for code
|
||||||
|
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
|
||||||
|
safe_chars = int(safe_limit / 1.2)
|
||||||
|
return max(1, safe_chars - overlap_chars)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
|
||||||
|
"""
|
||||||
|
Validate that chunks don't exceed token limits and truncate if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of text chunks to validate
|
||||||
|
max_tokens: Maximum tokens allowed per chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (validated_chunks, num_truncated)
|
||||||
|
"""
|
||||||
|
validated_chunks = []
|
||||||
|
num_truncated = 0
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
estimated_tokens = estimate_token_count(chunk)
|
||||||
|
|
||||||
|
if estimated_tokens > max_tokens:
|
||||||
|
# Truncate chunk to fit token limit
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
|
tokens = encoder.encode(chunk)
|
||||||
|
if len(tokens) > max_tokens:
|
||||||
|
truncated_tokens = tokens[:max_tokens]
|
||||||
|
truncated_chunk = encoder.decode(truncated_tokens)
|
||||||
|
validated_chunks.append(truncated_chunk)
|
||||||
|
num_truncated += 1
|
||||||
|
logger.warning(
|
||||||
|
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
|
||||||
|
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validated_chunks.append(chunk)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback: Conservative character truncation
|
||||||
|
char_limit = int(max_tokens / 1.2) # Conservative for code
|
||||||
|
if len(chunk) > char_limit:
|
||||||
|
truncated_chunk = chunk[:char_limit]
|
||||||
|
validated_chunks.append(truncated_chunk)
|
||||||
|
num_truncated += 1
|
||||||
|
logger.warning(
|
||||||
|
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
|
||||||
|
f"(conservative estimate for {max_tokens} tokens)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validated_chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
validated_chunks.append(chunk)
|
||||||
|
|
||||||
|
if num_truncated > 0:
|
||||||
|
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
|
||||||
|
|
||||||
|
return validated_chunks, num_truncated
|
||||||
|
|
||||||
|
|
||||||
# Code file extensions supported by astchunk
|
# Code file extensions supported by astchunk
|
||||||
CODE_EXTENSIONS = {
|
CODE_EXTENSIONS = {
|
||||||
".py": "python",
|
".py": "python",
|
||||||
@@ -82,6 +195,17 @@ def create_ast_chunks(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Warn if AST chunk size + overlap might exceed common token limits
|
||||||
|
estimated_max_tokens = int(
|
||||||
|
(max_chunk_size + chunk_overlap) * 1.2
|
||||||
|
) # Conservative estimate
|
||||||
|
if estimated_max_tokens > 512:
|
||||||
|
logger.warning(
|
||||||
|
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
||||||
|
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
||||||
|
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}"
|
||||||
|
)
|
||||||
|
|
||||||
configs = {
|
configs = {
|
||||||
"max_chunk_size": max_chunk_size,
|
"max_chunk_size": max_chunk_size,
|
||||||
"language": language,
|
"language": language,
|
||||||
@@ -217,4 +341,14 @@ def create_text_chunks(
|
|||||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
return all_chunks
|
|
||||||
|
# Validate chunk token limits (default to 512 for safety)
|
||||||
|
# This provides a safety net for embedding models with token limits
|
||||||
|
validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512)
|
||||||
|
|
||||||
|
if num_truncated > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
return validated_chunks
|
||||||
|
|||||||
@@ -181,25 +181,25 @@ Examples:
|
|||||||
"--doc-chunk-size",
|
"--doc-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
help="Document chunk size in tokens/characters (default: 256)",
|
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--doc-chunk-overlap",
|
"--doc-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
default=128,
|
||||||
help="Document chunk overlap (default: 128)",
|
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--code-chunk-size",
|
"--code-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
help="Code chunk size in tokens/lines (default: 512)",
|
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--code-chunk-overlap",
|
"--code-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=50,
|
||||||
help="Code chunk overlap (default: 50)",
|
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--use-ast-chunking",
|
"--use-ast-chunking",
|
||||||
@@ -209,14 +209,14 @@ Examples:
|
|||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-chunk-size",
|
"--ast-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=768,
|
default=300,
|
||||||
help="AST chunk size in characters (default: 768)",
|
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-chunk-overlap",
|
"--ast-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=96,
|
default=64,
|
||||||
help="AST chunk overlap in characters (default: 96)",
|
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-fallback-traditional",
|
"--ast-fallback-traditional",
|
||||||
@@ -255,6 +255,11 @@ Examples:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Non-interactive mode: automatically select index without prompting",
|
help="Non-interactive mode: automatically select index without prompting",
|
||||||
)
|
)
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--show-metadata",
|
||||||
|
action="store_true",
|
||||||
|
help="Display file paths and metadata in search results",
|
||||||
|
)
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
@@ -1157,6 +1162,11 @@ Examples:
|
|||||||
print(f"Warning: Could not process {file_path}: {e}")
|
print(f"Warning: Could not process {file_path}: {e}")
|
||||||
|
|
||||||
# Load other file types with default reader
|
# Load other file types with default reader
|
||||||
|
# Exclude PDFs from code_extensions if they were already processed separately
|
||||||
|
other_file_extensions = code_extensions
|
||||||
|
if should_process_pdfs and ".pdf" in code_extensions:
|
||||||
|
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create a custom file filter function using our PathSpec
|
# Create a custom file filter function using our PathSpec
|
||||||
def file_filter(
|
def file_filter(
|
||||||
@@ -1172,15 +1182,19 @@ Examples:
|
|||||||
except (ValueError, OSError):
|
except (ValueError, OSError):
|
||||||
return True # Include files that can't be processed
|
return True # Include files that can't be processed
|
||||||
|
|
||||||
other_docs = SimpleDirectoryReader(
|
# Only load other file types if there are extensions to process
|
||||||
docs_dir,
|
if other_file_extensions:
|
||||||
recursive=True,
|
other_docs = SimpleDirectoryReader(
|
||||||
encoding="utf-8",
|
docs_dir,
|
||||||
required_exts=code_extensions,
|
recursive=True,
|
||||||
file_extractor={}, # Use default extractors
|
encoding="utf-8",
|
||||||
exclude_hidden=not include_hidden,
|
required_exts=other_file_extensions,
|
||||||
filename_as_id=True,
|
file_extractor={}, # Use default extractors
|
||||||
).load_data(show_progress=True)
|
exclude_hidden=not include_hidden,
|
||||||
|
filename_as_id=True,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
else:
|
||||||
|
other_docs = []
|
||||||
|
|
||||||
# Filter documents after loading based on gitignore rules
|
# Filter documents after loading based on gitignore rules
|
||||||
filtered_docs = []
|
filtered_docs = []
|
||||||
@@ -1263,7 +1277,7 @@ Examples:
|
|||||||
from .chunking_utils import create_text_chunks
|
from .chunking_utils import create_text_chunks
|
||||||
|
|
||||||
# Use enhanced chunking with AST support
|
# Use enhanced chunking with AST support
|
||||||
all_texts = create_text_chunks(
|
chunk_texts = create_text_chunks(
|
||||||
documents,
|
documents,
|
||||||
chunk_size=self.node_parser.chunk_size,
|
chunk_size=self.node_parser.chunk_size,
|
||||||
chunk_overlap=self.node_parser.chunk_overlap,
|
chunk_overlap=self.node_parser.chunk_overlap,
|
||||||
@@ -1274,6 +1288,14 @@ Examples:
|
|||||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Note: AST chunking currently returns plain text chunks without metadata
|
||||||
|
# We preserve basic file info by associating chunks with their source documents
|
||||||
|
# For better metadata preservation, documents list order should be maintained
|
||||||
|
for chunk_text in chunk_texts:
|
||||||
|
# TODO: Enhance create_text_chunks to return metadata alongside text
|
||||||
|
# For now, we store chunks with empty metadata
|
||||||
|
all_texts.append({"text": chunk_text, "metadata": {}})
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(
|
print(
|
||||||
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||||
@@ -1285,17 +1307,27 @@ Examples:
|
|||||||
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
||||||
# Check if this is a code file based on source path
|
# Check if this is a code file based on source path
|
||||||
source_path = doc.metadata.get("source", "")
|
source_path = doc.metadata.get("source", "")
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
||||||
|
|
||||||
|
# Extract metadata to preserve with chunks
|
||||||
|
chunk_metadata = {
|
||||||
|
"file_path": file_path or source_path,
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional metadata if available
|
||||||
|
if "creation_date" in doc.metadata:
|
||||||
|
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||||
|
if "last_modified_date" in doc.metadata:
|
||||||
|
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||||
|
|
||||||
# Use appropriate parser based on file type
|
# Use appropriate parser based on file type
|
||||||
parser = self.code_parser if is_code_file else self.node_parser
|
parser = self.code_parser if is_code_file else self.node_parser
|
||||||
nodes = parser.get_nodes_from_documents([doc])
|
nodes = parser.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
text_with_source = (
|
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
|
||||||
"Chunk source:" + source_path + "\n" + node.get_content().replace("\n", " ")
|
|
||||||
)
|
|
||||||
all_texts.append(text_with_source)
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||||
return all_texts
|
return all_texts
|
||||||
@@ -1370,7 +1402,7 @@ Examples:
|
|||||||
|
|
||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
print(f"Building index '{index_name}' with {args.backend_name} backend...")
|
||||||
|
|
||||||
embedding_options: dict[str, Any] = {}
|
embedding_options: dict[str, Any] = {}
|
||||||
if args.embedding_mode == "ollama":
|
if args.embedding_mode == "ollama":
|
||||||
@@ -1382,7 +1414,7 @@ Examples:
|
|||||||
embedding_options["api_key"] = resolved_embedding_key
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend,
|
backend_name=args.backend_name,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
embedding_options=embedding_options or None,
|
embedding_options=embedding_options or None,
|
||||||
@@ -1393,10 +1425,8 @@ Examples:
|
|||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk_text_with_source in all_texts:
|
for chunk in all_texts:
|
||||||
chunk_source = chunk_text_with_source.split("\n")[0].split(":")[1]
|
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||||
chunk_text = chunk_text_with_source.split("\n")[1]
|
|
||||||
builder.add_text(chunk_text, {"source": chunk_source})
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"Index built at {index_path}")
|
print(f"Index built at {index_path}")
|
||||||
@@ -1517,6 +1547,23 @@ Examples:
|
|||||||
print(f"Search results for '{query}' (top {len(results)}):")
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
for i, result in enumerate(results, 1):
|
for i, result in enumerate(results, 1):
|
||||||
print(f"{i}. Score: {result.score:.3f}")
|
print(f"{i}. Score: {result.score:.3f}")
|
||||||
|
|
||||||
|
# Display metadata if flag is set
|
||||||
|
if args.show_metadata and result.metadata:
|
||||||
|
file_path = result.metadata.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
print(f" 📄 File: {file_path}")
|
||||||
|
|
||||||
|
file_name = result.metadata.get("file_name", "")
|
||||||
|
if file_name and file_name != file_path:
|
||||||
|
print(f" 📝 Name: {file_name}")
|
||||||
|
|
||||||
|
# Show timestamps if available
|
||||||
|
if "creation_date" in result.metadata:
|
||||||
|
print(f" 🕐 Created: {result.metadata['creation_date']}")
|
||||||
|
if "last_modified_date" in result.metadata:
|
||||||
|
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
|
||||||
|
|
||||||
print(f" {result.text[:200]}...")
|
print(f" {result.text[:200]}...")
|
||||||
print(f" Source: {result.metadata.get('source', '')}")
|
print(f" Source: {result.metadata.get('source', '')}")
|
||||||
print()
|
print()
|
||||||
|
|||||||
@@ -14,6 +14,89 @@ import torch
|
|||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
|
||||||
|
"""
|
||||||
|
Truncate texts to token limit using tiktoken or conservative character truncation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to truncate
|
||||||
|
max_tokens: Maximum tokens allowed per text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of truncated texts that should fit within token limit
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
|
truncated = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
tokens = encoder.encode(text)
|
||||||
|
if len(tokens) > max_tokens:
|
||||||
|
# Truncate to max_tokens and decode back to text
|
||||||
|
truncated_tokens = tokens[:max_tokens]
|
||||||
|
truncated_text = encoder.decode(truncated_tokens)
|
||||||
|
truncated.append(truncated_text)
|
||||||
|
logger.warning(
|
||||||
|
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
|
||||||
|
f"(from {len(text)} to {len(truncated_text)} characters)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
truncated.append(text)
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Fallback: Conservative character truncation
|
||||||
|
# Assume worst case: 1.5 tokens per character for code content
|
||||||
|
char_limit = int(max_tokens / 1.5)
|
||||||
|
truncated = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
if len(text) > char_limit:
|
||||||
|
truncated_text = text[:char_limit]
|
||||||
|
truncated.append(truncated_text)
|
||||||
|
logger.warning(
|
||||||
|
f"Truncated text from {len(text)} to {char_limit} characters "
|
||||||
|
f"(conservative estimate for {max_tokens} tokens)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
truncated.append(text)
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_token_limit(model_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Get token limit for a given embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the embedding model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token limit for the model, defaults to 512 if unknown
|
||||||
|
"""
|
||||||
|
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
||||||
|
base_model_name = model_name.split(":")[0]
|
||||||
|
|
||||||
|
# Check exact match first
|
||||||
|
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
|
return EMBEDDING_MODEL_LIMITS[model_name]
|
||||||
|
|
||||||
|
# Check base name match
|
||||||
|
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
|
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||||
|
|
||||||
|
# Check partial matches for common patterns
|
||||||
|
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
||||||
|
if known_model in base_model_name or base_model_name in known_model:
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Default to conservative 512 token limit
|
||||||
|
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
|
||||||
|
return 512
|
||||||
|
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -23,6 +106,17 @@ logger.setLevel(log_level)
|
|||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: dict[str, Any] = {}
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Known embedding model token limits
|
||||||
|
EMBEDDING_MODEL_LIMITS = {
|
||||||
|
"nomic-embed-text": 512,
|
||||||
|
"nomic-embed-text-v2": 512,
|
||||||
|
"mxbai-embed-large": 512,
|
||||||
|
"all-minilm": 512,
|
||||||
|
"bge-m3": 8192,
|
||||||
|
"snowflake-arctic-embed": 512,
|
||||||
|
# Add more models as needed
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@@ -574,9 +668,10 @@ def compute_embeddings_ollama(
|
|||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with simplified batch processing.
|
Compute embeddings using Ollama API with true batch processing.
|
||||||
|
|
||||||
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
Uses the /api/embed endpoint which supports batch inputs.
|
||||||
|
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
@@ -681,11 +776,11 @@ def compute_embeddings_ollama(
|
|||||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
||||||
model_name = resolved_model_name
|
model_name = resolved_model_name
|
||||||
|
|
||||||
# Verify the model supports embeddings by testing it
|
# Verify the model supports embeddings by testing it with /api/embed
|
||||||
try:
|
try:
|
||||||
test_response = requests.post(
|
test_response = requests.post(
|
||||||
f"{resolved_host}/api/embeddings",
|
f"{resolved_host}/api/embed",
|
||||||
json={"model": model_name, "prompt": "test"},
|
json={"model": model_name, "input": "test"},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
if test_response.status_code != 200:
|
if test_response.status_code != 200:
|
||||||
@@ -717,63 +812,80 @@ def compute_embeddings_ollama(
|
|||||||
# If torch is not available, use conservative batch size
|
# If torch is not available, use conservative batch size
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size}")
|
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||||
|
|
||||||
|
# Get model token limit and apply truncation
|
||||||
|
token_limit = get_model_token_limit(model_name)
|
||||||
|
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||||
|
|
||||||
|
# Apply token-aware truncation to all texts
|
||||||
|
truncated_texts = truncate_to_token_limit(texts, token_limit)
|
||||||
|
if len(truncated_texts) != len(texts):
|
||||||
|
logger.error("Truncation failed - text count mismatch")
|
||||||
|
truncated_texts = texts # Fallback to original texts
|
||||||
|
|
||||||
def get_batch_embeddings(batch_texts):
|
def get_batch_embeddings(batch_texts):
|
||||||
"""Get embeddings for a batch of texts."""
|
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||||
all_embeddings = []
|
max_retries = 3
|
||||||
failed_indices = []
|
retry_count = 0
|
||||||
|
|
||||||
for i, text in enumerate(batch_texts):
|
# Texts are already truncated to token limit by the outer function
|
||||||
max_retries = 3
|
while retry_count < max_retries:
|
||||||
retry_count = 0
|
try:
|
||||||
|
# Use /api/embed endpoint with "input" parameter for batch processing
|
||||||
|
response = requests.post(
|
||||||
|
f"{resolved_host}/api/embed",
|
||||||
|
json={"model": model_name, "input": batch_texts},
|
||||||
|
timeout=60, # Increased timeout for batch processing
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
# Truncate very long texts to avoid API issues
|
result = response.json()
|
||||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
batch_embeddings = result.get("embeddings")
|
||||||
while retry_count < max_retries:
|
|
||||||
try:
|
if batch_embeddings is None:
|
||||||
response = requests.post(
|
raise ValueError("No embeddings returned from API")
|
||||||
f"{resolved_host}/api/embeddings",
|
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
if not isinstance(batch_embeddings, list):
|
||||||
timeout=30,
|
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
|
||||||
|
|
||||||
|
if len(batch_embeddings) != len(batch_texts):
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
return batch_embeddings, []
|
||||||
embedding = result.get("embedding")
|
|
||||||
|
|
||||||
if embedding is None:
|
except requests.exceptions.Timeout:
|
||||||
raise ValueError(f"No embedding returned for text {i}")
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.warning(f"Timeout for batch after {max_retries} retries")
|
||||||
|
return None, list(range(len(batch_texts)))
|
||||||
|
|
||||||
if not isinstance(embedding, list) or len(embedding) == 0:
|
except Exception as e:
|
||||||
raise ValueError(f"Invalid embedding format for text {i}")
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
# Enhanced error detection for token limit violations
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
if "token" in error_msg and (
|
||||||
|
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
|
||||||
|
):
|
||||||
|
logger.error(
|
||||||
|
f"Token limit exceeded for batch. Error: {e}. "
|
||||||
|
f"Consider reducing chunk sizes or check token truncation."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to get embeddings for batch: {e}")
|
||||||
|
return None, list(range(len(batch_texts)))
|
||||||
|
|
||||||
all_embeddings.append(embedding)
|
return None, list(range(len(batch_texts)))
|
||||||
break
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
# Process truncated texts in batches
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.error(f"Failed to get embedding for text {i}: {e}")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
return all_embeddings, failed_indices
|
|
||||||
|
|
||||||
# Process texts in batches
|
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
all_failed_indices = []
|
all_failed_indices = []
|
||||||
|
|
||||||
# Setup progress bar if needed
|
# Setup progress bar if needed
|
||||||
show_progress = is_build or len(texts) > 10
|
show_progress = is_build or len(truncated_texts) > 10
|
||||||
try:
|
try:
|
||||||
if show_progress:
|
if show_progress:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -781,32 +893,36 @@ def compute_embeddings_ollama(
|
|||||||
show_progress = False
|
show_progress = False
|
||||||
|
|
||||||
# Process batches
|
# Process batches
|
||||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
num_batches = (len(truncated_texts) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
if show_progress:
|
if show_progress:
|
||||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||||
else:
|
else:
|
||||||
batch_iterator = range(num_batches)
|
batch_iterator = range(num_batches)
|
||||||
|
|
||||||
for batch_idx in batch_iterator:
|
for batch_idx in batch_iterator:
|
||||||
start_idx = batch_idx * batch_size
|
start_idx = batch_idx * batch_size
|
||||||
end_idx = min(start_idx + batch_size, len(texts))
|
end_idx = min(start_idx + batch_size, len(truncated_texts))
|
||||||
batch_texts = texts[start_idx:end_idx]
|
batch_texts = truncated_texts[start_idx:end_idx]
|
||||||
|
|
||||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||||
|
|
||||||
# Adjust failed indices to global indices
|
if batch_embeddings is not None:
|
||||||
global_failed = [start_idx + idx for idx in batch_failed]
|
all_embeddings.extend(batch_embeddings)
|
||||||
all_failed_indices.extend(global_failed)
|
else:
|
||||||
all_embeddings.extend(batch_embeddings)
|
# Entire batch failed, add None placeholders
|
||||||
|
all_embeddings.extend([None] * len(batch_texts))
|
||||||
|
# Adjust failed indices to global indices
|
||||||
|
global_failed = [start_idx + idx for idx in batch_failed]
|
||||||
|
all_failed_indices.extend(global_failed)
|
||||||
|
|
||||||
# Handle failed embeddings
|
# Handle failed embeddings
|
||||||
if all_failed_indices:
|
if all_failed_indices:
|
||||||
if len(all_failed_indices) == len(texts):
|
if len(all_failed_indices) == len(truncated_texts):
|
||||||
raise RuntimeError("Failed to compute any embeddings")
|
raise RuntimeError("Failed to compute any embeddings")
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use zero embeddings as fallback for failed ones
|
# Use zero embeddings as fallback for failed ones
|
||||||
|
|||||||
@@ -60,6 +60,11 @@ def handle_request(request):
|
|||||||
"maximum": 128,
|
"maximum": 128,
|
||||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||||
},
|
},
|
||||||
|
"show_metadata": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": False,
|
||||||
|
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["index_name", "query"],
|
"required": ["index_name", "query"],
|
||||||
},
|
},
|
||||||
@@ -104,6 +109,8 @@ def handle_request(request):
|
|||||||
f"--complexity={args.get('complexity', 32)}",
|
f"--complexity={args.get('complexity', 32)}",
|
||||||
"--non-interactive",
|
"--non-interactive",
|
||||||
]
|
]
|
||||||
|
if args.get("show_metadata", False):
|
||||||
|
cmd.append("--show-metadata")
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
elif tool_name == "leann_list":
|
elif tool_name == "leann_list":
|
||||||
|
|||||||
162
test_colqwen_reproduction.py
Normal file
162
test_colqwen_reproduction.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to reproduce ColQwen results from issue #119
|
||||||
|
https://github.com/yichuan-w/LEANN/issues/119
|
||||||
|
|
||||||
|
This script demonstrates the ColQwen workflow:
|
||||||
|
1. Download sample PDF
|
||||||
|
2. Convert to images
|
||||||
|
3. Build multimodal index
|
||||||
|
4. Run test queries
|
||||||
|
5. Generate similarity maps
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("🧪 ColQwen Reproduction Test - Issue #119")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Check if we're in the right directory
|
||||||
|
repo_root = Path.cwd()
|
||||||
|
if not (repo_root / "apps" / "colqwen_rag.py").exists():
|
||||||
|
print("❌ Please run this script from the LEANN repository root")
|
||||||
|
print(" cd /path/to/LEANN && python test_colqwen_reproduction.py")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("✅ Repository structure looks good")
|
||||||
|
|
||||||
|
# Step 1: Check dependencies
|
||||||
|
print("\n📦 Checking dependencies...")
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Check if pdf2image is available
|
||||||
|
if importlib.util.find_spec("pdf2image") is None:
|
||||||
|
raise ImportError("pdf2image not found")
|
||||||
|
# Check if colpali_engine is available
|
||||||
|
if importlib.util.find_spec("colpali_engine") is None:
|
||||||
|
raise ImportError("colpali_engine not found")
|
||||||
|
|
||||||
|
print("✅ Core dependencies available")
|
||||||
|
print(f" - PyTorch: {torch.__version__}")
|
||||||
|
print(f" - CUDA available: {torch.cuda.is_available()}")
|
||||||
|
print(
|
||||||
|
f" - MPS available: {hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()}"
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ Missing dependency: {e}")
|
||||||
|
print("\n📥 Install missing dependencies:")
|
||||||
|
print(
|
||||||
|
" uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 2: Download sample PDF
|
||||||
|
print("\n📄 Setting up sample PDF...")
|
||||||
|
pdf_dir = repo_root / "test_pdfs"
|
||||||
|
pdf_dir.mkdir(exist_ok=True)
|
||||||
|
sample_pdf = pdf_dir / "attention_paper.pdf"
|
||||||
|
|
||||||
|
if not sample_pdf.exists():
|
||||||
|
print("📥 Downloading sample paper (Attention Is All You Need)...")
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
try:
|
||||||
|
urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", sample_pdf)
|
||||||
|
print(f"✅ Downloaded: {sample_pdf}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Download failed: {e}")
|
||||||
|
print(" Please manually download a PDF to test_pdfs/attention_paper.pdf")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(f"✅ Using existing PDF: {sample_pdf}")
|
||||||
|
|
||||||
|
# Step 3: Test ColQwen RAG
|
||||||
|
print("\n🚀 Testing ColQwen RAG...")
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
print("\n1️⃣ Building multimodal index...")
|
||||||
|
build_cmd = f"python -m apps.colqwen_rag build --pdfs {pdf_dir} --index test_attention --model colqwen2 --pages-dir test_pages"
|
||||||
|
print(f" Command: {build_cmd}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = os.system(build_cmd)
|
||||||
|
if result == 0:
|
||||||
|
print("✅ Index built successfully!")
|
||||||
|
else:
|
||||||
|
print("❌ Index building failed")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error building index: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Test search
|
||||||
|
print("\n2️⃣ Testing search...")
|
||||||
|
test_queries = [
|
||||||
|
"How does attention mechanism work?",
|
||||||
|
"What is the transformer architecture?",
|
||||||
|
"How do you compute self-attention?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
print(f"\n🔍 Query: '{query}'")
|
||||||
|
search_cmd = f'python -m apps.colqwen_rag search test_attention "{query}" --top-k 3'
|
||||||
|
print(f" Command: {search_cmd}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = os.system(search_cmd)
|
||||||
|
if result == 0:
|
||||||
|
print("✅ Search completed")
|
||||||
|
else:
|
||||||
|
print("❌ Search failed")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Search error: {e}")
|
||||||
|
|
||||||
|
# Test interactive mode (briefly)
|
||||||
|
print("\n3️⃣ Testing interactive mode...")
|
||||||
|
print(" You can test interactive mode with:")
|
||||||
|
print(" python -m apps.colqwen_rag ask test_attention --interactive")
|
||||||
|
|
||||||
|
# Step 4: Test similarity maps (using existing script)
|
||||||
|
print("\n4️⃣ Testing similarity maps...")
|
||||||
|
similarity_script = (
|
||||||
|
repo_root
|
||||||
|
/ "apps"
|
||||||
|
/ "multimodal"
|
||||||
|
/ "vision-based-pdf-multi-vector"
|
||||||
|
/ "multi-vector-leann-similarity-map.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
if similarity_script.exists():
|
||||||
|
print(" You can generate similarity maps with:")
|
||||||
|
print(f" cd {similarity_script.parent}")
|
||||||
|
print(" python multi-vector-leann-similarity-map.py")
|
||||||
|
print(" (Edit the script to use your local PDF)")
|
||||||
|
|
||||||
|
print("\n🎉 ColQwen reproduction test completed!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" ✅ Dependencies checked")
|
||||||
|
print(" ✅ Sample PDF prepared")
|
||||||
|
print(" ✅ Index building tested")
|
||||||
|
print(" ✅ Search functionality tested")
|
||||||
|
print(" ✅ Interactive mode available")
|
||||||
|
print(" ✅ Similarity maps available")
|
||||||
|
|
||||||
|
print("\n🔗 Related repositories to check:")
|
||||||
|
print(" - https://github.com/lightonai/fast-plaid")
|
||||||
|
print(" - https://github.com/lightonai/pylate")
|
||||||
|
print(" - https://github.com/stanford-futuredata/ColBERT")
|
||||||
|
|
||||||
|
print("\n📝 Next steps:")
|
||||||
|
print(" 1. Test with your own PDFs")
|
||||||
|
print(" 2. Experiment with different queries")
|
||||||
|
print(" 3. Generate similarity maps for visual analysis")
|
||||||
|
print(" 4. Compare ColQwen2 vs ColPali performance")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user