Compare commits
14 Commits
fix/drop-p
...
feature/mc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b687600da | ||
|
|
dfae37d0ee | ||
|
|
a43fafe44e | ||
|
|
32710cf5a1 | ||
|
|
c24e62a3d9 | ||
|
|
4ccbbf3e6b | ||
|
|
d3e6cfa1f7 | ||
|
|
523eef7e79 | ||
|
|
99bb98748d | ||
|
|
fe904ec992 | ||
|
|
d2432b45f6 | ||
|
|
28521775f8 | ||
|
|
98cdcf600b | ||
|
|
1fdc9dfbfa |
85
.github/workflows/build-reusable.yml
vendored
85
.github/workflows/build-reusable.yml
vendored
@@ -35,8 +35,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
# Note: Python 3.9 dropped - uses PEP 604 union syntax (str | None)
|
||||
# which requires Python 3.10+
|
||||
- os: ubuntu-22.04
|
||||
python: '3.9'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
@@ -46,6 +46,8 @@ jobs:
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
# ARM64 Linux builds
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.9'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.10'
|
||||
- os: ubuntu-24.04-arm
|
||||
@@ -54,6 +56,8 @@ jobs:
|
||||
python: '3.12'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.13'
|
||||
- os: macos-14
|
||||
python: '3.9'
|
||||
- os: macos-14
|
||||
python: '3.10'
|
||||
- os: macos-14
|
||||
@@ -62,6 +66,8 @@ jobs:
|
||||
python: '3.12'
|
||||
- os: macos-14
|
||||
python: '3.13'
|
||||
- os: macos-15
|
||||
python: '3.9'
|
||||
- os: macos-15
|
||||
python: '3.10'
|
||||
- os: macos-15
|
||||
@@ -70,24 +76,16 @@ jobs:
|
||||
python: '3.12'
|
||||
- os: macos-15
|
||||
python: '3.13'
|
||||
# Intel Mac builds (x86_64) - replaces deprecated macos-13
|
||||
# Note: Python 3.13 excluded - PyTorch has no wheels for macOS x86_64 + Python 3.13
|
||||
# (PyTorch <=2.4.1 lacks cp313, PyTorch >=2.5.0 dropped Intel Mac support)
|
||||
- os: macos-15-intel
|
||||
- os: macos-13
|
||||
python: '3.9'
|
||||
- os: macos-13
|
||||
python: '3.10'
|
||||
- os: macos-15-intel
|
||||
- os: macos-13
|
||||
python: '3.11'
|
||||
- os: macos-15-intel
|
||||
- os: macos-13
|
||||
python: '3.12'
|
||||
# macOS 26 (beta) - arm64
|
||||
- os: macos-26
|
||||
python: '3.10'
|
||||
- os: macos-26
|
||||
python: '3.11'
|
||||
- os: macos-26
|
||||
python: '3.12'
|
||||
- os: macos-26
|
||||
python: '3.13'
|
||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
@@ -206,16 +204,13 @@ jobs:
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# Set deployment target based on runner
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
# Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
@@ -229,16 +224,14 @@ jobs:
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# Set deployment target based on runner
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||
# But Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
@@ -276,19 +269,16 @@ jobs:
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Determine deployment target based on runner OS
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
# Must match the Homebrew libraries for each macOS version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
HNSW_TARGET="13.0"
|
||||
DISKANN_TARGET="13.3"
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
HNSW_TARGET="14.0"
|
||||
DISKANN_TARGET="14.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
HNSW_TARGET="26.0"
|
||||
DISKANN_TARGET="26.0"
|
||||
fi
|
||||
|
||||
# Repair HNSW wheel
|
||||
@@ -344,15 +334,12 @@ jobs:
|
||||
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||
|
||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
2
.github/workflows/link-check.yml
vendored
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -91,8 +91,7 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
*.npy
|
||||
*.db
|
||||
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
@@ -106,6 +105,3 @@ apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weavia
|
||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||
|
||||
# AUR build directory (Arch Linux)
|
||||
paru-bin/
|
||||
|
||||
118
README.md
118
README.md
@@ -8,35 +8,19 @@
|
||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q">
|
||||
<img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||
</a>
|
||||
<a href="assets/wechat_user_group.JPG" title="Join WeChat group">
|
||||
<img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
|
||||
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
|
||||
</a>
|
||||
<p>
|
||||
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
|
||||
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
|
||||
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
@@ -201,7 +185,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
|
||||
|
||||
#### LLM Backend
|
||||
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
|
||||
|
||||
|
||||
<details>
|
||||
@@ -269,7 +253,6 @@ Below is a list of base URLs for common providers to get you started.
|
||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||
| **Anthropic** | `https://api.anthropic.com/v1` |
|
||||
|
||||
|
||||
|
||||
@@ -329,7 +312,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
@@ -392,54 +375,6 @@ python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authenticat
|
||||
|
||||
</details>
|
||||
|
||||
### 🎨 ColQwen: Multimodal PDF Retrieval with Vision-Language Models
|
||||
|
||||
Search through PDFs using both text and visual understanding with ColQwen2/ColPali models. Perfect for research papers, technical documents, and any PDFs with complex layouts, figures, or diagrams.
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: ColQwen Setup & Usage</strong></summary>
|
||||
|
||||
#### Prerequisites
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
#### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 # or colpali
|
||||
```
|
||||
|
||||
#### Search
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
#### Models
|
||||
- **ColQwen2** (`colqwen2`): Latest vision-language model with improved performance
|
||||
- **ColPali** (`colpali`): Proven multimodal retriever
|
||||
|
||||
For detailed usage, see the [ColQwen Guide](docs/COLQWEN_GUIDE.md).
|
||||
|
||||
</details>
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||
@@ -842,7 +777,7 @@ Once your iMessage conversations are indexed, you can search with queries like:
|
||||
|
||||
### MCP Integration: RAG on Live Data from Any Platform
|
||||
|
||||
Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
|
||||
**NEW!** Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
|
||||
|
||||
**Key Benefits:**
|
||||
- **Live Data Access**: Fetch real-time data without manual exports
|
||||
@@ -850,7 +785,8 @@ Connect to live data sources through the Model Context Protocol (MCP). LEANN now
|
||||
- **Easy Extension**: Add new platforms with minimal code
|
||||
- **Secure Access**: MCP servers handle authentication
|
||||
|
||||
#### 💬 Slack Messages: Search Your Team Conversations
|
||||
<details>
|
||||
<summary><strong>Slack Messages: Search Your Team Conversations</strong></summary>
|
||||
|
||||
Transform your Slack workspace into a searchable knowledge base! Find discussions, decisions, and shared knowledge across all your channels.
|
||||
|
||||
@@ -866,17 +802,18 @@ python -m apps.slack_rag \
|
||||
--query "What did we decide about the product launch?"
|
||||
```
|
||||
|
||||
**📖 Comprehensive Setup Guide**: For detailed setup instructions, troubleshooting common issues (like "users cache is not ready yet"), and advanced configuration options, see our [**Slack Setup Guide**](docs/slack-setup-guide.md).
|
||||
|
||||
**Quick Setup:**
|
||||
**Setup Requirements:**
|
||||
1. Install a Slack MCP server (e.g., `npm install -g slack-mcp-server`)
|
||||
2. Create a Slack App and get API credentials (see detailed guide above)
|
||||
3. Set environment variables:
|
||||
2. Create a Slack App and get API credentials:
|
||||
- Go to [api.slack.com/apps](https://api.slack.com/apps) and create a new app
|
||||
- Under "OAuth & Permissions", add these Bot Token Scopes: `channels:read`, `channels:history`, `groups:read`, `groups:history`, `im:read`, `im:history`, `mpim:read`, `mpim:history`
|
||||
- Install the app to your workspace and copy the "Bot User OAuth Token" (starts with `xoxb-`)
|
||||
- Under "App-Level Tokens", create a token with `connections:write` scope (starts with `xapp-`)
|
||||
```bash
|
||||
export SLACK_BOT_TOKEN="xoxb-your-bot-token"
|
||||
export SLACK_APP_TOKEN="xapp-your-app-token" # Optional
|
||||
export SLACK_APP_TOKEN="xapp-your-app-token"
|
||||
```
|
||||
4. Test connection with `--test-connection` flag
|
||||
3. Test connection with `--test-connection` flag
|
||||
|
||||
**Arguments:**
|
||||
- `--mcp-server`: Command to start the Slack MCP server
|
||||
@@ -884,10 +821,11 @@ python -m apps.slack_rag \
|
||||
- `--channels`: Specific channels to index (optional)
|
||||
- `--concatenate-conversations`: Group messages by channel (default: true)
|
||||
- `--max-messages-per-channel`: Limit messages per channel (default: 100)
|
||||
- `--max-retries`: Maximum retries for cache sync issues (default: 5)
|
||||
- `--retry-delay`: Initial delay between retries in seconds (default: 2.0)
|
||||
|
||||
#### 🐦 Twitter Bookmarks: Your Personal Tweet Library
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>Twitter Bookmarks: Your Personal Tweet Library</strong></summary>
|
||||
|
||||
Search through your Twitter bookmarks! Find that perfect article, thread, or insight you saved for later.
|
||||
|
||||
@@ -942,6 +880,8 @@ python -m apps.twitter_rag \
|
||||
- "What Python tutorials did I save?"
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Using MCP with CLI Commands</strong></summary>
|
||||
|
||||
**Want to use MCP data with regular LEANN CLI?** You can combine MCP apps with CLI commands:
|
||||
@@ -987,7 +927,7 @@ Want to add support for other platforms? LEANN's MCP integration is designed for
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
<details>
|
||||
<summary><strong>AST‑Aware Code Chunking</strong></summary>
|
||||
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||
|
||||
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||
|
||||
@@ -1106,10 +1046,10 @@ Options:
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
**List Command:**
|
||||
@@ -1274,7 +1214,3 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
|
||||
<p align="center">
|
||||
Made with ❤️ by the Leann team
|
||||
</p>
|
||||
|
||||
## 🤖 Explore LEANN with AI
|
||||
|
||||
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.
|
||||
|
||||
@@ -6,43 +6,12 @@ Provides common parameters and functionality for all RAG examples.
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
|
||||
# Optional import: older PyPI builds may not include interactive_utils
|
||||
try:
|
||||
from leann.interactive_utils import create_rag_session
|
||||
except ImportError:
|
||||
|
||||
def create_rag_session(app_name: str, data_description: str):
|
||||
class _SimpleSession:
|
||||
def run_interactive_loop(self, handler):
|
||||
print(f"Interactive session for {app_name}: {data_description}")
|
||||
print("Interactive mode not available in this build")
|
||||
|
||||
return _SimpleSession()
|
||||
|
||||
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
# Optional import: older PyPI builds may not include settings
|
||||
try:
|
||||
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
except ImportError:
|
||||
# Minimal fallbacks if settings helpers are unavailable
|
||||
import os
|
||||
|
||||
def resolve_ollama_host(value: str | None) -> str | None:
|
||||
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
|
||||
|
||||
def resolve_openai_api_key(value: str | None) -> str | None:
|
||||
return value or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
def resolve_openai_base_url(value: str | None) -> str | None:
|
||||
return value or os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@@ -180,14 +149,14 @@ class BaseRAGExample(ABC):
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-size",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
|
||||
default=512,
|
||||
help="Maximum characters per AST chunk (default: 512)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
|
||||
help="Overlap between AST chunks (default: 64)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--code-file-extensions",
|
||||
@@ -257,8 +226,8 @@ class BaseRAGExample(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
|
||||
"""Load data from the source. Returns list of text chunks (strings or dicts with 'text' key)."""
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
@@ -282,8 +251,8 @@ class BaseRAGExample(ABC):
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[Union[str, dict[str, Any]]]) -> str:
|
||||
"""Build LEANN index from texts (accepts strings or dicts with 'text' key)."""
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
@@ -314,14 +283,8 @@ class BaseRAGExample(ABC):
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for item in batch:
|
||||
# Handle both dict format (from create_text_chunks) and plain strings
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text", "")
|
||||
metadata = item.get("metadata")
|
||||
builder.add_text(text, metadata)
|
||||
else:
|
||||
builder.add_text(item)
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
@@ -344,26 +307,37 @@ class BaseRAGExample(ABC):
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
# Create interactive session
|
||||
session = create_rag_session(
|
||||
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
|
||||
)
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
|
||||
def handle_query(query: str):
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
if not query:
|
||||
continue
|
||||
|
||||
session.run_interactive_loop(handle_query)
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
|
||||
@@ -12,7 +12,6 @@ from pathlib import Path
|
||||
try:
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
@@ -26,7 +25,6 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
sys.path.insert(0, str(leann_src))
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
@@ -38,7 +36,6 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
|
||||
__all__ = [
|
||||
"CODE_EXTENSIONS",
|
||||
"_traditional_chunks_as_dicts",
|
||||
"create_ast_chunks",
|
||||
"create_text_chunks",
|
||||
"create_traditional_chunks",
|
||||
|
||||
@@ -1,364 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
|
||||
|
||||
Usage:
|
||||
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
|
||||
python -m apps.colqwen_rag search my_index "How does attention work?"
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
# Add LEANN packages to path
|
||||
_repo_root = Path(__file__).resolve().parents[1]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
import torch # noqa: E402
|
||||
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
|
||||
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
|
||||
from pdf2image import convert_from_path # noqa: E402
|
||||
from PIL import Image # noqa: E402
|
||||
from torch.utils.data import DataLoader # noqa: E402
|
||||
from tqdm import tqdm # noqa: E402
|
||||
|
||||
# Import the existing multi-vector implementation
|
||||
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
|
||||
class ColQwenRAG:
|
||||
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
|
||||
|
||||
def __init__(self, model_type: str = "colpali"):
|
||||
"""
|
||||
Initialize ColQwen RAG system.
|
||||
|
||||
Args:
|
||||
model_type: "colqwen2" or "colpali"
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.device = self._get_device()
|
||||
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
|
||||
if self.device.type == "mps":
|
||||
self.dtype = torch.float32
|
||||
elif self.device.type == "cuda":
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
|
||||
|
||||
# Load model and processor with MPS-optimized settings
|
||||
try:
|
||||
if model_type == "colqwen2":
|
||||
self.model_name = "vidore/colqwen2-v1.0"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||
else: # colpali
|
||||
self.model_name = "vidore/colpali-v1.2"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
|
||||
except Exception as e:
|
||||
if "memory" in str(e).lower() or "offload" in str(e).lower():
|
||||
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_type == "colqwen2":
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_device(self):
|
||||
"""Auto-select best available device."""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
|
||||
"""
|
||||
Build multimodal index from PDF files.
|
||||
|
||||
Args:
|
||||
pdf_paths: List of PDF file paths
|
||||
index_name: Name for the index
|
||||
pages_dir: Directory to save page images (optional)
|
||||
"""
|
||||
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
|
||||
|
||||
# Convert PDFs to images
|
||||
all_images = []
|
||||
all_metadata = []
|
||||
|
||||
if pages_dir:
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
|
||||
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
|
||||
try:
|
||||
images = convert_from_path(pdf_path, dpi=150)
|
||||
pdf_name = Path(pdf_path).stem
|
||||
|
||||
for i, image in enumerate(images):
|
||||
# Save image if pages_dir specified
|
||||
if pages_dir:
|
||||
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
|
||||
image.save(image_path)
|
||||
|
||||
all_images.append(image)
|
||||
all_metadata.append(
|
||||
{
|
||||
"pdf_path": pdf_path,
|
||||
"pdf_name": pdf_name,
|
||||
"page_number": i + 1,
|
||||
"image_path": str(image_path) if pages_dir else None,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {pdf_path}: {e}")
|
||||
continue
|
||||
|
||||
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
|
||||
print(f"All metadata: {all_metadata}")
|
||||
|
||||
# Generate embeddings
|
||||
print("🧠 Generating embeddings...")
|
||||
embeddings = self._embed_images(all_images)
|
||||
|
||||
# Build LEANN index
|
||||
print("🔍 Building LEANN index...")
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=embeddings.shape[-1],
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Create collection and insert data
|
||||
leann_mv.create_collection()
|
||||
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
|
||||
data = {
|
||||
"doc_id": i,
|
||||
"filepath": metadata.get("image_path", ""),
|
||||
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
|
||||
}
|
||||
leann_mv.insert(data)
|
||||
|
||||
# Build the index
|
||||
leann_mv.create_index()
|
||||
print(f"✅ Index '{index_name}' built successfully!")
|
||||
|
||||
return leann_mv
|
||||
|
||||
def search(self, index_name: str, query: str, top_k: int = 5):
|
||||
"""
|
||||
Search the index with a text query.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to search
|
||||
query: Text query
|
||||
top_k: Number of results to return
|
||||
"""
|
||||
print(f"🔍 Searching '{index_name}' for: '{query}'")
|
||||
|
||||
# Load index
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=128, # Will be updated when loading
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = self._embed_query(query)
|
||||
|
||||
# Search (returns list of (score, doc_id) tuples)
|
||||
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
|
||||
|
||||
# Display results
|
||||
print(f"\n📋 Top {len(search_results)} results:")
|
||||
for i, (score, doc_id) in enumerate(search_results, 1):
|
||||
# Get metadata for this doc_id (we need to load the metadata)
|
||||
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
|
||||
|
||||
return search_results
|
||||
|
||||
def ask(self, index_name: str, interactive: bool = False):
|
||||
"""
|
||||
Interactive Q&A with the indexed documents.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to query
|
||||
interactive: Whether to run in interactive mode
|
||||
"""
|
||||
print(f"💬 ColQwen Chat with '{index_name}'")
|
||||
|
||||
if interactive:
|
||||
print("Type 'quit' to exit, 'help' for commands")
|
||||
while True:
|
||||
try:
|
||||
query = input("\n🤔 Your question: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("Commands: quit/exit/q (exit), help (this message)")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
self.search(index_name, query, top_k=3)
|
||||
|
||||
# TODO: Add answer generation with Qwen-VL
|
||||
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
else:
|
||||
query = input("🤔 Your question: ").strip()
|
||||
if query:
|
||||
self.search(index_name, query)
|
||||
|
||||
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
|
||||
"""Generate embeddings for a list of images."""
|
||||
dataset = ListDataset(images)
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
|
||||
|
||||
embeddings = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Embedding images"):
|
||||
batch_images = cast(list, batch)
|
||||
batch_inputs = self.processor.process_images(batch_images).to(self.device)
|
||||
batch_embeddings = self.model(**batch_inputs)
|
||||
embeddings.append(batch_embeddings.cpu())
|
||||
|
||||
return torch.cat(embeddings, dim=0)
|
||||
|
||||
def _embed_query(self, query: str) -> torch.Tensor:
|
||||
"""Generate embedding for a text query."""
|
||||
with torch.no_grad():
|
||||
query_inputs = self.processor.process_queries([query]).to(self.device)
|
||||
query_embedding = self.model(**query_inputs)
|
||||
return query_embedding.cpu()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
|
||||
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
|
||||
build_parser.add_argument("--index", required=True, help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
build_parser.add_argument("--pages-dir", help="Directory to save page images")
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search the index")
|
||||
search_parser.add_argument("index", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
|
||||
search_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
|
||||
ask_parser.add_argument("index", help="Index name")
|
||||
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||
ask_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Initialize ColQwen RAG
|
||||
if args.command == "build":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
|
||||
# Get PDF files
|
||||
pdf_dir = Path(args.pdfs)
|
||||
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
|
||||
pdf_paths = [str(pdf_dir)]
|
||||
elif pdf_dir.is_dir():
|
||||
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
|
||||
else:
|
||||
print(f"❌ Invalid PDF path: {args.pdfs}")
|
||||
return
|
||||
|
||||
if not pdf_paths:
|
||||
print(f"❌ No PDF files found in {args.pdfs}")
|
||||
return
|
||||
|
||||
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
|
||||
|
||||
elif args.command == "search":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.search(args.index, args.query, args.top_k)
|
||||
|
||||
elif args.command == "ask":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.ask(args.index, args.interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,7 +5,6 @@ Supports PDF, TXT, MD, and other document formats.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -52,7 +51,7 @@ class DocumentRAG(BaseRAGExample):
|
||||
help="Enable AST-aware chunking for code files in the data directory",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLIP Image RAG Application
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||
You can index a directory of images and search them using text queries.
|
||||
|
||||
Usage:
|
||||
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
|
||||
|
||||
class ImageRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for images using CLIP embeddings.
|
||||
|
||||
This class provides a complete RAG pipeline for image data, including
|
||||
CLIP embedding generation, indexing, and text-based image search.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Image RAG",
|
||||
description="RAG application for images using CLIP embeddings",
|
||||
default_index_name="image_index",
|
||||
)
|
||||
# Override default embedding model to use CLIP
|
||||
self.embedding_model_default = "clip-ViT-L-14"
|
||||
self.embedding_mode_default = "sentence-transformers"
|
||||
self._image_data: list[dict] = []
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add image-specific arguments."""
|
||||
image_group = parser.add_argument_group("Image Parameters")
|
||||
image_group.add_argument(
|
||||
"--image-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing images to index",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--image-extensions",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for CLIP embedding generation (default: 32)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||
self._image_data = self._load_images_and_embeddings(args)
|
||||
return [entry["text"] for entry in self._image_data]
|
||||
|
||||
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||
"""Helper to process images and produce embeddings/metadata."""
|
||||
image_dir = Path(args.image_dir)
|
||||
if not image_dir.exists():
|
||||
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||
|
||||
print(f"📸 Loading images from {image_dir}...")
|
||||
|
||||
# Find all image files
|
||||
image_files = []
|
||||
for ext in args.image_extensions:
|
||||
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||
|
||||
if not image_files:
|
||||
raise ValueError(
|
||||
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||
)
|
||||
|
||||
print(f"✅ Found {len(image_files)} images")
|
||||
|
||||
# Limit if max_items is set
|
||||
if args.max_items > 0:
|
||||
image_files = image_files[: args.max_items]
|
||||
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||
|
||||
# Load CLIP model
|
||||
print("🔍 Loading CLIP model...")
|
||||
model = SentenceTransformer(self.embedding_model_default)
|
||||
|
||||
# Process images and generate embeddings
|
||||
print("🖼️ Processing images and generating embeddings...")
|
||||
image_data = []
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
for image_path in tqdm(image_files, desc="Processing images"):
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
batch_images.append(image)
|
||||
batch_paths.append(image_path)
|
||||
|
||||
# Process in batches
|
||||
if len(batch_images) >= args.batch_size:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=args.batch_size,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||
continue
|
||||
|
||||
# Process remaining images
|
||||
if batch_images:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=len(batch_images),
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"✅ Processed {len(image_data)} images")
|
||||
return image_data
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build index using pre-computed CLIP embeddings."""
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if not self._image_data or len(self._image_data) != len(texts):
|
||||
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||
|
||||
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=self.embedding_model_default,
|
||||
embedding_mode=self.embedding_mode_default,
|
||||
is_recompute=False,
|
||||
distance_metric="cosine",
|
||||
graph_degree=args.graph_degree,
|
||||
build_complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
)
|
||||
|
||||
for text, data in zip(texts, self._image_data):
|
||||
builder.add_text(text=text, metadata=data["metadata"])
|
||||
|
||||
ids = [str(i) for i in range(len(self._image_data))]
|
||||
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||
pickle.dump((ids, embeddings), f)
|
||||
pkl_path = f.name
|
||||
|
||||
try:
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||
print(f"✅ Index built successfully at {index_path}")
|
||||
return index_path
|
||||
finally:
|
||||
Path(pkl_path).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the image RAG application."""
|
||||
import asyncio
|
||||
|
||||
app = ImageRAG()
|
||||
asyncio.run(app.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,132 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the current directory to path to import leann_multi_vector
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||
from PIL import Image
|
||||
|
||||
# Ensure repo paths are importable
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Set environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image."""
|
||||
# Create a simple RGB image (800x600)
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
return img
|
||||
|
||||
|
||||
def load_test_image_from_file():
|
||||
"""Try to load an image from the indexes directory if available."""
|
||||
# Try to find an existing image in the indexes directory
|
||||
indexes_dir = Path(__file__).parent / "indexes"
|
||||
|
||||
# Look for images in common locations
|
||||
possible_paths = [
|
||||
indexes_dir / "vidore_fastplaid" / "images",
|
||||
indexes_dir / "colvision_large.leann.images",
|
||||
indexes_dir / "colvision.leann.images",
|
||||
]
|
||||
|
||||
for img_dir in possible_paths:
|
||||
if img_dir.exists():
|
||||
# Find first image file
|
||||
for ext in [".png", ".jpg", ".jpeg"]:
|
||||
for img_file in img_dir.glob(f"*{ext}"):
|
||||
print(f"Loading test image from: {img_file}")
|
||||
return Image.open(img_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Testing ColQwen2 Forward Pass")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Load or create test image
|
||||
print("\n[Step 1] Loading test image...")
|
||||
test_image = load_test_image_from_file()
|
||||
if test_image is None:
|
||||
print("No existing image found, creating a simple test image...")
|
||||
test_image = create_test_image()
|
||||
else:
|
||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||
|
||||
# Convert to RGB if needed
|
||||
if test_image.mode != "RGB":
|
||||
test_image = test_image.convert("RGB")
|
||||
print(f"✓ Converted to RGB: {test_image.size}")
|
||||
|
||||
# Step 2: Load model
|
||||
print("\n[Step 2] Loading ColQwen2 model...")
|
||||
try:
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||
print(f"✓ Model loaded: {model_name}")
|
||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||
|
||||
# Print model info
|
||||
if hasattr(model, "device"):
|
||||
print(f"✓ Model device: {model.device}")
|
||||
if hasattr(model, "dtype"):
|
||||
print(f"✓ Model dtype: {model.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Step 3: Test forward pass
|
||||
print("\n[Step 3] Running forward pass...")
|
||||
try:
|
||||
# Use the _embed_images function which handles batching and forward pass
|
||||
images = [test_image]
|
||||
print(f"Processing {len(images)} image(s)...")
|
||||
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
print("✓ Forward pass completed!")
|
||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||
|
||||
if len(doc_vecs) > 0:
|
||||
emb = doc_vecs[0]
|
||||
print(f"✓ Embedding shape: {emb.shape}")
|
||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||
print("✓ Embedding stats:")
|
||||
print(f" - Min: {emb.min().item():.4f}")
|
||||
print(f" - Max: {emb.max().item():.4f}")
|
||||
print(f" - Mean: {emb.mean().item():.4f}")
|
||||
print(f" - Std: {emb.std().item():.4f}")
|
||||
|
||||
# Check for NaN or Inf
|
||||
if torch.isnan(emb).any():
|
||||
print("⚠ Warning: Embedding contains NaN values!")
|
||||
if torch.isinf(emb).any():
|
||||
print("⚠ Warning: Embedding contains Inf values!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during forward pass: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,448 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v1 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v1 tasks
|
||||
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# ViDoRe v1 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V1_TASKS = {
|
||||
"VidoreArxivQARetrieval": {
|
||||
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreDocVQARetrieval": {
|
||||
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreInfoVQARetrieval": {
|
||||
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTabfquadRetrieval": {
|
||||
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTatdqaRetrieval": {
|
||||
"dataset_path": "vidore/tatdqa_test_beir",
|
||||
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreShiftProjectRetrieval": {
|
||||
"dataset_path": "vidore/shiftproject_test_beir",
|
||||
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAAIRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
# Task name aliases (short names -> full names)
|
||||
TASK_ALIASES = {
|
||||
"arxivqa": "VidoreArxivQARetrieval",
|
||||
"docvqa": "VidoreDocVQARetrieval",
|
||||
"infovqa": "VidoreInfoVQARetrieval",
|
||||
"tabfquad": "VidoreTabfquadRetrieval",
|
||||
"tatdqa": "VidoreTatdqaRetrieval",
|
||||
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||
}
|
||||
|
||||
|
||||
def normalize_task_name(task_name: str) -> str:
|
||||
"""Normalize task name (handle aliases)."""
|
||||
task_name_lower = task_name.lower()
|
||||
if task_name in VIDORE_V1_TASKS:
|
||||
return task_name
|
||||
if task_name_lower in TASK_ALIASES:
|
||||
return TASK_ALIASES[task_name_lower]
|
||||
# Try partial match
|
||||
for alias, full_name in TASK_ALIASES.items():
|
||||
if alias in task_name_lower or task_name_lower in alias:
|
||||
return full_name
|
||||
return task_name
|
||||
|
||||
|
||||
def get_safe_model_name(model_name: str) -> str:
|
||||
"""Get a safe model name for use in file paths."""
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
# If it's a path, use basename or hash
|
||||
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||
# Use basename if it's reasonable, otherwise use hash
|
||||
basename = os.path.basename(model_name.rstrip("/"))
|
||||
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||
return basename
|
||||
# Use hash for very long or problematic paths
|
||||
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||
# For HuggingFace model names, replace / with _
|
||||
return model_name.replace("/", "_").replace(":", "_")
|
||||
|
||||
|
||||
def load_vidore_v1_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v1 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 1000,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v1 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Normalize task name (handle aliases)
|
||||
task_name = normalize_task_name(task_name)
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V1_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V1_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v1_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
)
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
# Use safe model name for index path (different models need different indexes)
|
||||
safe_model_name = get_safe_model_name(model_name)
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,20,100,1000",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v1_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [normalize_task_name(args.task)]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,439 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v2 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v2 tasks
|
||||
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Language name to dataset language field value mapping
|
||||
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
|
||||
LANGUAGE_MAPPING = {
|
||||
"english": "eng-Latn",
|
||||
"french": "fra-Latn",
|
||||
"spanish": "spa-Latn",
|
||||
"german": "deu-Latn",
|
||||
}
|
||||
|
||||
# ViDoRe v2 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V2_TASKS = {
|
||||
"Vidore2ESGReportsRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_v2",
|
||||
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2EconomicsReportsRetrieval": {
|
||||
"dataset_path": "vidore/economics_reports_v2",
|
||||
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2BioMedicalLecturesRetrieval": {
|
||||
"dataset_path": "vidore/biomedical_lectures_v2",
|
||||
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2ESGReportsHLRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_human_labeled_v2",
|
||||
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
|
||||
# Note: This dataset doesn't have language filtering - all queries are English
|
||||
"languages": None, # No language filtering needed
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_vidore_v2_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v2 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
# Check if dataset has language field before filtering
|
||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||
|
||||
if language and has_language_field:
|
||||
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||
# Check if filtering resulted in empty dataset
|
||||
if len(query_ds_filtered) == 0:
|
||||
print(
|
||||
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||
)
|
||||
# Try with original language value (dataset might use simple names like 'english')
|
||||
print(f"Trying with original language value '{language}'...")
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = load_dataset(
|
||||
dataset_path, "queries", split=split, revision=revision
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
print(f"Available language values in dataset: {sample_langs}")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v2 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V2_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V2_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
# Determine language
|
||||
if language is None:
|
||||
# Use first language if multiple available
|
||||
languages = task_config.get("languages")
|
||||
if languages is None:
|
||||
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||
language = None
|
||||
elif len(languages) == 1:
|
||||
language = languages[0]
|
||||
else:
|
||||
language = None
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v2_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(
|
||||
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||
)
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
choices=["colqwen2", "colpali"],
|
||||
help="Model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Language to evaluate (default: first available)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Top-k results to retrieve",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,100",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v2_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
language=args.language,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,183 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
|
||||
class TimeParser:
|
||||
def __init__(self):
|
||||
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
|
||||
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
|
||||
|
||||
# Compile for performance
|
||||
self.regex = re.compile(self.pattern, re.IGNORECASE)
|
||||
|
||||
# Stop words to remove before regex parsing
|
||||
self.stop_words = {
|
||||
"in",
|
||||
"at",
|
||||
"of",
|
||||
"by",
|
||||
"as",
|
||||
"me",
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"any",
|
||||
"find",
|
||||
"search",
|
||||
"list",
|
||||
"ago",
|
||||
"back",
|
||||
"past",
|
||||
"earlier",
|
||||
}
|
||||
|
||||
def clean_text(self, text):
|
||||
"""Remove stop words from text"""
|
||||
words = text.split()
|
||||
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
|
||||
return cleaned
|
||||
|
||||
def parse(self, text):
|
||||
"""Extract all time expressions from text"""
|
||||
# Clean text first
|
||||
cleaned_text = self.clean_text(text)
|
||||
|
||||
matches = []
|
||||
for match in self.regex.finditer(cleaned_text):
|
||||
fuzzy = match.group(1) # "around", "about", etc.
|
||||
number = int(match.group(2))
|
||||
unit = match.group(3).lower()
|
||||
|
||||
matches.append(
|
||||
{
|
||||
"full_match": match.group(0),
|
||||
"fuzzy": bool(fuzzy),
|
||||
"number": number,
|
||||
"unit": unit,
|
||||
"range": self.calculate_range(number, unit, bool(fuzzy)),
|
||||
}
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
def calculate_range(self, number, unit, is_fuzzy):
|
||||
"""Convert to actual datetime range and return ISO format strings"""
|
||||
units = {
|
||||
"hour": timedelta(hours=number),
|
||||
"day": timedelta(days=number),
|
||||
"week": timedelta(weeks=number),
|
||||
"month": timedelta(days=number * 30),
|
||||
"year": timedelta(days=number * 365),
|
||||
}
|
||||
|
||||
delta = units[unit]
|
||||
now = datetime.now()
|
||||
target = now - delta
|
||||
|
||||
if is_fuzzy:
|
||||
buffer = delta * 0.2 # 20% buffer for fuzzy
|
||||
start = (target - buffer).isoformat()
|
||||
end = (target + buffer).isoformat()
|
||||
else:
|
||||
start = target.isoformat()
|
||||
end = now.isoformat()
|
||||
|
||||
return (start, end)
|
||||
|
||||
|
||||
def search_files(query, top_k=15):
|
||||
"""Search the index and return results"""
|
||||
# Parse time expressions
|
||||
parser = TimeParser()
|
||||
time_matches = parser.parse(query)
|
||||
|
||||
# Remove time expressions from query for semantic search
|
||||
clean_query = query
|
||||
if time_matches:
|
||||
for match in time_matches:
|
||||
clean_query = clean_query.replace(match["full_match"], "").strip()
|
||||
|
||||
# Check if clean_query is less than 4 characters
|
||||
if len(clean_query) < 4:
|
||||
print("Error: add more input for accurate results.")
|
||||
return
|
||||
|
||||
# Single query to vector DB
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search(
|
||||
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
|
||||
)
|
||||
|
||||
# Filter by time if time expression found
|
||||
if time_matches:
|
||||
time_range = time_matches[0]["range"] # Use first time expression
|
||||
start_time, end_time = time_range
|
||||
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
# Access metadata attribute directly (not .get())
|
||||
metadata = result.metadata if hasattr(result, "metadata") else {}
|
||||
|
||||
if metadata:
|
||||
# Check modification date first, fall back to creation date
|
||||
date_str = metadata.get("modification_date") or metadata.get("creation_date")
|
||||
|
||||
if date_str:
|
||||
# Convert strings to datetime objects for proper comparison
|
||||
try:
|
||||
file_date = datetime.fromisoformat(date_str)
|
||||
start_dt = datetime.fromisoformat(start_time)
|
||||
end_dt = datetime.fromisoformat(end_time)
|
||||
|
||||
# Compare dates properly
|
||||
if start_dt <= file_date <= end_dt:
|
||||
filtered_results.append(result)
|
||||
except (ValueError, TypeError):
|
||||
# Handle invalid date formats
|
||||
print(f"Warning: Invalid date format in metadata: {date_str}")
|
||||
continue
|
||||
|
||||
results = filtered_results
|
||||
|
||||
# Print results
|
||||
print(f"\nSearch results for: '{query}'")
|
||||
if time_matches:
|
||||
print(
|
||||
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
|
||||
)
|
||||
print(
|
||||
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n[{i}] Score: {result.score:.4f}")
|
||||
print(f"Content: {result.text}")
|
||||
|
||||
# Show metadata if present
|
||||
metadata = result.metadata if hasattr(result, "metadata") else None
|
||||
if metadata:
|
||||
if "creation_date" in metadata:
|
||||
print(f"Created: {metadata['creation_date']}")
|
||||
if "modification_date" in metadata:
|
||||
print(f"Modified: {metadata['modification_date']}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print('Usage: python search_index.py "<search query>" [top_k]')
|
||||
sys.exit(1)
|
||||
|
||||
query = sys.argv[1]
|
||||
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
|
||||
|
||||
search_files(query, top_k)
|
||||
@@ -1,82 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
def process_json_items(json_file_path):
|
||||
"""Load and process JSON file with metadata items"""
|
||||
|
||||
with open(json_file_path, encoding="utf-8") as f:
|
||||
items = json.load(f)
|
||||
|
||||
# Guard against empty JSON
|
||||
if not items:
|
||||
print("⚠️ No items found in the JSON file. Exiting gracefully.")
|
||||
return
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
|
||||
|
||||
total_items = len(items)
|
||||
items_added = 0
|
||||
print(f"Processing {total_items} items...")
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
try:
|
||||
# Create embedding text sentence
|
||||
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
|
||||
|
||||
# Prepare metadata with dates
|
||||
metadata = {}
|
||||
if "CreationDate" in item:
|
||||
metadata["creation_date"] = item["CreationDate"]
|
||||
if "ContentChangeDate" in item:
|
||||
metadata["modification_date"] = item["ContentChangeDate"]
|
||||
|
||||
# Add to builder
|
||||
builder.add_text(embedding_text, metadata=metadata)
|
||||
items_added += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
|
||||
continue
|
||||
|
||||
# Show progress
|
||||
progress = (idx + 1) / total_items * 100
|
||||
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
|
||||
sys.stdout.flush()
|
||||
|
||||
print() # New line after progress
|
||||
|
||||
# Guard against no successfully added items
|
||||
if items_added == 0:
|
||||
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
|
||||
return
|
||||
|
||||
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
|
||||
print("Building index...")
|
||||
|
||||
try:
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"✓ Index saved to {INDEX_PATH}")
|
||||
except ValueError as e:
|
||||
if "No chunks added" in str(e):
|
||||
print("⚠️ No chunks were added to the builder. Index not created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python build_index.py <json_file>")
|
||||
sys.exit(1)
|
||||
|
||||
json_file = sys.argv[1]
|
||||
if not Path(json_file).exists():
|
||||
print(f"Error: File {json_file} not found")
|
||||
sys.exit(1)
|
||||
|
||||
process_json_items(json_file)
|
||||
@@ -1,265 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Spotlight Metadata Dumper for Vector DB
|
||||
Extracts only essential metadata for semantic search embeddings
|
||||
Output is optimized for vector database storage with minimal fields
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# Check platform before importing macOS-specific modules
|
||||
if sys.platform != "darwin":
|
||||
print("This script requires macOS (uses Spotlight)")
|
||||
sys.exit(1)
|
||||
|
||||
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
|
||||
|
||||
# EDIT THIS LIST: Add or remove folders to search
|
||||
# Can be either:
|
||||
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
|
||||
# - Absolute paths (e.g., "/Applications", "/System/Library")
|
||||
SEARCH_FOLDERS = [
|
||||
"Desktop",
|
||||
"Downloads",
|
||||
"Documents",
|
||||
"Music",
|
||||
"Pictures",
|
||||
"Movies",
|
||||
# "Library", # Uncomment to include
|
||||
# "/Applications", # Absolute path example
|
||||
# "Code/Projects", # Subfolder example
|
||||
# Add any other folders here
|
||||
]
|
||||
|
||||
|
||||
def convert_to_serializable(obj):
|
||||
"""Convert NS objects to Python serializable types"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle NSDate
|
||||
if hasattr(obj, "timeIntervalSince1970"):
|
||||
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
|
||||
|
||||
# Handle NSArray
|
||||
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
|
||||
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
|
||||
|
||||
# Convert to string
|
||||
try:
|
||||
return str(obj)
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
|
||||
|
||||
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
|
||||
"""
|
||||
Dump Spotlight data using public.item predicate
|
||||
"""
|
||||
# Build full paths from SEARCH_FOLDERS
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
search_paths = []
|
||||
|
||||
print("Search locations:")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Check if it's an absolute path or relative
|
||||
if folder.startswith("/"):
|
||||
full_path = folder
|
||||
else:
|
||||
full_path = os.path.join(home_dir, folder)
|
||||
|
||||
if os.path.exists(full_path):
|
||||
search_paths.append(full_path)
|
||||
print(f" ✓ {full_path}")
|
||||
else:
|
||||
print(f" ✗ {full_path} (not found)")
|
||||
|
||||
if not search_paths:
|
||||
print("No valid search paths found!")
|
||||
return []
|
||||
|
||||
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
|
||||
|
||||
# Create query with public.item predicate
|
||||
query = NSMetadataQuery.alloc().init()
|
||||
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
|
||||
query.setPredicate_(predicate)
|
||||
|
||||
# Set search scopes to our specific folders
|
||||
query.setSearchScopes_(search_paths)
|
||||
|
||||
print("Starting query...")
|
||||
query.startQuery()
|
||||
|
||||
# Wait for gathering to complete
|
||||
run_loop = NSRunLoop.currentRunLoop()
|
||||
print("Gathering results...")
|
||||
|
||||
# Let it gather for a few seconds
|
||||
for i in range(50): # 5 seconds max
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
# Check gathering status periodically
|
||||
if i % 10 == 0:
|
||||
current_count = query.resultCount()
|
||||
if current_count > 0:
|
||||
print(f" Found {current_count} items so far...")
|
||||
|
||||
# Continue while still gathering (up to 2 more seconds)
|
||||
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
|
||||
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
|
||||
query.stopQuery()
|
||||
|
||||
total_results = query.resultCount()
|
||||
print(f"Found {total_results} total items")
|
||||
|
||||
if total_results == 0:
|
||||
print("No results found")
|
||||
return []
|
||||
|
||||
# Process items
|
||||
items_to_process = min(total_results, max_items)
|
||||
results = []
|
||||
|
||||
# ONLY relevant attributes for vector embeddings
|
||||
# These provide essential context for semantic search without bloat
|
||||
attributes = [
|
||||
"kMDItemPath", # Full path for file retrieval
|
||||
"kMDItemFSName", # Filename for display & embedding
|
||||
"kMDItemFSSize", # Size for filtering/ranking
|
||||
"kMDItemContentType", # File type for categorization
|
||||
"kMDItemKind", # Human-readable type for embedding
|
||||
"kMDItemFSCreationDate", # Temporal context
|
||||
"kMDItemFSContentChangeDate", # Recency for ranking
|
||||
]
|
||||
|
||||
print(f"Processing {items_to_process} items...")
|
||||
|
||||
for i in range(items_to_process):
|
||||
try:
|
||||
item = query.resultAtIndex_(i)
|
||||
metadata = {}
|
||||
|
||||
# Extract ONLY the relevant attributes
|
||||
for attr in attributes:
|
||||
try:
|
||||
value = item.valueForAttribute_(attr)
|
||||
if value is not None:
|
||||
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
|
||||
clean_key = attr.replace("kMDItem", "").replace("FS", "")
|
||||
metadata[clean_key] = convert_to_serializable(value)
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Only add if we have at least a path
|
||||
if metadata.get("Path"):
|
||||
results.append(metadata)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i}: {e}")
|
||||
continue
|
||||
|
||||
# Save to JSON
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\n✓ Saved {len(results)} items to {output_file}")
|
||||
|
||||
# Show summary
|
||||
print("\nSample items:")
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
for i, item in enumerate(results[:3]):
|
||||
print(f"\n[Item {i + 1}]")
|
||||
print(f" Path: {item.get('Path', 'N/A')}")
|
||||
print(f" Name: {item.get('Name', 'N/A')}")
|
||||
print(f" Type: {item.get('ContentType', 'N/A')}")
|
||||
print(f" Kind: {item.get('Kind', 'N/A')}")
|
||||
|
||||
# Handle size properly
|
||||
size = item.get("Size")
|
||||
if size:
|
||||
try:
|
||||
size_int = int(size)
|
||||
if size_int > 1024 * 1024:
|
||||
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
|
||||
elif size_int > 1024:
|
||||
print(f" Size: {size_int / 1024:.2f} KB")
|
||||
else:
|
||||
print(f" Size: {size_int} bytes")
|
||||
except (ValueError, TypeError):
|
||||
print(f" Size: {size}")
|
||||
|
||||
# Show dates
|
||||
if "CreationDate" in item:
|
||||
print(f" Created: {item['CreationDate']}")
|
||||
if "ContentChangeDate" in item:
|
||||
print(f" Modified: {item['ContentChangeDate']}")
|
||||
|
||||
# Count by type
|
||||
type_counts = {}
|
||||
for item in results:
|
||||
content_type = item.get("ContentType", "unknown")
|
||||
type_counts[content_type] = type_counts.get(content_type, 0) + 1
|
||||
|
||||
print(f"\nTotal items saved: {len(results)}")
|
||||
|
||||
if type_counts:
|
||||
print("\nTop content types:")
|
||||
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
print(f" {ct}: {count} items")
|
||||
|
||||
# Count by folder
|
||||
folder_counts = {}
|
||||
for item in results:
|
||||
path = item.get("Path", "")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Build the full folder path
|
||||
if folder.startswith("/"):
|
||||
folder_path = folder
|
||||
else:
|
||||
folder_path = os.path.join(home_dir, folder)
|
||||
|
||||
if path.startswith(folder_path):
|
||||
folder_counts[folder] = folder_counts.get(folder, 0) + 1
|
||||
break
|
||||
|
||||
if folder_counts:
|
||||
print("\nItems by location:")
|
||||
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {folder}: {count} items")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
max_items = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print("Usage: python spot.py [number_of_items]")
|
||||
print("Default: 10 items")
|
||||
sys.exit(1)
|
||||
else:
|
||||
max_items = 10
|
||||
|
||||
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
|
||||
|
||||
# Run dump
|
||||
dump_spotlight_data(max_items=max_items, output_file=output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,7 +7,6 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
|
||||
flexible message processing options.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
@@ -30,8 +29,6 @@ class SlackMCPReader:
|
||||
workspace_name: Optional[str] = None,
|
||||
concatenate_conversations: bool = True,
|
||||
max_messages_per_conversation: int = 100,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Initialize the Slack MCP Reader.
|
||||
@@ -41,15 +38,11 @@ class SlackMCPReader:
|
||||
workspace_name: Optional workspace name to filter messages
|
||||
concatenate_conversations: Whether to group messages by channel/thread
|
||||
max_messages_per_conversation: Maximum messages to include per conversation
|
||||
max_retries: Maximum number of retries for failed operations
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
"""
|
||||
self.mcp_server_command = mcp_server_command
|
||||
self.workspace_name = workspace_name
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
self.max_messages_per_conversation = max_messages_per_conversation
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.mcp_process = None
|
||||
|
||||
async def start_mcp_server(self):
|
||||
@@ -117,73 +110,11 @@ class SlackMCPReader:
|
||||
|
||||
return response.get("result", {}).get("tools", [])
|
||||
|
||||
def _is_cache_sync_error(self, error: dict) -> bool:
|
||||
"""Check if the error is related to users cache not being ready."""
|
||||
if isinstance(error, dict):
|
||||
message = error.get("message", "").lower()
|
||||
return (
|
||||
"users cache is not ready" in message or "sync process is still running" in message
|
||||
)
|
||||
return False
|
||||
|
||||
async def _retry_with_backoff(self, func, *args, **kwargs):
|
||||
"""Retry a function with exponential backoff, especially for cache sync issues."""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if this is a cache sync error
|
||||
error_dict = {}
|
||||
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
|
||||
error_dict = e.args[0]
|
||||
elif "Failed to fetch messages" in str(e):
|
||||
# Try to extract error from the exception message
|
||||
import re
|
||||
|
||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
else:
|
||||
# Try alternative format
|
||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
if self._is_cache_sync_error(error_dict):
|
||||
if attempt < self.max_retries:
|
||||
delay = self.retry_delay * (2**attempt) # Exponential backoff
|
||||
logger.info(
|
||||
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cache sync still not ready after {self.max_retries} retries, giving up"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a cache sync error, don't retry
|
||||
break
|
||||
|
||||
# If we get here, all retries failed or it's not a retryable error
|
||||
raise last_exception
|
||||
|
||||
async def fetch_slack_messages(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
|
||||
Fetch Slack messages using MCP tools.
|
||||
|
||||
Args:
|
||||
channel: Optional channel name to filter messages
|
||||
@@ -192,59 +123,32 @@ class SlackMCPReader:
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
|
||||
|
||||
async def _fetch_slack_messages_impl(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Internal implementation of fetch_slack_messages without retry logic.
|
||||
"""
|
||||
# This is a generic implementation - specific MCP servers may have different tool names
|
||||
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
|
||||
|
||||
tools = await self.list_available_tools()
|
||||
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
|
||||
message_tool = None
|
||||
|
||||
# Look for a tool that can fetch messages - prioritize conversations_history
|
||||
message_tool = None
|
||||
|
||||
# First, try to find conversations_history specifically
|
||||
# Look for a tool that can fetch messages
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if "conversations_history" in tool_name:
|
||||
if any(
|
||||
keyword in tool_name
|
||||
for keyword in ["message", "history", "channel", "conversation"]
|
||||
):
|
||||
message_tool = tool
|
||||
logger.info(f"Found conversations_history tool: {tool}")
|
||||
break
|
||||
|
||||
# If not found, look for other message-fetching tools
|
||||
if not message_tool:
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if any(
|
||||
keyword in tool_name
|
||||
for keyword in ["conversations_search", "message", "history"]
|
||||
):
|
||||
message_tool = tool
|
||||
break
|
||||
|
||||
if not message_tool:
|
||||
raise RuntimeError("No message fetching tool found in MCP server")
|
||||
|
||||
# Prepare tool call parameters
|
||||
tool_params = {"limit": "180d"} # Use 180 days to get older messages
|
||||
tool_params = {"limit": limit}
|
||||
if channel:
|
||||
# For conversations_history, use channel_id parameter
|
||||
if message_tool["name"] == "conversations_history":
|
||||
tool_params["channel_id"] = channel
|
||||
else:
|
||||
# Try common parameter names for channel specification
|
||||
for param_name in ["channel", "channel_id", "channel_name"]:
|
||||
tool_params[param_name] = channel
|
||||
break
|
||||
|
||||
logger.info(f"Tool parameters: {tool_params}")
|
||||
# Try common parameter names for channel specification
|
||||
for param_name in ["channel", "channel_id", "channel_name"]:
|
||||
tool_params[param_name] = channel
|
||||
break
|
||||
|
||||
fetch_request = {
|
||||
"jsonrpc": "2.0",
|
||||
@@ -266,8 +170,8 @@ class SlackMCPReader:
|
||||
try:
|
||||
messages = json.loads(content["text"])
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
||||
messages = self._parse_csv_messages(content["text"], channel)
|
||||
# If not JSON, treat as plain text
|
||||
messages = [{"text": content["text"], "channel": channel or "unknown"}]
|
||||
else:
|
||||
messages = result["content"]
|
||||
else:
|
||||
@@ -276,56 +180,6 @@ class SlackMCPReader:
|
||||
|
||||
return messages if isinstance(messages, list) else [messages]
|
||||
|
||||
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
|
||||
"""Parse CSV format messages from Slack MCP server."""
|
||||
import csv
|
||||
import io
|
||||
|
||||
messages = []
|
||||
try:
|
||||
# Split by lines and process each line as a CSV row
|
||||
lines = csv_text.strip().split("\n")
|
||||
if not lines:
|
||||
return messages
|
||||
|
||||
# Skip header line if it exists
|
||||
start_idx = 0
|
||||
if lines[0].startswith("MsgID,UserID,UserName"):
|
||||
start_idx = 1
|
||||
|
||||
for line in lines[start_idx:]:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Parse CSV line
|
||||
reader = csv.reader(io.StringIO(line))
|
||||
try:
|
||||
row = next(reader)
|
||||
if len(row) >= 7: # Ensure we have enough columns
|
||||
message = {
|
||||
"ts": row[0],
|
||||
"user": row[1],
|
||||
"username": row[2],
|
||||
"real_name": row[3],
|
||||
"channel": row[4],
|
||||
"thread_ts": row[5],
|
||||
"text": row[6],
|
||||
"time": row[7] if len(row) > 7 else "",
|
||||
"reactions": row[8] if len(row) > 8 else "",
|
||||
"cursor": row[9] if len(row) > 9 else "",
|
||||
}
|
||||
messages.append(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse CSV messages: {e}")
|
||||
# Fallback: treat entire text as one message
|
||||
messages = [{"text": csv_text, "channel": channel or "unknown"}]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_message(self, message: dict[str, Any]) -> str:
|
||||
"""Format a single message for indexing."""
|
||||
text = message.get("text", "")
|
||||
@@ -397,40 +251,6 @@ class SlackMCPReader:
|
||||
|
||||
return "\n".join(content_parts)
|
||||
|
||||
async def get_all_channels(self) -> list[str]:
|
||||
"""Get list of all available channels."""
|
||||
try:
|
||||
channels_list_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "channels_list", "arguments": {}},
|
||||
}
|
||||
channels_response = await self.send_mcp_request(channels_list_request)
|
||||
if "result" in channels_response:
|
||||
result = channels_response["result"]
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
# Parse the channels from the response
|
||||
channels = []
|
||||
lines = content["text"].split("\n")
|
||||
for line in lines:
|
||||
if line.strip() and ("#" in line or "C" in line[:10]):
|
||||
# Extract channel ID or name
|
||||
parts = line.split()
|
||||
for part in parts:
|
||||
if part.startswith("C") and len(part) > 5:
|
||||
channels.append(part)
|
||||
elif part.startswith("#"):
|
||||
channels.append(part[1:]) # Remove #
|
||||
logger.info(f"Found {len(channels)} channels: {channels}")
|
||||
return channels
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get channels list: {e}")
|
||||
return []
|
||||
|
||||
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
|
||||
"""
|
||||
Read Slack data and return formatted text chunks.
|
||||
@@ -467,33 +287,36 @@ class SlackMCPReader:
|
||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||
continue
|
||||
else:
|
||||
# Fetch from all available channels
|
||||
logger.info("Fetching from all available channels...")
|
||||
all_channels = await self.get_all_channels()
|
||||
# Fetch from all available channels/conversations
|
||||
# This is a simplified approach - real implementation would need to
|
||||
# discover available channels first
|
||||
try:
|
||||
messages = await self.fetch_slack_messages(limit=1000)
|
||||
if messages:
|
||||
# Group messages by channel if concatenating
|
||||
if self.concatenate_conversations:
|
||||
channel_messages = {}
|
||||
for message in messages:
|
||||
channel = message.get(
|
||||
"channel", message.get("channel_name", "general")
|
||||
)
|
||||
if channel not in channel_messages:
|
||||
channel_messages[channel] = []
|
||||
channel_messages[channel].append(message)
|
||||
|
||||
if not all_channels:
|
||||
# Fallback to common channel names if we can't get the list
|
||||
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
|
||||
logger.info(f"Using fallback channels: {all_channels}")
|
||||
|
||||
for channel in all_channels:
|
||||
try:
|
||||
logger.info(f"Searching channel: {channel}")
|
||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||
if messages:
|
||||
if self.concatenate_conversations:
|
||||
text_content = self._create_concatenated_content(messages, channel)
|
||||
# Create concatenated content for each channel
|
||||
for channel, msgs in channel_messages.items():
|
||||
text_content = self._create_concatenated_content(msgs, channel)
|
||||
if text_content.strip():
|
||||
all_texts.append(text_content)
|
||||
else:
|
||||
# Process individual messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
all_texts.append(formatted_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||
continue
|
||||
else:
|
||||
# Process individual messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
all_texts.append(formatted_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch messages: {e}")
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
@@ -78,20 +78,6 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
help="Test MCP server connection and list available tools without indexing",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Maximum number of retries for failed operations (default: 5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Initial delay between retries in seconds (default: 2.0)",
|
||||
)
|
||||
|
||||
async def test_mcp_connection(self, args) -> bool:
|
||||
"""Test the MCP server connection and display available tools."""
|
||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||
@@ -102,14 +88,12 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
workspace_name=args.workspace_name,
|
||||
concatenate_conversations=not args.no_concatenate_conversations,
|
||||
max_messages_per_conversation=args.max_messages_per_channel,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
async with reader:
|
||||
tools = await reader.list_available_tools()
|
||||
|
||||
print("Successfully connected to MCP server!")
|
||||
print("\n✅ Successfully connected to MCP server!")
|
||||
print(f"Available tools ({len(tools)}):")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
@@ -131,7 +115,7 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to MCP server: {e}")
|
||||
print(f"\n❌ Failed to connect to MCP server: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure the MCP server is installed and accessible")
|
||||
print("2. Check if the server command is correct")
|
||||
@@ -146,11 +130,8 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
if args.workspace_name:
|
||||
print(f"Workspace: {args.workspace_name}")
|
||||
|
||||
# Filter out empty strings from channels
|
||||
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
|
||||
|
||||
if channels:
|
||||
print(f"Channels: {', '.join(channels)}")
|
||||
if args.channels:
|
||||
print(f"Channels: {', '.join(args.channels)}")
|
||||
else:
|
||||
print("Fetching from all available channels")
|
||||
|
||||
@@ -165,20 +146,18 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
workspace_name=args.workspace_name,
|
||||
concatenate_conversations=concatenate,
|
||||
max_messages_per_conversation=args.max_messages_per_channel,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
texts = await reader.read_slack_data(channels=channels)
|
||||
texts = await reader.read_slack_data(channels=args.channels)
|
||||
|
||||
if not texts:
|
||||
print("No messages found! This could mean:")
|
||||
print("❌ No messages found! This could mean:")
|
||||
print("- The MCP server couldn't fetch messages")
|
||||
print("- The specified channels don't exist or are empty")
|
||||
print("- Authentication issues with the Slack workspace")
|
||||
return []
|
||||
|
||||
print(f"Successfully loaded {len(texts)} text chunks from Slack")
|
||||
print(f"✅ Successfully loaded {len(texts)} text chunks from Slack")
|
||||
|
||||
# Show sample of what was loaded
|
||||
if texts:
|
||||
@@ -191,7 +170,7 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
return texts
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading Slack data: {e}")
|
||||
print(f"❌ Error loading Slack data: {e}")
|
||||
print("\nThis might be due to:")
|
||||
print("- MCP server connection issues")
|
||||
print("- Authentication problems")
|
||||
@@ -209,7 +188,7 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
if not success:
|
||||
return
|
||||
print(
|
||||
"MCP server is working! You can now run without --test-connection to start indexing."
|
||||
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -54,51 +54,29 @@ def extract_thinking_answer(response):
|
||||
return response.strip()
|
||||
|
||||
|
||||
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load HuggingFace model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
def load_hf_model(model_name="Qwen/Qwen3-8B"):
|
||||
"""Load HuggingFace model"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading HF: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load vLLM model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
def load_vllm_model(model_name="Qwen/Qwen3-8B"):
|
||||
"""Load vLLM model"""
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading vLLM: {model_name}")
|
||||
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
|
||||
llm = LLM(model=model_name, trust_remote_code=True)
|
||||
|
||||
# Qwen3 specific config
|
||||
if is_qwen3_model(model_name):
|
||||
@@ -200,33 +178,19 @@ def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complex
|
||||
}
|
||||
|
||||
|
||||
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
|
||||
"""Load Qwen2.5-VL multimodal model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
|
||||
"""Load Qwen2.5-VL multimodal model"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
|
||||
return processor, model
|
||||
@@ -238,14 +202,9 @@ def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_co
|
||||
try:
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
# Update Benchmarks
|
||||
|
||||
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||
search” pipeline under different assumptions:
|
||||
|
||||
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||
settings influence incremental `add()` latency when embeddings are fetched
|
||||
over the ZMQ embedding server.
|
||||
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||
against an offline approach that keeps the graph static and fuses results.
|
||||
|
||||
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||
|
||||
## Benchmarks
|
||||
|
||||
### 1. HNSW RNG Recompute Benchmark
|
||||
|
||||
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||
is enabled:
|
||||
|
||||
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||
| `baseline` | Enabled | Enabled | Enabled |
|
||||
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||
|
||||
For each scenario the script:
|
||||
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||
3. Appends the requested updates using the scenario’s RNG flags.
|
||||
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||
timings before appending a row to the CSV output.
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||
LEANN_LOG_LEVEL=INFO \
|
||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||
--runs 1 \
|
||||
--index-path .leann/bench/test.leann \
|
||||
--initial-files data/PrideandPrejudice.txt \
|
||||
--update-files data/huawei_pangu.md \
|
||||
--max-initial 300 \
|
||||
--max-updates 1 \
|
||||
--add-timeout 120
|
||||
```
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||
(including ms/passage) for each run.
|
||||
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||
`LEANN_HNSW_LOG_PATH`).
|
||||
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||
|
||||
### 2. Sequential vs. Offline Update Benchmark
|
||||
|
||||
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||
same dataset:
|
||||
|
||||
- **Scenario A – Sequential Update**
|
||||
- Start an embedding server.
|
||||
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||
mutates the HNSW graph.
|
||||
- After all inserts, run a search on the updated graph.
|
||||
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||
latency.
|
||||
|
||||
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||
- Stop Scenario A’s server and start a fresh embedding server.
|
||||
- Spawn two threads: one generates embeddings for the new passages offline
|
||||
(graph unchanged); the other computes the query embedding and searches the
|
||||
existing graph.
|
||||
- Merge offline similarities with the graph search results to emulate late
|
||||
fusion, then report the merged top‑k preview.
|
||||
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||
|
||||
**Run (both scenarios):**
|
||||
```bash
|
||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||
--index-path .leann/bench/offline_vs_update.leann \
|
||||
--max-initial 300 \
|
||||
--num-updates 1
|
||||
```
|
||||
|
||||
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||
print timing summaries to stdout and append the results to CSV.
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||
Scenario A and B.
|
||||
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||
checks.
|
||||
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||
|
||||
### 3. Visualisation
|
||||
|
||||
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||
benchmark into a single two-panel plot.
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
uv run -m benchmarks.update.plot_bench_results \
|
||||
--csv benchmarks/update/bench_results.csv \
|
||||
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||
--out benchmarks/update/bench_latency_from_csv.png
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||
- `--csv` – RNG benchmark results CSV (left panel).
|
||||
- `--csv-right` – Update strategy results CSV (right panel).
|
||||
- `--out` – Output image path (PNG/PDF supported).
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||
suites.
|
||||
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||
slides/papers.
|
||||
|
||||
## Parameters & Environment
|
||||
|
||||
### Common CLI Flags
|
||||
- `--max-initial` – Number of initial passages used to seed the index.
|
||||
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||
|
||||
### Environment Variables
|
||||
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||
execution of the embedding model.
|
||||
|
||||
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||
fusion better match your latency/accuracy trade-offs.
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Benchmarks for LEANN update workflows."""
|
||||
|
||||
# Expose helper to locate repository root for other modules that need it.
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def find_repo_root() -> Path:
|
||||
"""Return the project root containing pyproject.toml."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
return current.parents[1]
|
||||
|
||||
|
||||
__all__ = ["find_repo_root"]
|
||||
@@ -1,804 +0,0 @@
|
||||
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||
embedding recomputation.
|
||||
|
||||
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||
|
||||
Example usage (run from the repo root; downloads the model on first run)::
|
||||
|
||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||
--index-path .leann/bench/leann-demo.leann \
|
||||
--runs 1
|
||||
|
||||
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||
if you want a larger or different workload, and change the embedding model via
|
||||
``--model-name``.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import zmq
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_project_directory
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def _find_repo_root() -> Path:
|
||||
"""Locate project root by walking up until pyproject.toml is found."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
# Fallback: assume repo is two levels up (../..)
|
||||
return current.parents[2]
|
||||
|
||||
|
||||
REPO_ROOT = _find_repo_root()
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from apps.chunking import create_text_chunks # noqa: E402
|
||||
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
if limit is not None:
|
||||
cleaned = cleaned[:limit]
|
||||
return cleaned
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
ef_construction: int,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=True,
|
||||
distance_metric=distance_metric,
|
||||
backend_kwargs={
|
||||
"distance_metric": distance_metric,
|
||||
"is_compact": False,
|
||||
"is_recompute": True,
|
||||
"efConstruction": ef_construction,
|
||||
},
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||
|
||||
|
||||
def benchmark_update_with_mode(
|
||||
index_path: Path,
|
||||
new_chunks: list[dict[str, Any]],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
disable_forward_rng: bool,
|
||||
disable_reverse_rng: bool,
|
||||
server_port: int,
|
||||
add_timeout: int,
|
||||
ef_construction: int,
|
||||
) -> tuple[float, float]:
|
||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
with open(offset_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
existing_ids = set(offset_map.keys())
|
||||
|
||||
valid_chunks: list[dict[str, Any]] = []
|
||||
for chunk in new_chunks:
|
||||
text = chunk.get("text", "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
metadata = chunk.setdefault("metadata", {})
|
||||
passage_id = chunk.get("id") or metadata.get("id")
|
||||
if passage_id and passage_id in existing_ids:
|
||||
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid chunks to append.")
|
||||
|
||||
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
|
||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
if distance_metric == "cosine":
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
|
||||
index = faiss.read_index(str(index_file))
|
||||
index.is_recompute = True
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
try:
|
||||
storage_index.ntotal = index.ntotal
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||
if ef_construction is not None:
|
||||
index.hnsw.efConstruction = ef_construction
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||
logger.info(
|
||||
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||
disable_forward_rng,
|
||||
disable_reverse_rng,
|
||||
applied_forward,
|
||||
applied_reverse,
|
||||
)
|
||||
|
||||
base_id = index.ntotal
|
||||
for offset, chunk in enumerate(valid_chunks):
|
||||
new_id = str(base_id + offset)
|
||||
chunk.setdefault("metadata", {})["id"] = new_id
|
||||
chunk["id"] = new_id
|
||||
|
||||
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||
offset_map_backup = offset_map.copy()
|
||||
|
||||
try:
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
server_started, actual_port = server_manager.start_server(
|
||||
port=server_port,
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=distance_metric,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError("Failed to start embedding server.")
|
||||
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(actual_port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(actual_port)
|
||||
|
||||
_warmup_embedding_server(actual_port)
|
||||
|
||||
total_start = time.time()
|
||||
add_elapsed = 0.0
|
||||
|
||||
try:
|
||||
import signal
|
||||
|
||||
def _timeout_handler(signum, frame):
|
||||
raise TimeoutError("incremental add timed out")
|
||||
|
||||
if add_timeout > 0:
|
||||
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||
signal.alarm(add_timeout)
|
||||
|
||||
add_start = time.time()
|
||||
for i in range(embeddings.shape[0]):
|
||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||
add_elapsed = time.time() - add_start
|
||||
if add_timeout > 0:
|
||||
signal.alarm(0)
|
||||
faiss.write_index(index, str(index_file))
|
||||
finally:
|
||||
server_manager.stop_server()
|
||||
|
||||
except TimeoutError:
|
||||
raise
|
||||
except Exception:
|
||||
if passages_file.exists():
|
||||
with open(passages_file, "rb+") as f:
|
||||
f.truncate(rollback_size)
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map_backup, f)
|
||||
raise
|
||||
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||
try:
|
||||
index.hnsw.set_disable_rng_during_add(False)
|
||||
index.hnsw.set_disable_reverse_prune(False)
|
||||
except AttributeError:
|
||||
pass
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
total_elapsed = time.time() - total_start
|
||||
|
||||
return total_elapsed, add_elapsed
|
||||
|
||||
|
||||
def _total_zmq_nodes(log_path: Path) -> int:
|
||||
if not log_path.exists():
|
||||
return 0
|
||||
with log_path.open("r", encoding="utf-8") as log_file:
|
||||
text = log_file.read()
|
||||
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||
|
||||
|
||||
def _warmup_embedding_server(port: int) -> None:
|
||||
"""Send a dummy REQ so the embedding server loads its model."""
|
||||
ctx = zmq.Context()
|
||||
try:
|
||||
sock = ctx.socket(zmq.REQ)
|
||||
sock.setsockopt(zmq.LINGER, 0)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||
sock.send(payload)
|
||||
try:
|
||||
sock.recv()
|
||||
except zmq.error.Again:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/bench/leann-demo.leann"),
|
||||
help="Output index base path (without extension).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
help="Files used to build the initial index.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
help="Files appended during the benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
help="Embedding model used for build/update.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
default="sentence-transformers",
|
||||
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
default="mips",
|
||||
choices=["mips", "l2", "cosine"],
|
||||
help="Distance metric for HNSW backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ef-construction",
|
||||
type=int,
|
||||
default=200,
|
||||
help="efConstruction setting for initial build.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=5557,
|
||||
help="Port for the real embedding server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-initial",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Optional cap on initial passages (after chunking).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-updates",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Optional cap on update passages (after chunking).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-timeout",
|
||||
type=int,
|
||||
default=900,
|
||||
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-path",
|
||||
type=Path,
|
||||
default=Path("bench_latency.png"),
|
||||
help="Where to save the latency bar plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--broken-y",
|
||||
action="store_true",
|
||||
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lower-cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upper-start-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/bench_results.csv"),
|
||||
help="Where to append per-scenario results as CSV.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||
if not update_paragraphs:
|
||||
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||
|
||||
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||
ensure_index_dir(args.index_path)
|
||||
|
||||
scenarios = [
|
||||
("baseline", False, False, True),
|
||||
("no_cache_baseline", False, False, False),
|
||||
("disable_forward_rng", True, False, True),
|
||||
("disable_forward_and_reverse_rng", True, True, True),
|
||||
]
|
||||
|
||||
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||
|
||||
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
|
||||
# CSV setup
|
||||
import csv
|
||||
|
||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||
csv_fields = [
|
||||
"run_id",
|
||||
"scenario",
|
||||
"cache_enabled",
|
||||
"ef_construction",
|
||||
"max_initial",
|
||||
"max_updates",
|
||||
"total_time_s",
|
||||
"add_only_s",
|
||||
"latency_ms_per_passage",
|
||||
"zmq_nodes",
|
||||
"stageA_time_s",
|
||||
"stageBC_time_s",
|
||||
"model_name",
|
||||
"embedding_mode",
|
||||
"distance_metric",
|
||||
]
|
||||
# Create CSV with header if missing
|
||||
if args.csv_path:
|
||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
for run in range(args.runs):
|
||||
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||
print(f"\nScenario: {name}")
|
||||
cleanup_index_files(args.index_path)
|
||||
if log_path.exists():
|
||||
try:
|
||||
log_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||
build_initial_index(
|
||||
args.index_path,
|
||||
initial_paragraphs,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
args.ef_construction,
|
||||
)
|
||||
|
||||
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||
|
||||
try:
|
||||
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||
args.index_path,
|
||||
update_chunks,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
disable_forward,
|
||||
disable_reverse,
|
||||
args.server_port,
|
||||
args.add_timeout,
|
||||
args.ef_construction,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
print(f"Scenario {name} timed out: {exc}")
|
||||
continue
|
||||
|
||||
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||
if curr_size < prev_size:
|
||||
prev_size = 0
|
||||
zmq_count = 0
|
||||
if log_path.exists():
|
||||
with log_path.open("r", encoding="utf-8") as log_file:
|
||||
log_file.seek(prev_size)
|
||||
new_entries = log_file.read()
|
||||
zmq_count = sum(
|
||||
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||
)
|
||||
stageA = sum(
|
||||
float(x)
|
||||
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||
)
|
||||
stageBC = sum(
|
||||
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||
)
|
||||
else:
|
||||
stageA = 0.0
|
||||
stageBC = 0.0
|
||||
|
||||
per_chunk = add_elapsed / len(update_chunks)
|
||||
print(
|
||||
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||
)
|
||||
print(f"ZMQ node fetch total: {zmq_count}")
|
||||
results_total[name].append(total_elapsed)
|
||||
results_add[name].append(add_elapsed)
|
||||
results_zmq[name].append(zmq_count)
|
||||
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||
results_stageA[name].append(stageA)
|
||||
results_stageBC[name].append(stageBC)
|
||||
|
||||
# Append row to CSV
|
||||
if args.csv_path:
|
||||
row = {
|
||||
"run_id": run_id,
|
||||
"scenario": name,
|
||||
"cache_enabled": 1 if cache_enabled else 0,
|
||||
"ef_construction": args.ef_construction,
|
||||
"max_initial": args.max_initial,
|
||||
"max_updates": args.max_updates,
|
||||
"total_time_s": round(total_elapsed, 6),
|
||||
"add_only_s": round(add_elapsed, 6),
|
||||
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||
"zmq_nodes": int(zmq_count),
|
||||
"stageA_time_s": round(stageA, 6),
|
||||
"stageBC_time_s": round(stageBC, 6),
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row)
|
||||
|
||||
print("\n=== Summary ===")
|
||||
for name in results_add:
|
||||
add_values = results_add[name]
|
||||
total_values = results_total[name]
|
||||
zmq_values = results_zmq[name]
|
||||
latency_values = results_ms_per_passage[name]
|
||||
if not add_values:
|
||||
print(f"{name}: no successful runs")
|
||||
continue
|
||||
avg_add = sum(add_values) / len(add_values)
|
||||
avg_total = sum(total_values) / len(total_values)
|
||||
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||
runs = len(add_values)
|
||||
print(
|
||||
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||
)
|
||||
|
||||
if args.plot_path:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
labels = [name for name, *_ in scenarios]
|
||||
values = [
|
||||
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||
if results_ms_per_passage[name]
|
||||
else 0.0
|
||||
for name in labels
|
||||
]
|
||||
|
||||
def _auto_cap(vals: list[float]) -> float | None:
|
||||
s = sorted(vals, reverse=True)
|
||||
if len(s) < 2:
|
||||
return None
|
||||
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||
return s[1] * 1.1
|
||||
return None
|
||||
|
||||
def _fmt_ms(v: float) -> str:
|
||||
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||
|
||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||
|
||||
if args.broken_y:
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.2, lower_cap * 1.02)
|
||||
)
|
||||
ymax = max(values) * 1.10 if values else 1.0
|
||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||
2,
|
||||
1,
|
||||
sharex=True,
|
||||
figsize=(7.4, 5.0),
|
||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||
)
|
||||
x = list(range(len(labels)))
|
||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_bottom.set_ylim(0, lower_cap)
|
||||
ax_top.set_ylim(upper_start, ymax)
|
||||
for i, v in enumerate(values):
|
||||
if v <= lower_cap:
|
||||
ax_bottom.text(
|
||||
i,
|
||||
v + lower_cap * 0.02,
|
||||
_fmt_ms(v),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
else:
|
||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bottom.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False)
|
||||
ax_bottom.xaxis.tick_bottom()
|
||||
d = 0.015
|
||||
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||
kwargs.update({"transform": ax_bottom.transAxes})
|
||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||
ax_bottom.set_xticks(range(len(labels)))
|
||||
ax_bottom.set_xticklabels(labels)
|
||||
ax = ax_bottom
|
||||
else:
|
||||
cap = args.cap_y or _auto_cap(values)
|
||||
plt.figure(figsize=(7.2, 4.2))
|
||||
ax = plt.gca()
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = []
|
||||
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||
bars.append(b[0])
|
||||
if v > cap:
|
||||
bars[-1].set_hatch("//")
|
||||
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
else:
|
||||
ax.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
_fmt_ms(v),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax.set_ylim(0, cap * 1.10)
|
||||
ax.plot(
|
||||
[0.02 - 0.02, 0.02 + 0.02],
|
||||
[0.98 + 0.02, 0.98 - 0.02],
|
||||
transform=ax.transAxes,
|
||||
color="k",
|
||||
lw=1,
|
||||
)
|
||||
ax.plot(
|
||||
[0.98 - 0.02, 0.98 + 0.02],
|
||||
[0.98 + 0.02, 0.98 - 0.02],
|
||||
transform=ax.transAxes,
|
||||
color="k",
|
||||
lw=1,
|
||||
)
|
||||
if any(v > cap for v in values):
|
||||
ax.legend(
|
||||
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||
)
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels)
|
||||
else:
|
||||
ax.bar(labels, values, color=colors[: len(labels)])
|
||||
for idx, val in enumerate(values):
|
||||
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||
|
||||
plt.ylabel("Average add latency (ms per passage)")
|
||||
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(args.plot_path)
|
||||
print(f"Saved latency bar plot to {args.plot_path}")
|
||||
# ZMQ time split (Stage A vs B/C)
|
||||
try:
|
||||
plt.figure(figsize=(6, 4))
|
||||
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||
bc_vals = [
|
||||
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||
]
|
||||
ind = range(len(labels))
|
||||
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||
plt.bar(
|
||||
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||
)
|
||||
plt.xticks(list(ind), labels, rotation=10)
|
||||
plt.ylabel("Server ZMQ time (s)")
|
||||
plt.title(
|
||||
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||
)
|
||||
plt.legend()
|
||||
out2 = args.plot_path.with_name(
|
||||
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out2)
|
||||
print(f"Saved ZMQ time split plot to {out2}")
|
||||
except Exception as e:
|
||||
print("Failed to plot ZMQ split:", e)
|
||||
except ImportError:
|
||||
print("matplotlib not available; skipping plot generation")
|
||||
|
||||
# leave the last build on disk for inspection
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,5 +0,0 @@
|
||||
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
|
@@ -1,704 +0,0 @@
|
||||
"""
|
||||
Compare two latency models for small incremental updates vs. search:
|
||||
|
||||
Scenario A (sequential update then search):
|
||||
- Build initial HNSW (is_recompute=True)
|
||||
- Start embedding server (ZMQ) for recompute
|
||||
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||
- Then run a search query on the updated index
|
||||
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||
|
||||
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||
- Do NOT insert the N passages into the graph
|
||||
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||
embedding and run a search on the existing index
|
||||
- After both finish, compute similarity between the query embedding and the N
|
||||
new passage embeddings, merge with the index search results by score, and
|
||||
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||
|
||||
This script reuses the model/data loading conventions of
|
||||
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||
comparison for the two execution strategies above.
|
||||
|
||||
Example (from the repository root):
|
||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||
--index-path .leann/bench/offline_vs_update.leann \
|
||||
--max-initial 300 --num-updates 5 --k 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import psutil # type: ignore
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_project_directory
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def _find_repo_root() -> Path:
|
||||
"""Locate project root by walking up until pyproject.toml is found."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
# Fallback: assume repo is two levels up (../..)
|
||||
return current.parents[2]
|
||||
|
||||
|
||||
REPO_ROOT = _find_repo_root()
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from apps.chunking import create_text_chunks # noqa: E402
|
||||
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
if limit is not None:
|
||||
cleaned = cleaned[:limit]
|
||||
return cleaned
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
ef_construction: int,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=True,
|
||||
distance_metric=distance_metric,
|
||||
backend_kwargs={
|
||||
"distance_metric": distance_metric,
|
||||
"is_compact": False,
|
||||
"is_recompute": True,
|
||||
"efConstruction": ef_construction,
|
||||
},
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||
if metric == "cosine":
|
||||
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
vecs = vecs / norms
|
||||
return vecs
|
||||
|
||||
|
||||
def _read_index_for_search(index_path: Path) -> Any:
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
# Force-disable experimental disk cache when loading the index so that
|
||||
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||
cfg = faiss.HNSWIndexConfig()
|
||||
cfg.is_recompute = True
|
||||
if hasattr(cfg, "disk_cache_ratio"):
|
||||
cfg.disk_cache_ratio = 0.0
|
||||
if hasattr(cfg, "external_storage_path"):
|
||||
cfg.external_storage_path = None
|
||||
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||
# ensure recompute mode persists after reload
|
||||
try:
|
||||
index.is_recompute = True
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
actual_ntotal = index.hnsw.levels.size()
|
||||
except AttributeError:
|
||||
actual_ntotal = index.ntotal
|
||||
if actual_ntotal != index.ntotal:
|
||||
print(
|
||||
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||
flush=True,
|
||||
)
|
||||
index.ntotal = actual_ntotal
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
return index
|
||||
|
||||
|
||||
def _append_passages_for_updates(
|
||||
meta_path: Path,
|
||||
start_id: int,
|
||||
texts: list[str],
|
||||
) -> list[str]:
|
||||
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
index_dir = meta_path.parent
|
||||
meta_name = meta_path.name
|
||||
if not meta_name.endswith(".meta.json"):
|
||||
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||
index_base = meta_name[: -len(".meta.json")]
|
||||
|
||||
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||
|
||||
if not passages_file.exists() or not offsets_file.exists():
|
||||
raise FileNotFoundError(
|
||||
"Passage store missing; cannot register update passages for recompute mode."
|
||||
)
|
||||
|
||||
with open(offsets_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
|
||||
assigned_ids: list[str] = []
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for i, text in enumerate(texts):
|
||||
passage_id = str(start_id + i)
|
||||
offset = f.tell()
|
||||
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
offset_map[passage_id] = offset
|
||||
assigned_ids.append(passage_id)
|
||||
|
||||
with open(offsets_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
try:
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
meta = {}
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
return assigned_ids
|
||||
|
||||
|
||||
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||
distances = np.zeros((1, k), dtype=np.float32)
|
||||
indices = np.zeros((1, k), dtype=np.int64)
|
||||
index.search(
|
||||
1,
|
||||
faiss.swig_ptr(q),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(indices),
|
||||
)
|
||||
return distances[0], indices[0]
|
||||
|
||||
|
||||
def _score_for_metric(dist: float, metric: str) -> float:
|
||||
# Convert FAISS distance to a "higher is better" score
|
||||
if metric in ("mips", "cosine"):
|
||||
return float(dist)
|
||||
# l2 distance (smaller better) -> negative distance as score
|
||||
return -float(dist)
|
||||
|
||||
|
||||
def _merge_results(
|
||||
index_results: tuple[np.ndarray, np.ndarray],
|
||||
offline_scores: list[tuple[int, float]],
|
||||
k: int,
|
||||
metric: str,
|
||||
) -> list[tuple[str, float]]:
|
||||
distances, indices = index_results
|
||||
merged: list[tuple[str, float]] = []
|
||||
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||
for j, s in offline_scores:
|
||||
merged.append((f"offline:{j}", s))
|
||||
merged.sort(key=lambda x: x[1], reverse=True)
|
||||
return merged[:k]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioResult:
|
||||
name: str
|
||||
update_total_s: float
|
||||
search_s: float
|
||||
overall_s: float
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
)
|
||||
parser.add_argument("--max-initial", type=int, default=300)
|
||||
parser.add_argument("--num-updates", type=int, default=5)
|
||||
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="neural network",
|
||||
help="Query text used for the search benchmark.",
|
||||
)
|
||||
parser.add_argument("--server-port", type=int, default=5557)
|
||||
parser.add_argument("--add-timeout", type=int, default=600)
|
||||
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
default="mips",
|
||||
choices=["mips", "l2", "cosine"],
|
||||
)
|
||||
parser.add_argument("--ef-construction", type=int, default=200)
|
||||
parser.add_argument(
|
||||
"--only",
|
||||
choices=["A", "B", "both"],
|
||||
default="both",
|
||||
help="Run only Scenario A, Scenario B, or both",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||
help="Where to append results (CSV).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
# Load data
|
||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||
if not update_paragraphs:
|
||||
raise ValueError("No update passages loaded from --update-files")
|
||||
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||
if len(update_paragraphs) < args.num_updates:
|
||||
raise ValueError(
|
||||
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||
)
|
||||
|
||||
ensure_index_dir(args.index_path)
|
||||
cleanup_index_files(args.index_path)
|
||||
|
||||
# Build initial index
|
||||
build_initial_index(
|
||||
args.index_path,
|
||||
initial_paragraphs,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
args.ef_construction,
|
||||
)
|
||||
|
||||
# Prepare index object and meta
|
||||
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||
index = _read_index_for_search(args.index_path)
|
||||
|
||||
# CSV setup
|
||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||
if args.csv_path:
|
||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_fields = [
|
||||
"run_id",
|
||||
"scenario",
|
||||
"max_initial",
|
||||
"num_updates",
|
||||
"k",
|
||||
"total_time_s",
|
||||
"add_total_s",
|
||||
"search_time_s",
|
||||
"emb_time_s",
|
||||
"makespan_s",
|
||||
"model_name",
|
||||
"embedding_mode",
|
||||
"distance_metric",
|
||||
]
|
||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
# Debug: list existing HNSW server PIDs before starting
|
||||
try:
|
||||
existing = [
|
||||
p
|
||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||
if any(
|
||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||
for arg in (p.info.get("cmdline") or [])
|
||||
)
|
||||
]
|
||||
if existing:
|
||||
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||
for p in existing:
|
||||
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||
except Exception as _e:
|
||||
pass
|
||||
|
||||
add_total = 0.0
|
||||
search_after_add = 0.0
|
||||
total_seq = 0.0
|
||||
port_a = None
|
||||
if args.only in ("A", "both"):
|
||||
# Scenario A: sequential update then search
|
||||
start_id = index.ntotal
|
||||
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||
if assigned_ids:
|
||||
logger.debug(
|
||||
"Registered %d update passages starting at id %s",
|
||||
len(assigned_ids),
|
||||
assigned_ids[0],
|
||||
)
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
ok, port = server_manager.start_server(
|
||||
port=args.server_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
if not ok:
|
||||
raise RuntimeError("Failed to start embedding server")
|
||||
try:
|
||||
# Set ZMQ port for recompute mode
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(port)
|
||||
|
||||
# Start A overall timer BEFORE computing update embeddings
|
||||
t0 = time.time()
|
||||
|
||||
# Compute embeddings for updates (counted into A's overall)
|
||||
t_emb0 = time.time()
|
||||
upd_embs = compute_embeddings(
|
||||
update_paragraphs,
|
||||
args.model_name,
|
||||
mode=args.embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
emb_time_updates = time.time() - t_emb0
|
||||
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||
|
||||
# Perform sequential adds
|
||||
for i in range(upd_embs.shape[0]):
|
||||
t_add0 = time.time()
|
||||
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||
add_total += time.time() - t_add0
|
||||
# Don't persist index after adds to avoid contaminating Scenario B
|
||||
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||
# faiss.write_index(index, str(index_file))
|
||||
|
||||
# Search after updates
|
||||
q_emb = compute_embeddings(
|
||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||
|
||||
# Warm up search with a dummy query first
|
||||
print("[DEBUG] Warming up search...")
|
||||
_ = _search(index, q_emb, 1)
|
||||
|
||||
t_s0 = time.time()
|
||||
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||
search_after_add = time.time() - t_s0
|
||||
total_seq = time.time() - t0
|
||||
finally:
|
||||
server_manager.stop_server()
|
||||
port_a = port
|
||||
|
||||
print("\n=== Scenario A: update->search (sequential) ===")
|
||||
# emb_time_updates is defined only when A runs
|
||||
try:
|
||||
_emb_a = emb_time_updates
|
||||
except NameError:
|
||||
_emb_a = 0.0
|
||||
print(
|
||||
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||
)
|
||||
# CSV row for A
|
||||
if args.csv_path:
|
||||
row_a = {
|
||||
"run_id": run_id,
|
||||
"scenario": "A",
|
||||
"max_initial": args.max_initial,
|
||||
"num_updates": args.num_updates,
|
||||
"k": args.k,
|
||||
"total_time_s": round(total_seq, 6),
|
||||
"add_total_s": round(add_total, 6),
|
||||
"search_time_s": round(search_after_add, 6),
|
||||
"emb_time_s": round(_emb_a, 6),
|
||||
"makespan_s": 0.0,
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row_a)
|
||||
|
||||
# Verify server cleanup
|
||||
try:
|
||||
# short sleep to allow signal handling to finish
|
||||
time.sleep(0.5)
|
||||
leftovers = [
|
||||
p
|
||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||
if any(
|
||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||
for arg in (p.info.get("cmdline") or [])
|
||||
)
|
||||
]
|
||||
if leftovers:
|
||||
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||
for p in leftovers:
|
||||
print(
|
||||
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||
)
|
||||
else:
|
||||
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||
if args.only in ("B", "both"):
|
||||
# ensure a server is available for recompute search
|
||||
server_manager_b = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
requested_port = args.server_port if port_a is None else port_a
|
||||
ok_b, port_b = server_manager_b.start_server(
|
||||
port=requested_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
if not ok_b:
|
||||
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||
|
||||
# Wait for server to fully initialize
|
||||
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||
time.sleep(2)
|
||||
|
||||
try:
|
||||
# Read the index first
|
||||
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||
|
||||
# Then configure ZMQ port on the correct index object
|
||||
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||
index_no_update.hnsw.set_zmq_port(port_b)
|
||||
elif hasattr(index_no_update, "set_zmq_port"):
|
||||
index_no_update.set_zmq_port(port_b)
|
||||
|
||||
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||
logger.info("Warming up embedding model for Scenario B...")
|
||||
_ = compute_embeddings(
|
||||
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
|
||||
# Prepare worker A: compute embeddings for the same N passages
|
||||
emb_time = 0.0
|
||||
updates_embs_offline: np.ndarray | None = None
|
||||
|
||||
def _worker_emb():
|
||||
nonlocal emb_time, updates_embs_offline
|
||||
t = time.time()
|
||||
updates_embs_offline = compute_embeddings(
|
||||
update_paragraphs,
|
||||
args.model_name,
|
||||
mode=args.embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
emb_time = time.time() - t
|
||||
|
||||
# Pre-compute query embedding and warm up search outside of timed section.
|
||||
q_vec = compute_embeddings(
|
||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||
print("[DEBUG B] Warming up search...")
|
||||
_ = _search(index_no_update, q_vec, 1)
|
||||
|
||||
# Worker B: timed search on the warmed index
|
||||
search_time = 0.0
|
||||
offline_elapsed = 0.0
|
||||
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||
|
||||
def _worker_search():
|
||||
nonlocal search_time, index_results
|
||||
t = time.time()
|
||||
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||
search_time = time.time() - t
|
||||
index_results = (distances, indices)
|
||||
|
||||
# Run two workers concurrently
|
||||
t0 = time.time()
|
||||
th1 = threading.Thread(target=_worker_emb)
|
||||
th2 = threading.Thread(target=_worker_search)
|
||||
th1.start()
|
||||
th2.start()
|
||||
th1.join()
|
||||
th2.join()
|
||||
offline_elapsed = time.time() - t0
|
||||
|
||||
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||
offline_scores: list[tuple[int, float]] = []
|
||||
if updates_embs_offline is not None:
|
||||
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||
for j in range(upd2.shape[0]):
|
||||
if args.distance_metric in ("mips", "cosine"):
|
||||
s = float(np.dot(q_vec[0], upd2[j]))
|
||||
else:
|
||||
diff = q_vec[0] - upd2[j]
|
||||
s = -float(np.dot(diff, diff))
|
||||
offline_scores.append((j, s))
|
||||
|
||||
merged_topk = (
|
||||
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||
if index_results
|
||||
else []
|
||||
)
|
||||
|
||||
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||
print(
|
||||
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||
)
|
||||
if merged_topk:
|
||||
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||
print(f"Merged top-5 preview: {preview}")
|
||||
# CSV row for B
|
||||
if args.csv_path:
|
||||
row_b = {
|
||||
"run_id": run_id,
|
||||
"scenario": "B",
|
||||
"max_initial": args.max_initial,
|
||||
"num_updates": args.num_updates,
|
||||
"k": args.k,
|
||||
"total_time_s": 0.0,
|
||||
"add_total_s": 0.0,
|
||||
"search_time_s": round(search_time, 6),
|
||||
"emb_time_s": round(emb_time, 6),
|
||||
"makespan_s": round(offline_elapsed, 6),
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row_b)
|
||||
|
||||
finally:
|
||||
server_manager_b.stop_server()
|
||||
|
||||
# Summary
|
||||
print("\n=== Summary ===")
|
||||
msg_a = (
|
||||
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||
if args.only in ("A", "both")
|
||||
else "A: skipped"
|
||||
)
|
||||
msg_b = (
|
||||
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||
if args.only in ("B", "both")
|
||||
else "B: skipped"
|
||||
)
|
||||
print(msg_a + "\n" + msg_b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,5 +0,0 @@
|
||||
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
|
@@ -1,645 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Plot latency bars from the benchmark CSV produced by
|
||||
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||
|
||||
If you also provide an offline_vs_update.csv via --csv-right
|
||||
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||
output a side-by-side figure:
|
||||
- Left: ms/passage bars (four RNG scenarios).
|
||||
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||
|
||||
Usage:
|
||||
uv run python benchmarks/update/plot_bench_results.py \
|
||||
--csv benchmarks/update/bench_results.csv \
|
||||
--out benchmarks/update/bench_latency_from_csv.png
|
||||
|
||||
The script selects the latest run_id in the CSV and plots four bars for
|
||||
the default scenarios:
|
||||
- baseline
|
||||
- no_cache_baseline
|
||||
- disable_forward_rng
|
||||
- disable_forward_and_reverse_rng
|
||||
|
||||
If multiple rows exist per scenario for that run_id, the script averages
|
||||
their latency_ms_per_passage values.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_SCENARIOS = [
|
||||
"no_cache_baseline",
|
||||
"baseline",
|
||||
"disable_forward_rng",
|
||||
"disable_forward_and_reverse_rng",
|
||||
]
|
||||
|
||||
SCENARIO_LABELS = {
|
||||
"baseline": "+ Cache",
|
||||
"no_cache_baseline": "Naive \n Recompute",
|
||||
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||
}
|
||||
|
||||
# Paper-style colors and hatches for scenarios
|
||||
SCENARIO_STYLES = {
|
||||
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||
}
|
||||
|
||||
|
||||
def load_latest_run(csv_path: Path):
|
||||
rows = []
|
||||
with csv_path.open("r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
rows.append(row)
|
||||
if not rows:
|
||||
raise SystemExit("CSV is empty: no rows to plot")
|
||||
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||
run_ids = [r.get("run_id", "") for r in rows]
|
||||
latest = max(run_ids)
|
||||
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||
if not latest_rows:
|
||||
# Fallback: take last 4 rows
|
||||
latest_rows = rows[-4:]
|
||||
latest = latest_rows[-1].get("run_id", "unknown")
|
||||
return latest, latest_rows
|
||||
|
||||
|
||||
def aggregate_latency(rows):
|
||||
acc = defaultdict(list)
|
||||
for r in rows:
|
||||
sc = r.get("scenario", "")
|
||||
try:
|
||||
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||
except ValueError:
|
||||
continue
|
||||
acc[sc].append(val)
|
||||
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||
return avg
|
||||
|
||||
|
||||
def _auto_cap(values: list[float]) -> float | None:
|
||||
if not values:
|
||||
return None
|
||||
sorted_vals = sorted(values, reverse=True)
|
||||
if len(sorted_vals) < 2:
|
||||
return None
|
||||
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||
if second <= 0:
|
||||
return None
|
||||
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||
if max_v >= 2.5 * second:
|
||||
return second * 1.1
|
||||
return None
|
||||
|
||||
|
||||
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||
# Draw small diagonal ticks near left/right to signal cap
|
||||
x0, x1 = rel_x0, rel_x1
|
||||
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||
|
||||
|
||||
def _fmt_ms(v: float) -> str:
|
||||
if v >= 1000:
|
||||
return f"{v / 1000:.1f}k"
|
||||
return f"{v:.1f}"
|
||||
|
||||
|
||||
def main():
|
||||
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument(
|
||||
"--csv",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/bench_results.csv"),
|
||||
help="Path to results CSV (defaults to bench_results.csv)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--out",
|
||||
type=Path,
|
||||
default=Path("add_ablation.pdf"),
|
||||
help="Output image path",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--csv-right",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-auto-cap",
|
||||
action="store_true",
|
||||
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--broken-y",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--lower-cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--upper-start-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
latest_run, latest_rows = load_latest_run(args.csv)
|
||||
avg = aggregate_latency(latest_rows)
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except Exception as e:
|
||||
raise SystemExit(f"matplotlib not available: {e}")
|
||||
|
||||
scenarios = DEFAULT_SCENARIOS
|
||||
values = [avg.get(name, 0.0) for name in scenarios]
|
||||
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||
|
||||
# If right CSV is provided, build side-by-side figure
|
||||
if args.csv_right is not None:
|
||||
try:
|
||||
right_rows_all = []
|
||||
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||
rreader = csv.DictReader(f)
|
||||
right_rows_all = list(rreader)
|
||||
if right_rows_all:
|
||||
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||
else:
|
||||
r_latest = None
|
||||
right_rows = []
|
||||
except Exception:
|
||||
r_latest = None
|
||||
right_rows = []
|
||||
|
||||
a_total = 0.0
|
||||
b_makespan = 0.0
|
||||
for r in right_rows:
|
||||
sc = (r.get("scenario", "") or "").strip().upper()
|
||||
if sc == "A":
|
||||
try:
|
||||
a_total = float(r.get("total_time_s", 0.0))
|
||||
except Exception:
|
||||
pass
|
||||
elif sc == "B":
|
||||
try:
|
||||
b_makespan = float(r.get("makespan_s", 0.0))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import gridspec
|
||||
|
||||
# Left subplot (reuse current style, with optional cap)
|
||||
cap = args.cap_y
|
||||
if cap is None and not args.no_auto_cap:
|
||||
cap = _auto_cap(values)
|
||||
x = list(range(len(labels)))
|
||||
|
||||
if args.broken_y:
|
||||
# Use broken axis for left subplot
|
||||
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||
gs = gridspec.GridSpec(
|
||||
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||
)
|
||||
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||
ax_right = fig.add_subplot(gs[:, 1])
|
||||
|
||||
# Determine break points
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = (
|
||||
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||
) # Increased to show more range
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.5, lower_cap * 1.02)
|
||||
)
|
||||
ymax = (
|
||||
max(values) * 1.90 if values else 1.0
|
||||
) # Increase headroom to 1.90 for text label and tick range
|
||||
|
||||
# Draw bars on both axes
|
||||
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
|
||||
# Set limits
|
||||
ax_left_bottom.set_ylim(0, lower_cap)
|
||||
ax_left_top.set_ylim(upper_start, ymax)
|
||||
|
||||
# Annotate values (convert ms to s)
|
||||
values_s = [v / 1000.0 for v in values]
|
||||
lower_cap_s = lower_cap / 1000.0
|
||||
upper_start_s = upper_start / 1000.0
|
||||
ymax_s = ymax / 1000.0
|
||||
|
||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||
|
||||
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||
ax_left_bottom.clear()
|
||||
ax_left_top.clear()
|
||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||
# Draw in bottom axis for all bars
|
||||
ax_left_bottom.bar(
|
||||
i,
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||
if v > upper_start_s:
|
||||
ax_left_top.bar(
|
||||
i,
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||
|
||||
for i, v in enumerate(values_s):
|
||||
if v <= lower_cap_s:
|
||||
ax_left_bottom.text(
|
||||
i,
|
||||
v + lower_cap_s * 0.02,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
else:
|
||||
ax_left_top.text(
|
||||
i,
|
||||
v + (ymax_s - upper_start_s) * 0.02,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Hide spines between axes
|
||||
ax_left_top.spines["bottom"].set_visible(False)
|
||||
ax_left_bottom.spines["top"].set_visible(False)
|
||||
ax_left_top.tick_params(
|
||||
labeltop=False, labelbottom=False, bottom=False
|
||||
) # Hide tick marks
|
||||
ax_left_bottom.xaxis.tick_bottom()
|
||||
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||
|
||||
# Draw break marks (matching paper_fig.py style)
|
||||
d = 0.015
|
||||
kwargs = {
|
||||
"transform": ax_left_top.transAxes,
|
||||
"color": "k",
|
||||
"clip_on": False,
|
||||
"linewidth": 0.8,
|
||||
"zorder": 10,
|
||||
}
|
||||
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||
|
||||
ax_left_bottom.set_xticks(x)
|
||||
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||
# Don't set ylabel here - will use fig.text for alignment
|
||||
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||
# Add subtle grid for better readability
|
||||
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||
|
||||
# Set x-axis limits to match bar width with right subplot
|
||||
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||
ax_left_top.set_xlim(-0.6, 3.6)
|
||||
|
||||
ax_left = ax_left_bottom # for compatibility
|
||||
else:
|
||||
# Regular side-by-side layout
|
||||
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||
if val > cap:
|
||||
bars[i].set_hatch("//")
|
||||
ax_left.text(
|
||||
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||
)
|
||||
else:
|
||||
ax_left.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
_fmt_ms(val),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax_left.set_ylim(0, cap * 1.10)
|
||||
_add_break_marker(ax_left, y=0.98)
|
||||
ax_left.set_xticks(x)
|
||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
else:
|
||||
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
for i, v in enumerate(values):
|
||||
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
ax_left.set_xticks(x)
|
||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
ax_left.set_ylabel("Latency (ms per passage)")
|
||||
max_initial = latest_rows[0].get("max_initial", "?")
|
||||
max_updates = latest_rows[0].get("max_updates", "?")
|
||||
ax_left.set_title(
|
||||
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||
)
|
||||
|
||||
# Right subplot (A vs B, seconds) - paper style
|
||||
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||
r_styles = [
|
||||
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||
]
|
||||
# 2 bars, centered with proper spacing
|
||||
xr = [0, 1]
|
||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||
ax_right.bar(
|
||||
xr[i],
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
for i, v in enumerate(r_values):
|
||||
max_v = max(r_values) if r_values else 1.0
|
||||
offset = max(0.0002, 0.02 * max_v)
|
||||
ax_right.text(
|
||||
xr[i],
|
||||
v + offset,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_right.set_xticks(xr)
|
||||
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||
# Don't set ylabel here - will use fig.text for alignment
|
||||
ax_right.tick_params(axis="y", labelsize=10)
|
||||
# Add subtle grid for better readability
|
||||
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||
|
||||
# Set x-axis limits to match left subplot's bar width visually
|
||||
# Accounting for width_ratios=[1.5, 1]:
|
||||
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||
# Right: 2 bars, need same visual width
|
||||
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||
# range_right = 4.2 / 1.5 = 2.8
|
||||
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||
ax_right.set_xlim(-0.9, 1.9)
|
||||
|
||||
# Set y-axis limit with headroom for text labels
|
||||
if r_values:
|
||||
max_v = max(r_values)
|
||||
ax_right.set_ylim(0, max_v * 1.15)
|
||||
|
||||
# Format y-axis to avoid scientific notation
|
||||
ax_right.ticklabel_format(style="plain", axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Add aligned ylabels using fig.text (after tight_layout)
|
||||
# Get the vertical center of the entire figure
|
||||
fig_center_y = 0.5
|
||||
# Left ylabel - closer to left plot
|
||||
left_x = 0.05
|
||||
fig.text(
|
||||
left_x,
|
||||
fig_center_y,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
# Right ylabel - closer to right plot
|
||||
right_bbox = ax_right.get_position()
|
||||
right_x = right_bbox.x0 - 0.07
|
||||
fig.text(
|
||||
right_x,
|
||||
fig_center_y,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||
# Also save PDF for paper
|
||||
pdf_out = args.out.with_suffix(".pdf")
|
||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||
print(f"Saved: {args.out}")
|
||||
print(f"Saved: {pdf_out}")
|
||||
return
|
||||
|
||||
# Broken-Y mode
|
||||
if args.broken_y:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||
2,
|
||||
1,
|
||||
sharex=True,
|
||||
figsize=(7.5, 6.75),
|
||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||
)
|
||||
|
||||
# Determine default breaks from second-highest
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.2, lower_cap * 1.02)
|
||||
)
|
||||
ymax = max(values) * 1.10 if values else 1.0
|
||||
|
||||
x = list(range(len(labels)))
|
||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
|
||||
# Limits
|
||||
ax_bottom.set_ylim(0, lower_cap)
|
||||
ax_top.set_ylim(upper_start, ymax)
|
||||
|
||||
# Annotate values
|
||||
for i, v in enumerate(values):
|
||||
if v <= lower_cap:
|
||||
ax_bottom.text(
|
||||
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||
)
|
||||
else:
|
||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
|
||||
# Hide spines between axes and draw diagonal break marks
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bottom.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||
ax_bottom.xaxis.tick_bottom()
|
||||
|
||||
# Diagonal lines at the break (matching paper_fig.py style)
|
||||
d = 0.015
|
||||
kwargs = {
|
||||
"transform": ax_top.transAxes,
|
||||
"color": "k",
|
||||
"clip_on": False,
|
||||
"linewidth": 0.8,
|
||||
"zorder": 10,
|
||||
}
|
||||
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||
kwargs.update({"transform": ax_bottom.transAxes})
|
||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||
|
||||
ax_bottom.set_xticks(x)
|
||||
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
ax = ax_bottom # for labeling below
|
||||
else:
|
||||
cap = args.cap_y
|
||||
if cap is None and not args.no_auto_cap:
|
||||
cap = _auto_cap(values)
|
||||
|
||||
plt.figure(figsize=(5.4, 3.15))
|
||||
ax = plt.gca()
|
||||
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = []
|
||||
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||
bars.append(bar[0])
|
||||
# Hatch and annotate when capped
|
||||
if val > cap:
|
||||
bars[-1].set_hatch("//")
|
||||
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||
else:
|
||||
ax.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
f"{_fmt_ms(val)}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax.set_ylim(0, cap * 1.10)
|
||||
_add_break_marker(ax, y=0.98)
|
||||
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||
v > cap for v in values
|
||||
) else None
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||
else:
|
||||
ax.bar(labels, values, color=colors[: len(labels)])
|
||||
for idx, val in enumerate(values):
|
||||
ax.text(
|
||||
idx,
|
||||
val + 1.0,
|
||||
f"{_fmt_ms(val)}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||
# Try to extract some context for title
|
||||
max_initial = latest_rows[0].get("max_initial", "?")
|
||||
max_updates = latest_rows[0].get("max_updates", "?")
|
||||
|
||||
if args.broken_y:
|
||||
fig.text(
|
||||
0.02,
|
||||
0.5,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
fig.suptitle(
|
||||
"Add Operation Latency",
|
||||
fontsize=11,
|
||||
y=0.98,
|
||||
fontweight="bold",
|
||||
)
|
||||
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||
else:
|
||||
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||
plt.tight_layout()
|
||||
|
||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||
# Also save PDF for paper
|
||||
pdf_out = args.out.with_suffix(".pdf")
|
||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||
print(f"Saved: {args.out}")
|
||||
print(f"Saved: {pdf_out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,200 +0,0 @@
|
||||
# ColQwen Integration Guide
|
||||
|
||||
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
|
||||
|
||||
## Quick Start
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
### 2. Basic Usage
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 \
|
||||
--pages-dir ./page_images/ # Optional: save page images
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--pdfs`: Directory containing PDF files (or single PDF path)
|
||||
- `--index`: Name for the index (required)
|
||||
- `--model`: `colqwen2` (default) or `colpali`
|
||||
- `--pages-dir`: Directory to save page images (optional)
|
||||
|
||||
### Search Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--top-k`: Number of results to return (default: 5)
|
||||
- `--model`: Model used for search (should match build model)
|
||||
|
||||
### Interactive Q&A
|
||||
```bash
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
```
|
||||
|
||||
**Commands in interactive mode:**
|
||||
- Type your questions naturally
|
||||
- `help`: Show available commands
|
||||
- `quit`/`exit`/`q`: Exit interactive mode
|
||||
|
||||
## 🧪 Test & Reproduce Results
|
||||
|
||||
Run the reproduction test for issue #119:
|
||||
```bash
|
||||
python test_colqwen_reproduction.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. ✅ Check dependencies
|
||||
2. 📥 Download sample PDF (Attention Is All You Need paper)
|
||||
3. 🏗️ Build test index
|
||||
4. 🔍 Run sample queries
|
||||
5. 📊 Show how to generate similarity maps
|
||||
|
||||
## 🎨 Advanced: Similarity Maps
|
||||
|
||||
For visual similarity analysis, use the existing advanced script:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Edit the script to customize:
|
||||
- `QUERY`: Your question
|
||||
- `MODEL`: "colqwen2" or "colpali"
|
||||
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
|
||||
- `SIMILARITY_MAP`: Generate heatmaps
|
||||
- `ANSWER`: Enable Qwen-VL answer generation
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
### ColQwen2 vs ColPali
|
||||
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
|
||||
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
|
||||
|
||||
### Architecture
|
||||
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
|
||||
2. **Vision Encoding**: Process images with ColQwen2/ColPali
|
||||
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
|
||||
4. **Query Processing**: Encode text queries with same model
|
||||
5. **Similarity Search**: Find most relevant pages/regions
|
||||
6. **Visual Maps**: Generate attention heatmaps (optional)
|
||||
|
||||
### Device Support
|
||||
- **CUDA**: Best performance with GPU acceleration
|
||||
- **MPS**: Apple Silicon Mac support
|
||||
- **CPU**: Fallback for any system (slower)
|
||||
|
||||
Auto-detection: CUDA > MPS > CPU
|
||||
|
||||
## 📊 Performance Tips
|
||||
|
||||
### For Best Performance:
|
||||
```bash
|
||||
# Use ColQwen2 for latest features
|
||||
--model colqwen2
|
||||
|
||||
# Save page images for reuse
|
||||
--pages-dir ./cached_pages/
|
||||
|
||||
# Adjust batch size based on GPU memory
|
||||
# (automatically handled)
|
||||
```
|
||||
|
||||
### For Large Document Sets:
|
||||
- Process PDFs in batches
|
||||
- Use SSD storage for index files
|
||||
- Consider using CUDA if available
|
||||
|
||||
## 🔗 Related Resources
|
||||
|
||||
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
|
||||
- **Pylate**: https://github.com/lightonai/pylate
|
||||
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
|
||||
- **ColPali Paper**: Vision-Language Models for Document Retrieval
|
||||
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### PDF Conversion Issues (macOS)
|
||||
```bash
|
||||
# Install poppler
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
- Reduce batch size (automatically handled)
|
||||
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
|
||||
- Process fewer PDFs at once
|
||||
|
||||
### Model Download Issues
|
||||
- Ensure internet connection for first run
|
||||
- Models are cached after first download
|
||||
- Use HuggingFace mirrors if needed
|
||||
|
||||
### Import Errors
|
||||
```bash
|
||||
# Ensure all dependencies installed
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
|
||||
# Check PyTorch installation
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
```
|
||||
|
||||
## 💡 Examples
|
||||
|
||||
### Research Paper Analysis
|
||||
```bash
|
||||
# Index your research papers
|
||||
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
|
||||
|
||||
# Ask research questions
|
||||
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
|
||||
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
|
||||
```
|
||||
|
||||
### Document Q&A
|
||||
```bash
|
||||
# Index business documents
|
||||
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
|
||||
|
||||
# Interactive analysis
|
||||
python -m apps.colqwen_rag ask reports --interactive
|
||||
```
|
||||
|
||||
### Visual Analysis
|
||||
```bash
|
||||
# Generate similarity maps for specific queries
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
# Edit multi-vector-leann-similarity-map.py with your query
|
||||
python multi-vector-leann-similarity-map.py
|
||||
# Check ./figures/ for generated heatmaps
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**
|
||||
@@ -158,95 +158,6 @@ builder.build_index("./indexes/my-notes", chunks)
|
||||
|
||||
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||
|
||||
## Optional Embedding Features
|
||||
|
||||
### Task-Specific Prompt Templates
|
||||
|
||||
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
|
||||
|
||||
- **Indexing documents**: `"title: none | text: "`
|
||||
- **Search queries**: `"task: search result | query: "`
|
||||
|
||||
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
|
||||
|
||||
```bash
|
||||
# Build index with EmbeddingGemma (via LM Studio or Ollama)
|
||||
leann build my-docs \
|
||||
--docs ./documents \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-embeddinggemma-300m-qat \
|
||||
--embedding-api-base http://localhost:1234/v1 \
|
||||
--embedding-prompt-template "title: none | text: " \
|
||||
--force
|
||||
|
||||
# Search with query-specific prompt
|
||||
leann search my-docs \
|
||||
--query "What is quantum computing?" \
|
||||
--embedding-prompt-template "task: search result | query: "
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
|
||||
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
|
||||
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
|
||||
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
|
||||
|
||||
**Python API:**
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder(
|
||||
embedding_mode="openai",
|
||||
embedding_model="text-embedding-embeddinggemma-300m-qat",
|
||||
embedding_options={
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_key": "lm-studio",
|
||||
"prompt_template": "title: none | text: ",
|
||||
},
|
||||
)
|
||||
builder.build_index("./indexes/my-docs", chunks)
|
||||
```
|
||||
|
||||
**References:**
|
||||
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
|
||||
|
||||
### LM Studio Auto-Detection (Optional)
|
||||
|
||||
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
|
||||
|
||||
**Prerequisites:**
|
||||
```bash
|
||||
# Install Node.js (if not already installed)
|
||||
# Then install the LM Studio SDK globally
|
||||
npm install -g @lmstudio/sdk
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
|
||||
2. Queries model metadata via Node.js subprocess
|
||||
3. Automatically unloads model after query (respects your JIT auto-evict settings)
|
||||
4. Falls back to static registry if SDK unavailable
|
||||
|
||||
**No configuration needed** - it works automatically when SDK is installed:
|
||||
|
||||
```bash
|
||||
leann build my-docs \
|
||||
--docs ./documents \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-nomic-embed-text-v1.5 \
|
||||
--embedding-api-base http://localhost:1234/v1
|
||||
# Context length auto-detected if SDK available
|
||||
# Falls back to registry (2048) if not
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- ✅ Automatic token limit detection
|
||||
- ✅ Respects LM Studio JIT auto-evict settings
|
||||
- ✅ No manual registry maintenance
|
||||
- ✅ Graceful fallback if SDK unavailable
|
||||
|
||||
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
@@ -454,7 +365,7 @@ leann search my-index "your query" \
|
||||
|
||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||
|
||||
```bash
|
||||
# One-time: install and configure SkyPilot
|
||||
@@ -544,5 +455,5 @@ Conclusion:
|
||||
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://suhasjs.github.io/files/diskann_neurips19.pdf)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
|
||||
48
docs/faq.md
48
docs/faq.md
@@ -8,51 +8,3 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
|
||||
## 2. When should I use prompt templates?
|
||||
|
||||
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
|
||||
|
||||
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
|
||||
|
||||
**Example usage with EmbeddingGemma:**
|
||||
```bash
|
||||
# Build with document prompt
|
||||
leann build my-docs --embedding-prompt-template "title: none | text: "
|
||||
|
||||
# Search with query prompt
|
||||
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
|
||||
```
|
||||
|
||||
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
|
||||
|
||||
## 3. Why is LM Studio loading multiple copies of my model?
|
||||
|
||||
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
|
||||
|
||||
**If you still see duplicates:**
|
||||
- Update to the latest LEANN version
|
||||
- Restart LM Studio to clear loaded models
|
||||
- Check that you have JIT auto-evict enabled in LM Studio settings
|
||||
|
||||
**How it works now:**
|
||||
1. LEANN loads model temporarily to get context length
|
||||
2. Immediately unloads after query
|
||||
3. LM Studio JIT loads model on-demand for actual embeddings
|
||||
4. Auto-evicts per your settings
|
||||
|
||||
## 4. Do I need Node.js and @lmstudio/sdk?
|
||||
|
||||
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
|
||||
|
||||
**Benefits if you install it:**
|
||||
- Automatic context length detection for LM Studio models
|
||||
- No manual registry maintenance
|
||||
- Always gets accurate token limits from the model itself
|
||||
|
||||
**To install (optional):**
|
||||
```bash
|
||||
npm install -g @lmstudio/sdk
|
||||
```
|
||||
|
||||
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.
|
||||
|
||||
@@ -1,395 +0,0 @@
|
||||
# Slack Integration Setup Guide
|
||||
|
||||
This guide provides step-by-step instructions for setting up Slack integration with LEANN.
|
||||
|
||||
## Overview
|
||||
|
||||
LEANN's Slack integration uses MCP (Model Context Protocol) servers to fetch and index your Slack messages for RAG (Retrieval-Augmented Generation). This allows you to search through your Slack conversations using natural language queries.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Slack Workspace Access**: You need admin or owner permissions in your Slack workspace to create apps and configure OAuth tokens.
|
||||
|
||||
2. **Slack MCP Server**: Install a Slack MCP server (e.g., `slack-mcp-server` via npm)
|
||||
|
||||
3. **LEANN**: Ensure you have LEANN installed and working
|
||||
|
||||
## Step 1: Create a Slack App
|
||||
|
||||
### 1.1 Go to Slack API Dashboard
|
||||
|
||||
1. Visit [https://api.slack.com/apps](https://api.slack.com/apps)
|
||||
2. Click **"Create New App"**
|
||||
3. Choose **"From scratch"**
|
||||
4. Enter your app name (e.g., "LEANN Slack Integration")
|
||||
5. Select your workspace
|
||||
6. Click **"Create App"**
|
||||
|
||||
### 1.2 Configure App Permissions
|
||||
|
||||
#### Token Scopes
|
||||
|
||||
1. In your app dashboard, go to **"OAuth & Permissions"** in the left sidebar
|
||||
2. Scroll down to **"Scopes"** section
|
||||
3. Under **"Bot Token Scopes & OAuth Scope"**, click **"Add an OAuth Scope"**
|
||||
4. Add the following scopes:
|
||||
- `channels:read` - Read public channel information
|
||||
- `channels:history` - Read messages in public channels
|
||||
- `groups:read` - Read private channel information
|
||||
- `groups:history` - Read messages in private channels
|
||||
- `im:read` - Read direct message information
|
||||
- `im:history` - Read direct messages
|
||||
- `mpim:read` - Read group direct message information
|
||||
- `mpim:history` - Read group direct messages
|
||||
- `users:read` - Read user information
|
||||
- `team:read` - Read workspace information
|
||||
|
||||
#### App-Level Tokens (Optional)
|
||||
|
||||
Some MCP servers may require app-level tokens:
|
||||
|
||||
1. Go to **"Basic Information"** in the left sidebar
|
||||
2. Scroll down to **"App-Level Tokens"**
|
||||
3. Click **"Generate Token and Scopes"**
|
||||
4. Enter a name (e.g., "LEANN Integration")
|
||||
5. Add the `connections:write` scope
|
||||
6. Click **"Generate"**
|
||||
7. Copy the token (starts with `xapp-`)
|
||||
|
||||
### 1.3 Install App to Workspace
|
||||
|
||||
1. Go to **"OAuth & Permissions"** in the left sidebar
|
||||
2. Click **"Install to Workspace"**
|
||||
3. Review the permissions and click **"Allow"**
|
||||
4. Copy the **"Bot User OAuth Token"** (starts with `xoxb-`)
|
||||
5. Copy the **"User OAuth Token"** (starts with `xoxp-`)
|
||||
|
||||
## Step 2: Install Slack MCP Server
|
||||
|
||||
### Option A: Using npm (Recommended)
|
||||
|
||||
```bash
|
||||
# Install globally
|
||||
npm install -g slack-mcp-server
|
||||
|
||||
# Or install locally
|
||||
npm install slack-mcp-server
|
||||
```
|
||||
|
||||
### Option B: Using npx (No installation required)
|
||||
|
||||
```bash
|
||||
# Use directly without installation
|
||||
npx slack-mcp-server
|
||||
```
|
||||
|
||||
## Step 3: Install and Configure Ollama (for Real LLM Responses)
|
||||
|
||||
### 3.1 Install Ollama
|
||||
|
||||
```bash
|
||||
# Install Ollama using Homebrew (macOS)
|
||||
brew install ollama
|
||||
|
||||
# Or download from https://ollama.ai/
|
||||
```
|
||||
|
||||
### 3.2 Start Ollama Service
|
||||
|
||||
```bash
|
||||
# Start Ollama as a service
|
||||
brew services start ollama
|
||||
|
||||
# Or start manually
|
||||
ollama serve
|
||||
```
|
||||
|
||||
### 3.3 Pull a Model
|
||||
|
||||
```bash
|
||||
# Pull a lightweight model for testing
|
||||
ollama pull llama3.2:1b
|
||||
|
||||
# Verify the model is available
|
||||
ollama list
|
||||
```
|
||||
|
||||
## Step 4: Configure Environment Variables
|
||||
|
||||
Create a `.env` file or set environment variables:
|
||||
|
||||
```bash
|
||||
# Required: User OAuth Token
|
||||
SLACK_OAUTH_TOKEN=xoxp-your-user-oauth-token-here
|
||||
|
||||
# Optional: App-Level Token (if your MCP server requires it)
|
||||
SLACK_APP_TOKEN=xapp-your-app-token-here
|
||||
|
||||
# Optional: Workspace-specific settings
|
||||
SLACK_WORKSPACE_ID=T1234567890 # Your workspace ID (optional)
|
||||
```
|
||||
|
||||
## Step 5: Test the Setup
|
||||
|
||||
### 5.1 Test MCP Server Connection
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--test-connection \
|
||||
--workspace-name "Your Workspace Name"
|
||||
```
|
||||
|
||||
This will test the connection and list available tools without indexing any data.
|
||||
|
||||
### 5.2 Index a Specific Channel
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Your Workspace Name" \
|
||||
--channels general \
|
||||
--query "What did we discuss about the project?"
|
||||
```
|
||||
|
||||
### 5.3 Real RAG Query Examples
|
||||
|
||||
This section demonstrates successful Slack RAG integration queries against the Sky Lab Computing workspace's "random" channel. The system successfully retrieves actual conversation messages and performs semantic search with high relevance scores, including finding specific research paper announcements and technical discussions.
|
||||
|
||||
### Example 1: Advisor Models Query
|
||||
|
||||
**Query:** "train black-box models to adopt to your personal data"
|
||||
|
||||
This query demonstrates the system's ability to find specific research announcements about training black-box models for personal data adaptation.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Example 2: Barbarians at the Gate Query
|
||||
|
||||
**Query:** "AI-driven research systems ADRS"
|
||||
|
||||
This query demonstrates the system's ability to find specific research announcements about AI-driven research systems and algorithm discovery.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Bot is installed in the Sky Lab Computing workspace and invited to the target channel (run `/invite @YourBotName` in the channel if needed)
|
||||
- Bot token available and exported in the same terminal session
|
||||
|
||||
### Commands
|
||||
|
||||
1) Set the workspace token for this shell
|
||||
|
||||
```bash
|
||||
export SLACK_MCP_XOXP_TOKEN="xoxp-***-redacted-***"
|
||||
```
|
||||
|
||||
2) Run queries against the "random" channel by channel ID (C0GN5BX0F)
|
||||
|
||||
**Advisor Models Query:**
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--channels C0GN5BX0F \
|
||||
--max-messages-per-channel 100000 \
|
||||
--query "train black-box models to adopt to your personal data" \
|
||||
--llm ollama \
|
||||
--llm-model "llama3.2:1b" \
|
||||
--llm-host "http://localhost:11434" \
|
||||
--no-concatenate-conversations
|
||||
```
|
||||
|
||||
**Barbarians at the Gate Query:**
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--channels C0GN5BX0F \
|
||||
--max-messages-per-channel 100000 \
|
||||
--query "AI-driven research systems ADRS" \
|
||||
--llm ollama \
|
||||
--llm-model "llama3.2:1b" \
|
||||
--llm-host "http://localhost:11434" \
|
||||
--no-concatenate-conversations
|
||||
```
|
||||
|
||||
These examples demonstrate the system's ability to find and retrieve specific research announcements and technical discussions from the conversation history, showcasing the power of semantic search in Slack data.
|
||||
|
||||
3) Optional: Ask a broader question
|
||||
|
||||
```bash
|
||||
python test_channel_by_id_or_name.py \
|
||||
--channel-id C0GN5BX0F \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--query "What is LEANN about?"
|
||||
```
|
||||
|
||||
Notes:
|
||||
- If you see `not_in_channel`, invite the bot to the channel and re-run.
|
||||
- If you see `channel_not_found`, confirm the channel ID and workspace.
|
||||
- Deep search via server-side “search” tools may require additional Slack scopes; the example above performs client-side filtering over retrieved history.
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue 1: "users cache is not ready yet" Error
|
||||
|
||||
**Problem**: You see this warning:
|
||||
```
|
||||
WARNING - Failed to fetch messages from channel random: Failed to fetch messages: {'code': -32603, 'message': 'users cache is not ready yet, sync process is still running... please wait'}
|
||||
```
|
||||
|
||||
**Solution**: This is a common timing issue. The LEANN integration now includes automatic retry logic:
|
||||
|
||||
1. **Wait and Retry**: The system will automatically retry with exponential backoff (2s, 4s, 8s, etc.)
|
||||
2. **Increase Retry Parameters**: If needed, you can customize retry behavior:
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--max-retries 10 \
|
||||
--retry-delay 3.0 \
|
||||
--channels general \
|
||||
--query "Your query here"
|
||||
```
|
||||
3. **Keep MCP Server Running**: Start the MCP server separately and keep it running:
|
||||
```bash
|
||||
# Terminal 1: Start MCP server
|
||||
slack-mcp-server
|
||||
|
||||
# Terminal 2: Run LEANN (it will connect to the running server)
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --channels general --query "test"
|
||||
```
|
||||
|
||||
### Issue 2: "No message fetching tool found"
|
||||
|
||||
**Problem**: The MCP server doesn't have the expected tools.
|
||||
|
||||
**Solution**:
|
||||
1. Check if your MCP server is properly installed and configured
|
||||
2. Verify your Slack tokens are correct
|
||||
3. Try a different MCP server implementation
|
||||
4. Check the MCP server documentation for required configuration
|
||||
|
||||
### Issue 3: Permission Denied Errors
|
||||
|
||||
**Problem**: You get permission errors when trying to access channels.
|
||||
|
||||
**Solutions**:
|
||||
1. **Check Bot Permissions**: Ensure your bot has been added to the channels you want to access
|
||||
2. **Verify Token Scopes**: Make sure you have all required scopes configured
|
||||
3. **Channel Access**: For private channels, the bot needs to be explicitly invited
|
||||
4. **Workspace Permissions**: Ensure your Slack app has the necessary workspace permissions
|
||||
|
||||
### Issue 4: Empty Results
|
||||
|
||||
**Problem**: No messages are returned even though the channel has messages.
|
||||
|
||||
**Solutions**:
|
||||
1. **Check Channel Names**: Ensure channel names are correct (without the # symbol)
|
||||
2. **Verify Bot Access**: Make sure the bot can access the channels
|
||||
3. **Check Date Ranges**: Some MCP servers have limitations on message history
|
||||
4. **Increase Message Limits**: Try increasing the message limit:
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--channels general \
|
||||
--max-messages-per-channel 1000 \
|
||||
--query "test"
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Custom MCP Server Commands
|
||||
|
||||
If you need to pass additional parameters to your MCP server:
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server --token-file /path/to/tokens.json" \
|
||||
--workspace-name "Your Workspace" \
|
||||
--channels general \
|
||||
--query "Your query"
|
||||
```
|
||||
|
||||
### Multiple Workspaces
|
||||
|
||||
To work with multiple Slack workspaces, you can:
|
||||
|
||||
1. Create separate apps for each workspace
|
||||
2. Use different environment variables
|
||||
3. Run separate instances with different configurations
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
For better performance with large workspaces:
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Your Workspace" \
|
||||
--max-messages-per-channel 500 \
|
||||
--no-concatenate-conversations \
|
||||
--query "Your query"
|
||||
```
|
||||
---
|
||||
|
||||
## Troubleshooting Checklist
|
||||
|
||||
- [ ] Slack app created with proper permissions
|
||||
- [ ] Bot token (xoxb-) copied correctly
|
||||
- [ ] App-level token (xapp-) created if needed
|
||||
- [ ] MCP server installed and accessible
|
||||
- [ ] Ollama installed and running (`brew services start ollama`)
|
||||
- [ ] Ollama model pulled (`ollama pull llama3.2:1b`)
|
||||
- [ ] Environment variables set correctly
|
||||
- [ ] Bot invited to relevant channels
|
||||
- [ ] Channel names specified without # symbol
|
||||
- [ ] Sufficient retry attempts configured
|
||||
- [ ] Network connectivity to Slack APIs
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you continue to have issues:
|
||||
|
||||
1. **Check Logs**: Look for detailed error messages in the console output
|
||||
2. **Test MCP Server**: Use `--test-connection` to verify the MCP server is working
|
||||
3. **Verify Tokens**: Double-check that your Slack tokens are valid and have the right scopes
|
||||
4. **Check Ollama**: Ensure Ollama is running (`ollama serve`) and the model is available (`ollama list`)
|
||||
5. **Community Support**: Reach out to the LEANN community for help
|
||||
|
||||
## Example Commands
|
||||
|
||||
### Basic Usage
|
||||
```bash
|
||||
# Test connection
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
|
||||
|
||||
# Index specific channels
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "My Company" \
|
||||
--channels general random \
|
||||
--query "What did we decide about the project timeline?"
|
||||
```
|
||||
|
||||
### Advanced Usage
|
||||
```bash
|
||||
# With custom retry settings
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "My Company" \
|
||||
--channels general \
|
||||
--max-retries 10 \
|
||||
--retry-delay 5.0 \
|
||||
--max-messages-per-channel 2000 \
|
||||
--query "Show me all decisions made in the last month"
|
||||
```
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 445 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 508 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 437 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 474 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 501 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 454 KiB |
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.3.5"
|
||||
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
|
||||
version = "0.3.4"
|
||||
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
|
||||
@@ -29,25 +29,12 @@ if(APPLE)
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||
endif()
|
||||
|
||||
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
|
||||
# Use system ZeroMQ instead of building from source
|
||||
find_package(PkgConfig REQUIRED)
|
||||
|
||||
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
|
||||
endif()
|
||||
|
||||
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
|
||||
|
||||
# This creates PkgConfig::ZMQ target automatically with correct properties
|
||||
if(TARGET PkgConfig::ZMQ)
|
||||
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
|
||||
else()
|
||||
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
|
||||
endif()
|
||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(SYSTEM third_party/cppzmq)
|
||||
include_directories(third_party/cppzmq)
|
||||
|
||||
# Configure msgpack-c - disable boost dependency
|
||||
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
|
||||
@@ -215,8 +215,6 @@ class HNSWSearcher(BaseSearcher):
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
if hasattr(self._index, "set_zmq_port"):
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
|
||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.3.5"
|
||||
version = "0.3.4"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = [
|
||||
"leann-core==0.3.5",
|
||||
"leann-core==0.3.4",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: e2d243c40d...5952745237
@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.3.5"
|
||||
version = "0.3.4"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
|
||||
# All required dependencies included
|
||||
@@ -18,16 +18,14 @@ dependencies = [
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
"torch>=2.0.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
||||
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
||||
"python-dotenv>=1.0.0",
|
||||
"openai>=1.0.0",
|
||||
"huggingface-hub>=0.20.0",
|
||||
# Keep transformers below 4.46: 4.46.0 adds Python 3.10-only return type syntax and
|
||||
# breaks Python 3.9 environments.
|
||||
"transformers>=4.30.0,<4.46",
|
||||
"transformers>=4.30.0",
|
||||
"requests>=2.25.0",
|
||||
"accelerate>=0.20.0",
|
||||
"PyPDF2>=3.0.0",
|
||||
@@ -42,7 +40,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
colab = [
|
||||
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
||||
"transformers>=4.30.0,<4.46", # 4.46.0 switches to PEP 604 typing (int | None), breaks Py3.9
|
||||
"transformers>=4.30.0,<5.0.0", # Limit transformers version
|
||||
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
||||
]
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import Any, Literal, Optional, Union
|
||||
import numpy as np
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
from leann.interactive_utils import create_api_session
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
|
||||
from .chat import get_llm
|
||||
@@ -820,10 +819,10 @@ class LeannBuilder:
|
||||
actual_port,
|
||||
requested_zmq_port,
|
||||
)
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(actual_port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(actual_port)
|
||||
try:
|
||||
index.hnsw.zmq_port = actual_port
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if needs_recompute:
|
||||
for i in range(embeddings.shape[0]):
|
||||
@@ -916,7 +915,6 @@ class LeannSearcher:
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||
batch_size: int = 0,
|
||||
use_grep: bool = False,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
@@ -980,24 +978,10 @@ class LeannSearcher:
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Extract query template from stored embedding_options with fallback chain:
|
||||
# 1. Check provider_options override (highest priority)
|
||||
# 2. Check query_prompt_template (new format)
|
||||
# 3. Check prompt_template (old format for backward compat)
|
||||
# 4. None (no template)
|
||||
query_template = None
|
||||
if provider_options and "prompt_template" in provider_options:
|
||||
query_template = provider_options["prompt_template"]
|
||||
elif "query_prompt_template" in self.embedding_options:
|
||||
query_template = self.embedding_options["query_prompt_template"]
|
||||
elif "prompt_template" in self.embedding_options:
|
||||
query_template = self.embedding_options["prompt_template"]
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
query_template=query_template,
|
||||
)
|
||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
@@ -1251,17 +1235,6 @@ class LeannChat:
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
logger.info("The context provided to the LLM is:")
|
||||
logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
||||
logger.info("-" * 150)
|
||||
for r in results:
|
||||
chunk_relevance = f"{r.score:.3f}"
|
||||
chunk_id = r.id
|
||||
chunk_content = r.text[:60]
|
||||
chunk_source = r.metadata.get("source", "")[:80]
|
||||
logger.info(
|
||||
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
||||
)
|
||||
ask_time = time.time()
|
||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||
ask_time = time.time() - ask_time
|
||||
@@ -1269,14 +1242,19 @@ class LeannChat:
|
||||
return ans
|
||||
|
||||
def start_interactive(self):
|
||||
"""Start interactive chat session."""
|
||||
session = create_api_session()
|
||||
|
||||
def handle_query(user_input: str):
|
||||
response = self.ask(user_input)
|
||||
print(f"Leann: {response}")
|
||||
|
||||
session.run_interactive_loop(handle_query)
|
||||
print("\nLeann Chat started (type 'quit' to exit)")
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
if user_input.lower() in ["quit", "exit"]:
|
||||
break
|
||||
if not user_input:
|
||||
continue
|
||||
response = self.ask(user_input)
|
||||
print(f"Leann: {response}")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
def cleanup(self):
|
||||
"""Explicitly cleanup embedding server resources.
|
||||
|
||||
@@ -12,13 +12,7 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .settings import (
|
||||
resolve_anthropic_api_key,
|
||||
resolve_anthropic_base_url,
|
||||
resolve_ollama_host,
|
||||
resolve_openai_api_key,
|
||||
resolve_openai_base_url,
|
||||
)
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -552,30 +546,11 @@ class OllamaChat(LLMInterface):
|
||||
|
||||
|
||||
class HFChat(LLMInterface):
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates.
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the Hugging Face model to load.
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models as this can pose
|
||||
a security risk if the model repository is compromised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat", trust_remote_code: bool = False
|
||||
):
|
||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||
|
||||
# Security warning when trust_remote_code is enabled
|
||||
if trust_remote_code:
|
||||
logger.warning(
|
||||
"SECURITY WARNING: trust_remote_code=True allows execution of arbitrary code from the model repository. "
|
||||
"Only enable this for models from trusted sources. This creates a potential security risk if the model "
|
||||
"repository is compromised."
|
||||
)
|
||||
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model_name, "hf")
|
||||
if model_error:
|
||||
@@ -613,16 +588,14 @@ class HFChat(LLMInterface):
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer for {model_name}...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
logger.info(f"Loading model {model_name}...")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
logger.info(f"Successfully loaded {model_name}")
|
||||
finally:
|
||||
@@ -840,92 +813,12 @@ class OpenAIChat(LLMInterface):
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(**params)
|
||||
print(
|
||||
f"Total tokens = {response.usage.total_tokens}, prompt tokens = {response.usage.prompt_tokens}, completion tokens = {response.usage.completion_tokens}"
|
||||
)
|
||||
if response.choices[0].finish_reason == "length":
|
||||
print("The query is exceeding the maximum allowed number of tokens")
|
||||
return response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error communicating with OpenAI: {e}")
|
||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||
|
||||
|
||||
class AnthropicChat(LLMInterface):
|
||||
"""LLM interface for Anthropic Claude models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-haiku-4-5",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = resolve_anthropic_base_url(base_url)
|
||||
self.api_key = resolve_anthropic_api_key(api_key)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
|
||||
model,
|
||||
self.base_url,
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
# Allow custom Anthropic-compatible endpoints via base_url
|
||||
self.client = anthropic.Anthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
logger.info(f"Sending request to Anthropic with model {self.model}")
|
||||
|
||||
try:
|
||||
# Anthropic API parameters
|
||||
params = {
|
||||
"model": self.model,
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if "temperature" in kwargs:
|
||||
params["temperature"] = kwargs["temperature"]
|
||||
if "top_p" in kwargs:
|
||||
params["top_p"] = kwargs["top_p"]
|
||||
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
# Extract text from response
|
||||
response_text = response.content[0].text
|
||||
|
||||
# Log token usage
|
||||
print(
|
||||
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
|
||||
f"input tokens = {response.usage.input_tokens}, "
|
||||
f"output tokens = {response.usage.output_tokens}"
|
||||
)
|
||||
|
||||
if response.stop_reason == "max_tokens":
|
||||
print("The query is exceeding the maximum allowed number of tokens")
|
||||
|
||||
return response_text.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error communicating with Anthropic: {e}")
|
||||
return f"Error: Could not get a response from Anthropic. Details: {e}"
|
||||
|
||||
|
||||
class SimulatedChat(LLMInterface):
|
||||
"""A simple simulated chat for testing and development."""
|
||||
|
||||
@@ -966,10 +859,7 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
host=llm_config.get("host"),
|
||||
)
|
||||
elif llm_type == "hf":
|
||||
return HFChat(
|
||||
model_name=model or "deepseek-ai/deepseek-llm-7b-chat",
|
||||
trust_remote_code=llm_config.get("trust_remote_code", False),
|
||||
)
|
||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||
elif llm_type == "openai":
|
||||
return OpenAIChat(
|
||||
model=model or "gpt-4o",
|
||||
@@ -978,12 +868,6 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
)
|
||||
elif llm_type == "gemini":
|
||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "anthropic":
|
||||
return AnthropicChat(
|
||||
model=model or "claude-3-5-sonnet-20241022",
|
||||
api_key=llm_config.get("api_key"),
|
||||
base_url=llm_config.get("base_url"),
|
||||
)
|
||||
elif llm_type == "simulated":
|
||||
return SimulatedChat()
|
||||
else:
|
||||
|
||||
@@ -5,128 +5,12 @@ Packaged within leann-core so installed wheels can import it reliably.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Flag to ensure AST token warning only shown once per session
|
||||
_ast_token_warning_shown = False
|
||||
|
||||
|
||||
def estimate_token_count(text: str) -> int:
|
||||
"""
|
||||
Estimate token count for a text string.
|
||||
Uses conservative estimation: ~4 characters per token for natural text,
|
||||
~1.2 tokens per character for code (worse tokenization).
|
||||
|
||||
Args:
|
||||
text: Input text to estimate tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoder.encode(text))
|
||||
except ImportError:
|
||||
# Fallback: Conservative character-based estimation
|
||||
# Assume worst case for code: 1.2 tokens per character
|
||||
return int(len(text) * 1.2)
|
||||
|
||||
|
||||
def calculate_safe_chunk_size(
|
||||
model_token_limit: int,
|
||||
overlap_tokens: int,
|
||||
chunking_mode: str = "traditional",
|
||||
safety_factor: float = 0.9,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate safe chunk size accounting for overlap and safety margin.
|
||||
|
||||
Args:
|
||||
model_token_limit: Maximum tokens supported by embedding model
|
||||
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
|
||||
chunking_mode: "traditional" (tokens) or "ast" (characters)
|
||||
safety_factor: Safety margin (0.9 = 10% safety margin)
|
||||
|
||||
Returns:
|
||||
Safe chunk size: tokens for traditional, characters for AST
|
||||
"""
|
||||
safe_limit = int(model_token_limit * safety_factor)
|
||||
|
||||
if chunking_mode == "traditional":
|
||||
# Traditional chunking uses tokens
|
||||
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
|
||||
return max(1, safe_limit - overlap_tokens)
|
||||
else: # AST chunking
|
||||
# AST uses characters, need to convert
|
||||
# Conservative estimate: 1.2 tokens per char for code
|
||||
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
|
||||
safe_chars = int(safe_limit / 1.2)
|
||||
return max(1, safe_chars - overlap_chars)
|
||||
|
||||
|
||||
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
|
||||
"""
|
||||
Validate that chunks don't exceed token limits and truncate if necessary.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to validate
|
||||
max_tokens: Maximum tokens allowed per chunk
|
||||
|
||||
Returns:
|
||||
Tuple of (validated_chunks, num_truncated)
|
||||
"""
|
||||
validated_chunks = []
|
||||
num_truncated = 0
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
estimated_tokens = estimate_token_count(chunk)
|
||||
|
||||
if estimated_tokens > max_tokens:
|
||||
# Truncate chunk to fit token limit
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = encoder.encode(chunk)
|
||||
if len(tokens) > max_tokens:
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_chunk = encoder.decode(truncated_tokens)
|
||||
validated_chunks.append(truncated_chunk)
|
||||
num_truncated += 1
|
||||
logger.warning(
|
||||
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
|
||||
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
|
||||
)
|
||||
else:
|
||||
validated_chunks.append(chunk)
|
||||
except ImportError:
|
||||
# Fallback: Conservative character truncation
|
||||
char_limit = int(max_tokens / 1.2) # Conservative for code
|
||||
if len(chunk) > char_limit:
|
||||
truncated_chunk = chunk[:char_limit]
|
||||
validated_chunks.append(truncated_chunk)
|
||||
num_truncated += 1
|
||||
logger.warning(
|
||||
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
|
||||
f"(conservative estimate for {max_tokens} tokens)"
|
||||
)
|
||||
else:
|
||||
validated_chunks.append(chunk)
|
||||
else:
|
||||
validated_chunks.append(chunk)
|
||||
|
||||
if num_truncated > 0:
|
||||
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
|
||||
|
||||
return validated_chunks, num_truncated
|
||||
|
||||
|
||||
# Code file extensions supported by astchunk
|
||||
CODE_EXTENSIONS = {
|
||||
".py": "python",
|
||||
@@ -177,45 +61,27 @@ def create_ast_chunks(
|
||||
max_chunk_size: int = 512,
|
||||
chunk_overlap: int = 64,
|
||||
metadata_template: str = "default",
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[str]:
|
||||
"""Create AST-aware chunks from code documents using astchunk.
|
||||
|
||||
Falls back to traditional chunking if astchunk is unavailable.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
try:
|
||||
from astchunk import ASTChunkBuilder # optional dependency
|
||||
except ImportError as e:
|
||||
logger.error(f"astchunk not available: {e}")
|
||||
logger.info("Falling back to traditional chunking for code files")
|
||||
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
|
||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||
|
||||
all_chunks = []
|
||||
for doc in documents:
|
||||
language = doc.metadata.get("language")
|
||||
if not language:
|
||||
logger.warning("No language detected; falling back to traditional chunking")
|
||||
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||
continue
|
||||
|
||||
try:
|
||||
# Warn once if AST chunk size + overlap might exceed common token limits
|
||||
# Note: Actual truncation happens at embedding time with dynamic model limits
|
||||
global _ast_token_warning_shown
|
||||
estimated_max_tokens = int(
|
||||
(max_chunk_size + chunk_overlap) * 1.2
|
||||
) # Conservative estimate
|
||||
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
|
||||
logger.warning(
|
||||
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
||||
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
||||
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
|
||||
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
|
||||
)
|
||||
_ast_token_warning_shown = True
|
||||
|
||||
configs = {
|
||||
"max_chunk_size": max_chunk_size,
|
||||
"language": language,
|
||||
@@ -239,40 +105,17 @@ def create_ast_chunks(
|
||||
|
||||
chunks = chunk_builder.chunkify(code_content)
|
||||
for chunk in chunks:
|
||||
chunk_text = None
|
||||
astchunk_metadata = {}
|
||||
|
||||
if hasattr(chunk, "text"):
|
||||
chunk_text = chunk.text
|
||||
elif isinstance(chunk, dict) and "text" in chunk:
|
||||
chunk_text = chunk["text"]
|
||||
elif isinstance(chunk, str):
|
||||
chunk_text = chunk
|
||||
elif isinstance(chunk, dict):
|
||||
# Handle astchunk format: {"content": "...", "metadata": {...}}
|
||||
if "content" in chunk:
|
||||
chunk_text = chunk["content"]
|
||||
astchunk_metadata = chunk.get("metadata", {})
|
||||
elif "text" in chunk:
|
||||
chunk_text = chunk["text"]
|
||||
else:
|
||||
chunk_text = str(chunk) # Last resort
|
||||
else:
|
||||
chunk_text = str(chunk)
|
||||
|
||||
if chunk_text and chunk_text.strip():
|
||||
# Extract document-level metadata
|
||||
doc_metadata = {
|
||||
"file_path": doc.metadata.get("file_path", ""),
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
}
|
||||
if "creation_date" in doc.metadata:
|
||||
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||
if "last_modified_date" in doc.metadata:
|
||||
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||
|
||||
# Merge document metadata + astchunk metadata
|
||||
combined_metadata = {**doc_metadata, **astchunk_metadata}
|
||||
|
||||
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
|
||||
all_chunks.append(chunk_text.strip())
|
||||
|
||||
logger.info(
|
||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||
@@ -280,19 +123,15 @@ def create_ast_chunks(
|
||||
except Exception as e:
|
||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||
logger.info("Falling back to traditional chunking")
|
||||
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||
|
||||
return all_chunks
|
||||
|
||||
|
||||
def create_traditional_chunks(
|
||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
) -> list[str]:
|
||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||
if chunk_size <= 0:
|
||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||
chunk_size = 256
|
||||
@@ -308,40 +147,19 @@ def create_traditional_chunks(
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
result = []
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Extract document-level metadata
|
||||
doc_metadata = {
|
||||
"file_path": doc.metadata.get("file_path", ""),
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
}
|
||||
if "creation_date" in doc.metadata:
|
||||
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||
if "last_modified_date" in doc.metadata:
|
||||
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||
|
||||
try:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
for node in nodes:
|
||||
result.append({"text": node.get_content(), "metadata": doc_metadata})
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
except Exception as e:
|
||||
logger.error(f"Traditional chunking failed for document: {e}")
|
||||
content = doc.get_content()
|
||||
if content and content.strip():
|
||||
result.append({"text": content.strip(), "metadata": doc_metadata})
|
||||
all_texts.append(content.strip())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _traditional_chunks_as_dicts(
|
||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Helper: Traditional chunking that returns dict format for consistency.
|
||||
|
||||
This is now just an alias for create_traditional_chunks for backwards compatibility.
|
||||
"""
|
||||
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||
return all_texts
|
||||
|
||||
|
||||
def create_text_chunks(
|
||||
@@ -353,12 +171,8 @@ def create_text_chunks(
|
||||
ast_chunk_overlap: int = 64,
|
||||
code_file_extensions: Optional[list[str]] = None,
|
||||
ast_fallback_traditional: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Create text chunks from documents with optional AST support for code files.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
) -> list[str]:
|
||||
"""Create text chunks from documents with optional AST support for code files."""
|
||||
if not documents:
|
||||
logger.warning("No documents provided for chunking")
|
||||
return []
|
||||
@@ -393,17 +207,14 @@ def create_text_chunks(
|
||||
logger.error(f"AST chunking failed: {e}")
|
||||
if ast_fallback_traditional:
|
||||
all_chunks.extend(
|
||||
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
|
||||
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||
)
|
||||
else:
|
||||
raise
|
||||
if text_docs:
|
||||
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
|
||||
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||
else:
|
||||
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
|
||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||
|
||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||
|
||||
# Note: Token truncation is now handled at embedding time with dynamic model limits
|
||||
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
|
||||
return all_chunks
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -9,14 +8,8 @@ from llama_index.core.node_parser import SentenceSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .interactive_utils import create_cli_session
|
||||
from .registry import register_project_directory
|
||||
from .settings import (
|
||||
resolve_anthropic_base_url,
|
||||
resolve_ollama_host,
|
||||
resolve_openai_api_key,
|
||||
resolve_openai_base_url,
|
||||
)
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
|
||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||
@@ -112,7 +105,7 @@ Examples:
|
||||
help="Documents directories and/or files (default: current directory)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--backend-name",
|
||||
"--backend",
|
||||
type=str,
|
||||
default="hnsw",
|
||||
choices=["hnsw", "diskann"],
|
||||
@@ -149,18 +142,6 @@ Examples:
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--query-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template for queries (different from build template for task-specific models)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||
)
|
||||
@@ -198,25 +179,25 @@ Examples:
|
||||
"--doc-chunk-size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
|
||||
help="Document chunk size in tokens/characters (default: 256)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--doc-chunk-overlap",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
|
||||
help="Document chunk overlap (default: 128)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--code-chunk-size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
|
||||
help="Code chunk size in tokens/lines (default: 512)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--code-chunk-overlap",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
|
||||
help="Code chunk overlap (default: 50)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--use-ast-chunking",
|
||||
@@ -226,14 +207,14 @@ Examples:
|
||||
build_parser.add_argument(
|
||||
"--ast-chunk-size",
|
||||
type=int,
|
||||
default=300,
|
||||
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
|
||||
default=768,
|
||||
help="AST chunk size in characters (default: 768)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--ast-chunk-overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
|
||||
default=96,
|
||||
help="AST chunk overlap in characters (default: 96)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--ast-fallback-traditional",
|
||||
@@ -272,17 +253,6 @@ Examples:
|
||||
action="store_true",
|
||||
help="Non-interactive mode: automatically select index without prompting",
|
||||
)
|
||||
search_parser.add_argument(
|
||||
"--show-metadata",
|
||||
action="store_true",
|
||||
help="Display file paths and metadata in search results",
|
||||
)
|
||||
search_parser.add_argument(
|
||||
"--embedding-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
@@ -296,7 +266,7 @@ Examples:
|
||||
"--llm",
|
||||
type=str,
|
||||
default="ollama",
|
||||
choices=["simulated", "ollama", "hf", "openai", "anthropic"],
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
help="LLM provider (default: ollama)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
@@ -346,7 +316,7 @@ Examples:
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for cloud LLM providers (OpenAI, Anthropic)",
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# List command
|
||||
@@ -1185,11 +1155,6 @@ Examples:
|
||||
print(f"Warning: Could not process {file_path}: {e}")
|
||||
|
||||
# Load other file types with default reader
|
||||
# Exclude PDFs from code_extensions if they were already processed separately
|
||||
other_file_extensions = code_extensions
|
||||
if should_process_pdfs and ".pdf" in code_extensions:
|
||||
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
|
||||
|
||||
try:
|
||||
# Create a custom file filter function using our PathSpec
|
||||
def file_filter(
|
||||
@@ -1205,26 +1170,21 @@ Examples:
|
||||
except (ValueError, OSError):
|
||||
return True # Include files that can't be processed
|
||||
|
||||
# Only load other file types if there are extensions to process
|
||||
if other_file_extensions:
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=other_file_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
exclude_hidden=not include_hidden,
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
else:
|
||||
other_docs = []
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=code_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
exclude_hidden=not include_hidden,
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Filter documents after loading based on gitignore rules
|
||||
filtered_docs = []
|
||||
for doc in other_docs:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
doc.metadata["source"] = file_path
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents.extend(filtered_docs)
|
||||
@@ -1300,7 +1260,7 @@ Examples:
|
||||
from .chunking_utils import create_text_chunks
|
||||
|
||||
# Use enhanced chunking with AST support
|
||||
chunk_texts = create_text_chunks(
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=self.node_parser.chunk_size,
|
||||
chunk_overlap=self.node_parser.chunk_overlap,
|
||||
@@ -1311,9 +1271,6 @@ Examples:
|
||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||
)
|
||||
|
||||
# create_text_chunks now returns list[dict] with metadata preserved
|
||||
all_texts.extend(chunk_texts)
|
||||
|
||||
except ImportError as e:
|
||||
print(
|
||||
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||
@@ -1325,27 +1282,14 @@ Examples:
|
||||
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
||||
# Check if this is a code file based on source path
|
||||
source_path = doc.metadata.get("source", "")
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
||||
|
||||
# Extract metadata to preserve with chunks
|
||||
chunk_metadata = {
|
||||
"file_path": file_path or source_path,
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
}
|
||||
|
||||
# Add optional metadata if available
|
||||
if "creation_date" in doc.metadata:
|
||||
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||
if "last_modified_date" in doc.metadata:
|
||||
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||
|
||||
# Use appropriate parser based on file type
|
||||
parser = self.code_parser if is_code_file else self.node_parser
|
||||
nodes = parser.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||
return all_texts
|
||||
@@ -1420,7 +1364,7 @@ Examples:
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend_name} backend...")
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
|
||||
embedding_options: dict[str, Any] = {}
|
||||
if args.embedding_mode == "ollama":
|
||||
@@ -1430,17 +1374,9 @@ Examples:
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
if args.query_prompt_template:
|
||||
# New format: separate templates
|
||||
if args.embedding_prompt_template:
|
||||
embedding_options["build_prompt_template"] = args.embedding_prompt_template
|
||||
embedding_options["query_prompt_template"] = args.query_prompt_template
|
||||
elif args.embedding_prompt_template:
|
||||
# Old format: single template (backward compat)
|
||||
embedding_options["prompt_template"] = args.embedding_prompt_template
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
embedding_options=embedding_options or None,
|
||||
@@ -1451,8 +1387,8 @@ Examples:
|
||||
num_threads=args.num_threads,
|
||||
)
|
||||
|
||||
for chunk in all_texts:
|
||||
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
@@ -1559,11 +1495,6 @@ Examples:
|
||||
print("Invalid input. Aborting search.")
|
||||
return
|
||||
|
||||
# Build provider_options for runtime override
|
||||
provider_options = {}
|
||||
if args.embedding_prompt_template:
|
||||
provider_options["prompt_template"] = args.embedding_prompt_template
|
||||
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
results = searcher.search(
|
||||
query,
|
||||
@@ -1573,31 +1504,12 @@ Examples:
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
provider_options=provider_options if provider_options else None,
|
||||
)
|
||||
|
||||
print(f"Search results for '{query}' (top {len(results)}):")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"{i}. Score: {result.score:.3f}")
|
||||
|
||||
# Display metadata if flag is set
|
||||
if args.show_metadata and result.metadata:
|
||||
file_path = result.metadata.get("file_path", "")
|
||||
if file_path:
|
||||
print(f" 📄 File: {file_path}")
|
||||
|
||||
file_name = result.metadata.get("file_name", "")
|
||||
if file_name and file_name != file_path:
|
||||
print(f" 📝 Name: {file_name}")
|
||||
|
||||
# Show timestamps if available
|
||||
if "creation_date" in result.metadata:
|
||||
print(f" 🕐 Created: {result.metadata['creation_date']}")
|
||||
if "last_modified_date" in result.metadata:
|
||||
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
|
||||
|
||||
print(f" {result.text[:200]}...")
|
||||
print(f" Source: {result.metadata.get('source', '')}")
|
||||
print()
|
||||
|
||||
async def ask_questions(self, args):
|
||||
@@ -1621,12 +1533,6 @@ Examples:
|
||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||
if resolved_api_key:
|
||||
llm_config["api_key"] = resolved_api_key
|
||||
elif args.llm == "anthropic":
|
||||
# For Anthropic, pass base_url and API key if provided
|
||||
if args.api_base:
|
||||
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
|
||||
if args.api_key:
|
||||
llm_config["api_key"] = args.api_key
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
@@ -1635,7 +1541,6 @@ Examples:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
def _ask_once(prompt: str) -> None:
|
||||
query_start_time = time.time()
|
||||
response = chat.ask(
|
||||
prompt,
|
||||
top_k=args.top_k,
|
||||
@@ -1646,20 +1551,27 @@ Examples:
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
query_completion_time = time.time() - query_start_time
|
||||
print(f"LEANN: {response}")
|
||||
print(f"The query took {query_completion_time:.3f} seconds to finish")
|
||||
|
||||
initial_query = (args.query or "").strip()
|
||||
|
||||
if args.interactive:
|
||||
# Create interactive session
|
||||
session = create_cli_session(index_name)
|
||||
|
||||
if initial_query:
|
||||
_ask_once(initial_query)
|
||||
|
||||
session.run_interactive_loop(_ask_once)
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
_ask_once(user_input)
|
||||
else:
|
||||
query = initial_query or input("Enter your question: ").strip()
|
||||
if not query:
|
||||
|
||||
@@ -4,15 +4,12 @@ Consolidates all embedding computation logic using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure performance
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
@@ -23,288 +20,6 @@ LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Token limit registry for embedding models
|
||||
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
|
||||
# Ollama models use dynamic discovery via /api/show
|
||||
EMBEDDING_MODEL_LIMITS = {
|
||||
# Nomic models (common across servers)
|
||||
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
|
||||
"nomic-embed-text-v1.5": 2048,
|
||||
"nomic-embed-text-v2": 512,
|
||||
# Other embedding models
|
||||
"mxbai-embed-large": 512,
|
||||
"all-minilm": 512,
|
||||
"bge-m3": 8192,
|
||||
"snowflake-arctic-embed": 512,
|
||||
# OpenAI models
|
||||
"text-embedding-3-small": 8192,
|
||||
"text-embedding-3-large": 8192,
|
||||
"text-embedding-ada-002": 8192,
|
||||
}
|
||||
|
||||
# Runtime cache for dynamically discovered token limits
|
||||
# Key: (model_name, base_url), Value: token_limit
|
||||
# Prevents repeated SDK/API calls for the same model
|
||||
_token_limit_cache: dict[tuple[str, str], int] = {}
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
default: int = 2048,
|
||||
) -> int:
|
||||
"""
|
||||
Get token limit for a given embedding model.
|
||||
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||
Caches discovered limits to prevent repeated API/SDK calls.
|
||||
|
||||
Args:
|
||||
model_name: Name of the embedding model
|
||||
base_url: Base URL of the embedding server (for dynamic discovery)
|
||||
default: Default token limit if model not found
|
||||
|
||||
Returns:
|
||||
Token limit for the model in tokens
|
||||
"""
|
||||
# Check cache first to avoid repeated SDK/API calls
|
||||
cache_key = (model_name, base_url or "")
|
||||
if cache_key in _token_limit_cache:
|
||||
cached_limit = _token_limit_cache[cache_key]
|
||||
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
|
||||
return cached_limit
|
||||
|
||||
# Try Ollama dynamic discovery if base_url provided
|
||||
if base_url:
|
||||
# Detect Ollama servers by port or "ollama" in URL
|
||||
if "11434" in base_url or "ollama" in base_url.lower():
|
||||
limit = _query_ollama_context_limit(model_name, base_url)
|
||||
if limit:
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Try LM Studio SDK discovery
|
||||
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
|
||||
# Convert HTTP to WebSocket URL
|
||||
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
|
||||
# Remove /v1 suffix if present
|
||||
if ws_url.endswith("/v1"):
|
||||
ws_url = ws_url[:-3]
|
||||
|
||||
limit = _query_lmstudio_context_limit(model_name, ws_url)
|
||||
if limit:
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Fallback to known model registry with version handling (from PR #154)
|
||||
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
||||
base_model_name = model_name.split(":")[0]
|
||||
|
||||
# Check exact match first
|
||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||
limit = EMBEDDING_MODEL_LIMITS[model_name]
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Check base name match
|
||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Check partial matches for common patterns
|
||||
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
|
||||
if known_model in base_model_name or base_model_name in known_model:
|
||||
_token_limit_cache[cache_key] = registry_limit
|
||||
return registry_limit
|
||||
|
||||
# Default fallback
|
||||
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||
_token_limit_cache[cache_key] = default
|
||||
return default
|
||||
|
||||
|
||||
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
|
||||
"""
|
||||
Truncate texts to fit within token limit using tiktoken.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to truncate
|
||||
token_limit: Maximum number of tokens allowed
|
||||
|
||||
Returns:
|
||||
List of truncated texts (same length as input)
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Use tiktoken with cl100k_base encoding
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
truncated_texts = []
|
||||
truncation_count = 0
|
||||
total_tokens_removed = 0
|
||||
max_original_length = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
tokens = enc.encode(text)
|
||||
original_length = len(tokens)
|
||||
|
||||
if original_length <= token_limit:
|
||||
# Text is within limit, keep as is
|
||||
truncated_texts.append(text)
|
||||
else:
|
||||
# Truncate to token_limit
|
||||
truncated_tokens = tokens[:token_limit]
|
||||
truncated_text = enc.decode(truncated_tokens)
|
||||
truncated_texts.append(truncated_text)
|
||||
|
||||
# Track truncation statistics
|
||||
truncation_count += 1
|
||||
tokens_removed = original_length - token_limit
|
||||
total_tokens_removed += tokens_removed
|
||||
max_original_length = max(max_original_length, original_length)
|
||||
|
||||
# Log individual truncation at WARNING level (first few only)
|
||||
if truncation_count <= 3:
|
||||
logger.warning(
|
||||
f"Text {i + 1} truncated: {original_length} → {token_limit} tokens "
|
||||
f"({tokens_removed} tokens removed)"
|
||||
)
|
||||
elif truncation_count == 4:
|
||||
logger.warning("Further truncation warnings suppressed...")
|
||||
|
||||
# Log summary at INFO level
|
||||
if truncation_count > 0:
|
||||
logger.warning(
|
||||
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
|
||||
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
|
||||
)
|
||||
|
||||
return truncated_texts
|
||||
|
||||
|
||||
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||
"""
|
||||
Query Ollama /api/show for model context limit.
|
||||
|
||||
Args:
|
||||
model_name: Name of the Ollama model
|
||||
base_url: Base URL of the Ollama server
|
||||
|
||||
Returns:
|
||||
Context limit in tokens if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/api/show",
|
||||
json={"name": model_name},
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "model_info" in data:
|
||||
# Look for *.context_length in model_info
|
||||
for key, value in data["model_info"].items():
|
||||
if "context_length" in key and isinstance(value, int):
|
||||
logger.info(f"Detected {model_name} context limit: {value} tokens")
|
||||
return value
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to query Ollama context limit: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||
"""
|
||||
Query LM Studio SDK for model context length via Node.js subprocess.
|
||||
|
||||
Args:
|
||||
model_name: Name of the LM Studio model
|
||||
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
|
||||
|
||||
Returns:
|
||||
Context limit in tokens if found, None otherwise
|
||||
"""
|
||||
# Inline JavaScript using @lmstudio/sdk
|
||||
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
|
||||
js_code = f"""
|
||||
const {{ LMStudioClient }} = require('@lmstudio/sdk');
|
||||
(async () => {{
|
||||
try {{
|
||||
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
|
||||
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
|
||||
const contextLength = await model.getContextLength();
|
||||
await model.unload(); // Unload immediately to respect JIT auto-evict settings
|
||||
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
|
||||
}} catch (error) {{
|
||||
console.error(JSON.stringify({{ error: error.message }}));
|
||||
process.exit(1);
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
|
||||
try:
|
||||
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
|
||||
env = os.environ.copy()
|
||||
|
||||
# Try to get npm global root (works with nvm, brew node, etc.)
|
||||
try:
|
||||
npm_root = subprocess.run(
|
||||
["npm", "root", "-g"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if npm_root.returncode == 0:
|
||||
global_modules = npm_root.stdout.strip()
|
||||
# Append to existing NODE_PATH if present
|
||||
existing_node_path = env.get("NODE_PATH", "")
|
||||
env["NODE_PATH"] = (
|
||||
f"{global_modules}:{existing_node_path}"
|
||||
if existing_node_path
|
||||
else global_modules
|
||||
)
|
||||
except Exception:
|
||||
# If npm not available, continue with existing NODE_PATH
|
||||
pass
|
||||
|
||||
result = subprocess.run(
|
||||
["node", "-e", js_code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.debug(f"LM Studio SDK error: {result.stderr}")
|
||||
return None
|
||||
|
||||
data = json.loads(result.stdout)
|
||||
context_length = data.get("contextLength")
|
||||
|
||||
if context_length and context_length > 0:
|
||||
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
|
||||
return context_length
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.debug("LM Studio SDK query timeout")
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("LM Studio SDK returned invalid JSON")
|
||||
except Exception as e:
|
||||
logger.debug(f"LM Studio SDK query failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: dict[str, Any] = {}
|
||||
|
||||
@@ -352,7 +67,6 @@ def compute_embeddings(
|
||||
model_name,
|
||||
base_url=provider_options.get("base_url"),
|
||||
api_key=provider_options.get("api_key"),
|
||||
provider_options=provider_options,
|
||||
)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
@@ -362,7 +76,6 @@ def compute_embeddings(
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
host=provider_options.get("host"),
|
||||
provider_options=provider_options,
|
||||
)
|
||||
elif mode == "gemini":
|
||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||
@@ -701,7 +414,6 @@ def compute_embeddings_openai(
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
@@ -720,40 +432,26 @@ def compute_embeddings_openai(
|
||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||
)
|
||||
|
||||
# Extract base_url and api_key from provider_options if not provided directly
|
||||
provider_options = provider_options or {}
|
||||
effective_base_url = base_url or provider_options.get("base_url")
|
||||
effective_api_key = api_key or provider_options.get("api_key")
|
||||
|
||||
resolved_base_url = resolve_openai_base_url(effective_base_url)
|
||||
resolved_api_key = resolve_openai_api_key(effective_api_key)
|
||||
resolved_base_url = resolve_openai_base_url(base_url)
|
||||
resolved_api_key = resolve_openai_api_key(api_key)
|
||||
|
||||
if not resolved_api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
# Create OpenAI client
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
# Cache OpenAI client
|
||||
cache_key = f"openai_client::{resolved_base_url}"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
print(f"len of texts: {len(texts)}")
|
||||
|
||||
# Apply prompt template if provided
|
||||
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||
"prompt_template"
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||
texts = [f"{prompt_template}{text}" for text in texts]
|
||||
|
||||
# Query token limit and apply truncation
|
||||
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
|
||||
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
|
||||
texts = truncate_to_token_limit(texts, token_limit)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||
all_embeddings = []
|
||||
@@ -784,15 +482,7 @@ def compute_embeddings_openai(
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
|
||||
# Verify we got the expected number of embeddings
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
logger.warning(
|
||||
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
|
||||
)
|
||||
|
||||
# Only take the number of embeddings that match the batch size
|
||||
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {i} failed: {e}")
|
||||
raise
|
||||
@@ -882,20 +572,17 @@ def compute_embeddings_ollama(
|
||||
model_name: str,
|
||||
is_build: bool = False,
|
||||
host: Optional[str] = None,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API with true batch processing.
|
||||
Compute embeddings using Ollama API with simplified batch processing.
|
||||
|
||||
Uses the /api/embed endpoint which supports batch inputs.
|
||||
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
|
||||
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||
provider_options: Optional provider-specific options (e.g., prompt_template)
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
@@ -994,11 +681,11 @@ def compute_embeddings_ollama(
|
||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
||||
model_name = resolved_model_name
|
||||
|
||||
# Verify the model supports embeddings by testing it with /api/embed
|
||||
# Verify the model supports embeddings by testing it
|
||||
try:
|
||||
test_response = requests.post(
|
||||
f"{resolved_host}/api/embed",
|
||||
json={"model": model_name, "input": "test"},
|
||||
f"{resolved_host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": "test"},
|
||||
timeout=10,
|
||||
)
|
||||
if test_response.status_code != 200:
|
||||
@@ -1030,82 +717,56 @@ def compute_embeddings_ollama(
|
||||
# If torch is not available, use conservative batch size
|
||||
batch_size = 32
|
||||
|
||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||
|
||||
# Apply prompt template if provided
|
||||
provider_options = provider_options or {}
|
||||
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||
"prompt_template"
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||
texts = [f"{prompt_template}{text}" for text in texts]
|
||||
|
||||
# Get model token limit and apply truncation before batching
|
||||
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||
|
||||
# Apply truncation to all texts before batch processing
|
||||
# Function logs truncation details internally
|
||||
texts = truncate_to_token_limit(texts, token_limit)
|
||||
logger.info(f"Using batch size: {batch_size}")
|
||||
|
||||
def get_batch_embeddings(batch_texts):
|
||||
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
"""Get embeddings for a batch of texts."""
|
||||
all_embeddings = []
|
||||
failed_indices = []
|
||||
|
||||
# Texts are already truncated to token limit by the outer function
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Use /api/embed endpoint with "input" parameter for batch processing
|
||||
response = requests.post(
|
||||
f"{resolved_host}/api/embed",
|
||||
json={"model": model_name, "input": batch_texts},
|
||||
timeout=60, # Increased timeout for batch processing
|
||||
)
|
||||
response.raise_for_status()
|
||||
for i, text in enumerate(batch_texts):
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
result = response.json()
|
||||
batch_embeddings = result.get("embeddings")
|
||||
|
||||
if batch_embeddings is None:
|
||||
raise ValueError("No embeddings returned from API")
|
||||
|
||||
if not isinstance(batch_embeddings, list):
|
||||
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
|
||||
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
raise ValueError(
|
||||
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
|
||||
# Truncate very long texts to avoid API issues
|
||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{resolved_host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": truncated_text},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return batch_embeddings, []
|
||||
result = response.json()
|
||||
embedding = result.get("embedding")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.warning(f"Timeout for batch after {max_retries} retries")
|
||||
return None, list(range(len(batch_texts)))
|
||||
if embedding is None:
|
||||
raise ValueError(f"No embedding returned for text {i}")
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
# Enhanced error detection for token limit violations
|
||||
error_msg = str(e).lower()
|
||||
if "token" in error_msg and (
|
||||
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
|
||||
):
|
||||
logger.error(
|
||||
f"Token limit exceeded for batch. Error: {e}. "
|
||||
f"Consider reducing chunk sizes or check token truncation."
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to get embeddings for batch: {e}")
|
||||
return None, list(range(len(batch_texts)))
|
||||
if not isinstance(embedding, list) or len(embedding) == 0:
|
||||
raise ValueError(f"Invalid embedding format for text {i}")
|
||||
|
||||
return None, list(range(len(batch_texts)))
|
||||
all_embeddings.append(embedding)
|
||||
break
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
||||
failed_indices.append(i)
|
||||
all_embeddings.append(None)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.error(f"Failed to get embedding for text {i}: {e}")
|
||||
failed_indices.append(i)
|
||||
all_embeddings.append(None)
|
||||
break
|
||||
return all_embeddings, failed_indices
|
||||
|
||||
# Process texts in batches
|
||||
all_embeddings = []
|
||||
@@ -1123,7 +784,7 @@ def compute_embeddings_ollama(
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
|
||||
if show_progress:
|
||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
||||
else:
|
||||
batch_iterator = range(num_batches)
|
||||
|
||||
@@ -1134,14 +795,10 @@ def compute_embeddings_ollama(
|
||||
|
||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||
|
||||
if batch_embeddings is not None:
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
else:
|
||||
# Entire batch failed, add None placeholders
|
||||
all_embeddings.extend([None] * len(batch_texts))
|
||||
# Adjust failed indices to global indices
|
||||
global_failed = [start_idx + idx for idx in batch_failed]
|
||||
all_failed_indices.extend(global_failed)
|
||||
# Adjust failed indices to global indices
|
||||
global_failed = [start_idx + idx for idx in batch_failed]
|
||||
all_failed_indices.extend(global_failed)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Handle failed embeddings
|
||||
if all_failed_indices:
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
"""
|
||||
Interactive session utilities for LEANN applications.
|
||||
|
||||
Provides shared readline functionality and command handling across
|
||||
CLI, API, and RAG example interactive modes.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
# Try to import readline with fallback for Windows
|
||||
try:
|
||||
import readline
|
||||
|
||||
HAS_READLINE = True
|
||||
except ImportError:
|
||||
# Windows doesn't have readline by default
|
||||
HAS_READLINE = False
|
||||
readline = None
|
||||
|
||||
|
||||
class InteractiveSession:
|
||||
"""Manages interactive session with optional readline support and common commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_name: str,
|
||||
prompt: str = "You: ",
|
||||
welcome_message: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize interactive session with optional readline support.
|
||||
|
||||
Args:
|
||||
history_name: Name for history file (e.g., "cli", "api_chat")
|
||||
(ignored if readline not available)
|
||||
prompt: Input prompt to display
|
||||
welcome_message: Message to show when starting session
|
||||
|
||||
Note:
|
||||
On systems without readline (e.g., Windows), falls back to basic input()
|
||||
with limited functionality (no history, no line editing).
|
||||
"""
|
||||
self.history_name = history_name
|
||||
self.prompt = prompt
|
||||
self.welcome_message = welcome_message
|
||||
self._setup_complete = False
|
||||
|
||||
def setup_readline(self):
|
||||
"""Setup readline with history support (if available)."""
|
||||
if self._setup_complete:
|
||||
return
|
||||
|
||||
if not HAS_READLINE:
|
||||
# Readline not available (likely Windows), skip setup
|
||||
self._setup_complete = True
|
||||
return
|
||||
|
||||
# History file setup
|
||||
history_dir = Path.home() / ".leann" / "history"
|
||||
history_dir.mkdir(parents=True, exist_ok=True)
|
||||
history_file = history_dir / f"{self.history_name}.history"
|
||||
|
||||
# Load history if exists
|
||||
try:
|
||||
readline.read_history_file(str(history_file))
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# Save history on exit
|
||||
atexit.register(readline.write_history_file, str(history_file))
|
||||
|
||||
# Optional: Enable vi editing mode (commented out by default)
|
||||
# readline.parse_and_bind("set editing-mode vi")
|
||||
|
||||
self._setup_complete = True
|
||||
|
||||
def _show_help(self):
|
||||
"""Show available commands."""
|
||||
print("Commands:")
|
||||
print(" quit/exit/q - Exit the chat")
|
||||
print(" help - Show this help message")
|
||||
print(" clear - Clear screen")
|
||||
print(" history - Show command history")
|
||||
|
||||
def _show_history(self):
|
||||
"""Show command history."""
|
||||
if not HAS_READLINE:
|
||||
print(" History not available (readline not supported on this system)")
|
||||
return
|
||||
|
||||
history_length = readline.get_current_history_length()
|
||||
if history_length == 0:
|
||||
print(" No history available")
|
||||
return
|
||||
|
||||
for i in range(history_length):
|
||||
item = readline.get_history_item(i + 1)
|
||||
if item:
|
||||
print(f" {i + 1}: {item}")
|
||||
|
||||
def get_user_input(self) -> Optional[str]:
|
||||
"""
|
||||
Get user input with readline support.
|
||||
|
||||
Returns:
|
||||
User input string, or None if EOF (Ctrl+D)
|
||||
"""
|
||||
try:
|
||||
return input(self.prompt).strip()
|
||||
except KeyboardInterrupt:
|
||||
print("\n(Use 'quit' to exit)")
|
||||
return "" # Return empty string to continue
|
||||
except EOFError:
|
||||
print("\nGoodbye!")
|
||||
return None
|
||||
|
||||
def run_interactive_loop(self, handler_func: Callable[[str], None]):
|
||||
"""
|
||||
Run the interactive loop with a custom handler function.
|
||||
|
||||
Args:
|
||||
handler_func: Function to handle user input that's not a built-in command
|
||||
Should accept a string and handle the user's query
|
||||
"""
|
||||
self.setup_readline()
|
||||
|
||||
if self.welcome_message:
|
||||
print(self.welcome_message)
|
||||
|
||||
while True:
|
||||
user_input = self.get_user_input()
|
||||
|
||||
if user_input is None: # EOF (Ctrl+D)
|
||||
break
|
||||
|
||||
if not user_input: # Empty input or KeyboardInterrupt
|
||||
continue
|
||||
|
||||
# Handle built-in commands
|
||||
command = user_input.lower()
|
||||
if command in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
elif command == "help":
|
||||
self._show_help()
|
||||
elif command == "clear":
|
||||
os.system("clear" if os.name != "nt" else "cls")
|
||||
elif command == "history":
|
||||
self._show_history()
|
||||
else:
|
||||
# Regular user input - pass to handler
|
||||
try:
|
||||
handler_func(user_input)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def create_cli_session(index_name: str) -> InteractiveSession:
|
||||
"""Create an interactive session for CLI usage."""
|
||||
return InteractiveSession(
|
||||
history_name=index_name,
|
||||
prompt="\nYou: ",
|
||||
welcome_message="LEANN Assistant ready! Type 'quit' to exit, 'help' for commands\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
|
||||
|
||||
def create_api_session() -> InteractiveSession:
|
||||
"""Create an interactive session for API chat."""
|
||||
return InteractiveSession(
|
||||
history_name="api_chat",
|
||||
prompt="You: ",
|
||||
welcome_message="Leann Chat started (type 'quit' to exit, 'help' for commands)\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
|
||||
|
||||
def create_rag_session(app_name: str, data_description: str) -> InteractiveSession:
|
||||
"""Create an interactive session for RAG examples."""
|
||||
return InteractiveSession(
|
||||
history_name=f"{app_name}_rag",
|
||||
prompt="You: ",
|
||||
welcome_message=f"[Interactive Mode] Chat with your {data_description} data!\nType 'quit' or 'exit' to stop, 'help' for commands.\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
@@ -77,7 +77,6 @@ class LeannBackendSearcherInterface(ABC):
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
@@ -85,7 +84,6 @@ class LeannBackendSearcherInterface(ABC):
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
query_template: Optional prompt template to prepend to query
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array with shape (1, D)
|
||||
|
||||
@@ -60,11 +60,6 @@ def handle_request(request):
|
||||
"maximum": 128,
|
||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||
},
|
||||
"show_metadata": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
|
||||
},
|
||||
},
|
||||
"required": ["index_name", "query"],
|
||||
},
|
||||
@@ -109,8 +104,6 @@ def handle_request(request):
|
||||
f"--complexity={args.get('complexity', 32)}",
|
||||
"--non-interactive",
|
||||
]
|
||||
if args.get("show_metadata", False):
|
||||
cmd.append("--show-metadata")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
elif tool_name == "leann_list":
|
||||
|
||||
@@ -33,8 +33,6 @@ def autodiscover_backends():
|
||||
discovered_backends = []
|
||||
for dist in importlib.metadata.distributions():
|
||||
dist_name = dist.metadata["name"]
|
||||
if dist_name is None:
|
||||
continue
|
||||
if dist_name.startswith("leann-backend-"):
|
||||
backend_module_name = dist_name.replace("-", "_")
|
||||
discovered_backends.append(backend_module_name)
|
||||
|
||||
@@ -71,15 +71,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
or "mips"
|
||||
)
|
||||
|
||||
# Filter out ALL prompt templates from provider_options during search
|
||||
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
|
||||
# The server should never apply templates during search to avoid double-templating
|
||||
search_provider_options = {
|
||||
k: v
|
||||
for k, v in self.embedding_options.items()
|
||||
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
|
||||
}
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
@@ -87,7 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=distance_metric,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
provider_options=search_provider_options,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||
@@ -99,7 +90,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
@@ -108,16 +98,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
query_template: Optional prompt template to prepend to query
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Apply query template BEFORE any computation path
|
||||
# This ensures template is applied consistently for both server and fallback paths
|
||||
if query_template:
|
||||
query = f"{query_template}{query}"
|
||||
|
||||
# Try to use embedding server if available and requested
|
||||
if use_server_if_available:
|
||||
try:
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any
|
||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||
|
||||
|
||||
def _clean_url(value: str) -> str:
|
||||
@@ -53,23 +52,6 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||
|
||||
|
||||
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
|
||||
"""Resolve the base URL for Anthropic-compatible services."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
|
||||
os.getenv("ANTHROPIC_BASE_URL"),
|
||||
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
|
||||
|
||||
|
||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for OpenAI-compatible services."""
|
||||
|
||||
@@ -79,15 +61,6 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for Anthropic services."""
|
||||
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
return os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
|
||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||
"""Serialize provider options for child processes."""
|
||||
|
||||
|
||||
@@ -53,11 +53,6 @@ leann build my-project --docs $(git ls-files)
|
||||
# Start Claude Code
|
||||
claude
|
||||
```
|
||||
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
|
||||
|
||||
```bash
|
||||
leann build my-project --docs $(git ls-files) --no-recompute
|
||||
```
|
||||
|
||||
## 🚀 Advanced Usage Examples to build the index
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.3.5"
|
||||
version = "0.3.4"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
authors = [
|
||||
{ name = "LEANN Team" }
|
||||
@@ -18,10 +18,10 @@ classifiers = [
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
|
||||
# Default installation: core + hnsw + diskann
|
||||
|
||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "leann-workspace"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
dependencies = [
|
||||
"leann-core",
|
||||
@@ -22,10 +22,7 @@ dependencies = [
|
||||
"sglang",
|
||||
"ollama",
|
||||
"requests>=2.25.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
# Pin transformers below 4.46: 4.46.0 introduced Python 3.10-only typing (PEP 604) and
|
||||
# breaks our Python 3.9 test matrix when pulled in by sentence-transformers.
|
||||
"transformers<4.46",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"openai>=1.0.0",
|
||||
# PDF parsing dependencies - essential for document processing
|
||||
"PyPDF2>=3.0.0",
|
||||
@@ -57,8 +54,6 @@ dependencies = [
|
||||
"tree-sitter-c-sharp>=0.20.0",
|
||||
"tree-sitter-typescript>=0.20.0",
|
||||
"torchvision>=0.23.0",
|
||||
"einops",
|
||||
"seaborn",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -69,8 +64,7 @@ diskann = [
|
||||
# Add a new optional dependency group for document processing
|
||||
documents = [
|
||||
"beautifulsoup4>=4.13.0", # For HTML parsing
|
||||
"python-docx>=0.8.11", # For Word documents (creating/editing)
|
||||
"docx2txt>=0.9", # For Word documents (text extraction)
|
||||
"python-docx>=0.8.11", # For Word documents
|
||||
"openpyxl>=3.1.0", # For Excel files
|
||||
"pandas>=2.2.0", # For data processing
|
||||
]
|
||||
@@ -165,7 +159,6 @@ python_functions = ["test_*"]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"openai: marks tests that require OpenAI API key",
|
||||
"integration: marks tests that require live services (Ollama, LM Studio, etc.)",
|
||||
]
|
||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||
addopts = [
|
||||
|
||||
@@ -36,14 +36,6 @@ Tests DiskANN graph partitioning functionality:
|
||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||
|
||||
### `test_prompt_template_e2e.py`
|
||||
Integration tests for prompt template feature with live embedding services:
|
||||
- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio)
|
||||
- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default)
|
||||
- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk)
|
||||
- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration`
|
||||
- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Install test dependencies:
|
||||
@@ -74,12 +66,6 @@ pytest tests/ -m "not openai"
|
||||
# Skip slow tests
|
||||
pytest tests/ -m "not slow"
|
||||
|
||||
# Skip integration tests (that require live services)
|
||||
pytest tests/ -m "not integration"
|
||||
|
||||
# Run only integration tests (requires LM Studio or Ollama running)
|
||||
pytest tests/test_prompt_template_e2e.py -v -s
|
||||
|
||||
# Run DiskANN partition tests (requires local machine, not CI)
|
||||
pytest tests/test_diskann_partition.py
|
||||
```
|
||||
@@ -115,20 +101,6 @@ The `pytest.ini` file configures:
|
||||
- Custom markers for slow and OpenAI tests
|
||||
- Verbose output with short tracebacks
|
||||
|
||||
### Integration Test Prerequisites
|
||||
|
||||
Integration tests (`test_prompt_template_e2e.py`) require live services:
|
||||
|
||||
**Required:**
|
||||
- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded
|
||||
|
||||
**Optional:**
|
||||
- Ollama running at `http://localhost:11434` for token limit detection tests
|
||||
- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests
|
||||
|
||||
Tests gracefully skip if services are unavailable.
|
||||
|
||||
### Known Issues
|
||||
|
||||
- OpenAI tests are automatically skipped if no API key is provided
|
||||
- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed)
|
||||
|
||||
@@ -8,7 +8,7 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -116,10 +116,8 @@ class TestChunkingFunctions:
|
||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||
|
||||
assert len(chunks) > 0
|
||||
# Traditional chunks now return dict format for consistency
|
||||
assert all(isinstance(chunk, dict) for chunk in chunks)
|
||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
|
||||
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
|
||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||
|
||||
def test_create_traditional_chunks_empty_docs(self):
|
||||
"""Test traditional chunking with empty documents."""
|
||||
@@ -160,22 +158,11 @@ class Calculator:
|
||||
|
||||
# Should have multiple chunks due to different functions/classes
|
||||
assert len(chunks) > 0
|
||||
# R3: Expect dict format with "text" and "metadata" keys
|
||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||
"Each chunk should have 'text' and 'metadata' keys"
|
||||
)
|
||||
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
|
||||
"Each chunk text should be non-empty"
|
||||
)
|
||||
|
||||
# Check metadata is present
|
||||
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
|
||||
"Each chunk should have file_path metadata"
|
||||
)
|
||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||
|
||||
# Check that code structure is somewhat preserved
|
||||
combined_content = " ".join([c["text"] for c in chunks])
|
||||
combined_content = " ".join(chunks)
|
||||
assert "def hello_world" in combined_content
|
||||
assert "class Calculator" in combined_content
|
||||
|
||||
@@ -207,11 +194,7 @@ class Calculator:
|
||||
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
||||
|
||||
assert len(chunks) > 0
|
||||
# R3: Traditional chunking should also return dict format for consistency
|
||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||
"Each chunk should have 'text' and 'metadata' keys"
|
||||
)
|
||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||
|
||||
def test_create_text_chunks_ast_mode(self):
|
||||
"""Test text chunking in AST mode."""
|
||||
@@ -230,11 +213,7 @@ class Calculator:
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
# R3: AST mode should also return dict format
|
||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||
"Each chunk should have 'text' and 'metadata' keys"
|
||||
)
|
||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||
|
||||
def test_create_text_chunks_custom_extensions(self):
|
||||
"""Test text chunking with custom code file extensions."""
|
||||
@@ -374,552 +353,6 @@ class MathUtils:
|
||||
pytest.skip("Test timed out - likely due to model download in CI")
|
||||
|
||||
|
||||
class TestASTContentExtraction:
|
||||
"""Test AST content extraction bug fix.
|
||||
|
||||
These tests verify that astchunk's dict format with 'content' key is handled correctly,
|
||||
and that the extraction logic doesn't fall through to stringifying entire dicts.
|
||||
"""
|
||||
|
||||
def test_extract_content_from_astchunk_dict(self):
|
||||
"""Test that astchunk dict format with 'content' key is handled correctly.
|
||||
|
||||
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
|
||||
This causes fallthrough to str(chunk), stringifying the entire dict.
|
||||
|
||||
This test will FAIL until the bug is fixed because:
|
||||
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
|
||||
- Fixed code should extract just the content value
|
||||
"""
|
||||
# Mock the ASTChunkBuilder class
|
||||
mock_builder = Mock()
|
||||
|
||||
# Astchunk returns this format
|
||||
astchunk_format_chunk = {
|
||||
"content": "def hello():\n print('world')",
|
||||
"metadata": {
|
||||
"filepath": "test.py",
|
||||
"line_count": 2,
|
||||
"start_line_no": 0,
|
||||
"end_line_no": 1,
|
||||
"node_count": 1,
|
||||
},
|
||||
}
|
||||
mock_builder.chunkify.return_value = [astchunk_format_chunk]
|
||||
|
||||
# Create mock document
|
||||
doc = MockDocument(
|
||||
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
|
||||
)
|
||||
|
||||
# Mock the astchunk module and its ASTChunkBuilder class
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
# Patch sys.modules to inject our mock before the import
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
# Call create_ast_chunks
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should return dict format with proper metadata
|
||||
assert len(chunks) > 0, "Should return at least one chunk"
|
||||
|
||||
# R3: Each chunk should be a dict
|
||||
chunk = chunks[0]
|
||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||
assert "text" in chunk, "Chunk should have 'text' key"
|
||||
assert "metadata" in chunk, "Chunk should have 'metadata' key"
|
||||
|
||||
chunk_text = chunk["text"]
|
||||
|
||||
# CRITICAL: Should NOT contain stringified dict markers in the text field
|
||||
# These assertions will FAIL with current buggy code
|
||||
assert "'content':" not in chunk_text, (
|
||||
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
|
||||
)
|
||||
assert "'metadata':" not in chunk_text, (
|
||||
"Chunk text contains stringified metadata - extraction failed! "
|
||||
f"Got: {chunk_text[:100]}..."
|
||||
)
|
||||
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
|
||||
"Chunk text appears to be a stringified dict"
|
||||
)
|
||||
|
||||
# Should contain actual content
|
||||
assert "def hello()" in chunk_text, "Should extract actual code content"
|
||||
assert "print('world')" in chunk_text, "Should extract complete code content"
|
||||
|
||||
# R3: Should preserve astchunk metadata
|
||||
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
|
||||
"Should preserve file path metadata"
|
||||
)
|
||||
|
||||
def test_extract_text_key_fallback(self):
|
||||
"""Test that 'text' key still works for backward compatibility.
|
||||
|
||||
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
|
||||
This test should PASS even with current code.
|
||||
"""
|
||||
mock_builder = Mock()
|
||||
|
||||
# Some chunks might use "text" key
|
||||
text_key_chunk = {"text": "def legacy_function():\n return True"}
|
||||
mock_builder.chunkify.return_value = [text_key_chunk]
|
||||
|
||||
# Create mock document
|
||||
doc = MockDocument(
|
||||
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
|
||||
)
|
||||
|
||||
# Mock the astchunk module
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
# Call create_ast_chunks
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should extract text correctly as dict format
|
||||
assert len(chunks) > 0
|
||||
chunk = chunks[0]
|
||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||
assert "text" in chunk, "Chunk should have 'text' key"
|
||||
|
||||
chunk_text = chunk["text"]
|
||||
|
||||
# Should NOT be stringified
|
||||
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
|
||||
|
||||
# Should contain actual content
|
||||
assert "def legacy_function()" in chunk_text
|
||||
assert "return True" in chunk_text
|
||||
|
||||
def test_handles_string_chunks(self):
|
||||
"""Test that plain string chunks still work.
|
||||
|
||||
Some chunkers might return plain strings - verify these are preserved.
|
||||
This test should PASS with current code.
|
||||
"""
|
||||
mock_builder = Mock()
|
||||
|
||||
# Plain string chunk
|
||||
plain_string_chunk = "def simple_function():\n pass"
|
||||
mock_builder.chunkify.return_value = [plain_string_chunk]
|
||||
|
||||
# Create mock document
|
||||
doc = MockDocument(
|
||||
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
|
||||
)
|
||||
|
||||
# Mock the astchunk module
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
# Call create_ast_chunks
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should wrap string in dict format
|
||||
assert len(chunks) > 0
|
||||
chunk = chunks[0]
|
||||
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
|
||||
assert "text" in chunk, "Chunk should have 'text' key"
|
||||
|
||||
chunk_text = chunk["text"]
|
||||
|
||||
assert chunk_text == plain_string_chunk.strip(), (
|
||||
"Should preserve plain string chunk content"
|
||||
)
|
||||
assert "def simple_function()" in chunk_text
|
||||
assert "pass" in chunk_text
|
||||
|
||||
def test_multiple_chunks_with_mixed_formats(self):
|
||||
"""Test handling of multiple chunks with different formats.
|
||||
|
||||
Real-world scenario: astchunk might return a mix of formats.
|
||||
This test will FAIL if any chunk with 'content' key gets stringified.
|
||||
"""
|
||||
mock_builder = Mock()
|
||||
|
||||
# Mix of formats
|
||||
mixed_chunks = [
|
||||
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
|
||||
"def second():\n return 2", # Plain string
|
||||
{"text": "def third():\n return 3"}, # Old format
|
||||
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
|
||||
]
|
||||
mock_builder.chunkify.return_value = mixed_chunks
|
||||
|
||||
# Create mock document
|
||||
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
|
||||
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
|
||||
|
||||
# Mock the astchunk module
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
# Call create_ast_chunks
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should extract all chunks correctly as dicts
|
||||
assert len(chunks) == 4, "Should extract all 4 chunks"
|
||||
|
||||
# Check each chunk
|
||||
for i, chunk in enumerate(chunks):
|
||||
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
|
||||
assert "text" in chunk, f"Chunk {i} should have 'text' key"
|
||||
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
|
||||
|
||||
chunk_text = chunk["text"]
|
||||
# None should be stringified dicts
|
||||
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
|
||||
assert "'metadata':" not in chunk_text, (
|
||||
f"Chunk {i} text is stringified (has 'metadata':)"
|
||||
)
|
||||
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
|
||||
|
||||
# Verify actual content is present
|
||||
combined = "\n".join([c["text"] for c in chunks])
|
||||
assert "def first()" in combined
|
||||
assert "def second()" in combined
|
||||
assert "def third()" in combined
|
||||
assert "class MyClass:" in combined
|
||||
|
||||
def test_empty_content_value_handling(self):
|
||||
"""Test handling of chunks with empty content values.
|
||||
|
||||
Edge case: chunk has 'content' key but value is empty.
|
||||
Should skip these chunks, not stringify them.
|
||||
"""
|
||||
mock_builder = Mock()
|
||||
|
||||
chunks_with_empty = [
|
||||
{"content": "", "metadata": {"line_count": 0}}, # Empty content
|
||||
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
|
||||
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
|
||||
]
|
||||
mock_builder.chunkify.return_value = chunks_with_empty
|
||||
|
||||
doc = MockDocument(
|
||||
"def valid():\n return True", "/test/empty.py", {"language": "python"}
|
||||
)
|
||||
|
||||
# Mock the astchunk module
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should only have the valid chunk (empty ones filtered out)
|
||||
assert len(chunks) == 1, "Should filter out empty content chunks"
|
||||
|
||||
chunk = chunks[0]
|
||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||
assert "text" in chunk, "Chunk should have 'text' key"
|
||||
assert "def valid()" in chunk["text"]
|
||||
|
||||
# Should not have stringified the empty dict
|
||||
assert "'content': ''" not in chunk["text"]
|
||||
|
||||
|
||||
class TestASTMetadataPreservation:
|
||||
"""Test metadata preservation in AST chunk dictionaries.
|
||||
|
||||
R3: These tests define the contract for metadata preservation when returning
|
||||
chunk dictionaries instead of plain strings. Each chunk dict should have:
|
||||
- "text": str - the actual chunk content
|
||||
- "metadata": dict - all metadata from document AND astchunk
|
||||
|
||||
These tests will FAIL until G3 implementation changes return type to list[dict].
|
||||
"""
|
||||
|
||||
def test_ast_chunks_preserve_file_metadata(self):
|
||||
"""Test that document metadata is preserved in chunk metadata.
|
||||
|
||||
This test verifies that all document-level metadata (file_path, file_name,
|
||||
creation_date, last_modified_date) is included in each chunk's metadata dict.
|
||||
|
||||
This will FAIL because current code returns list[str], not list[dict].
|
||||
"""
|
||||
# Create mock document with rich metadata
|
||||
python_code = '''
|
||||
def calculate_sum(numbers):
|
||||
"""Calculate sum of numbers."""
|
||||
return sum(numbers)
|
||||
|
||||
class DataProcessor:
|
||||
"""Process data records."""
|
||||
|
||||
def process(self, data):
|
||||
return [x * 2 for x in data]
|
||||
'''
|
||||
doc = MockDocument(
|
||||
python_code,
|
||||
file_path="/project/src/utils.py",
|
||||
metadata={
|
||||
"language": "python",
|
||||
"file_path": "/project/src/utils.py",
|
||||
"file_name": "utils.py",
|
||||
"creation_date": "2024-01-15T10:30:00",
|
||||
"last_modified_date": "2024-10-31T15:45:00",
|
||||
},
|
||||
)
|
||||
|
||||
# Mock astchunk to return chunks with metadata
|
||||
mock_builder = Mock()
|
||||
astchunk_chunks = [
|
||||
{
|
||||
"content": "def calculate_sum(numbers):\n return sum(numbers)",
|
||||
"metadata": {
|
||||
"filepath": "/project/src/utils.py",
|
||||
"line_count": 2,
|
||||
"start_line_no": 1,
|
||||
"end_line_no": 2,
|
||||
"node_count": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
|
||||
"metadata": {
|
||||
"filepath": "/project/src/utils.py",
|
||||
"line_count": 3,
|
||||
"start_line_no": 5,
|
||||
"end_line_no": 7,
|
||||
"node_count": 2,
|
||||
},
|
||||
},
|
||||
]
|
||||
mock_builder.chunkify.return_value = astchunk_chunks
|
||||
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# CRITICAL: These assertions will FAIL with current list[str] return type
|
||||
assert len(chunks) == 2, "Should return 2 chunks"
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Structure assertions - WILL FAIL: current code returns strings
|
||||
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
||||
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
||||
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
||||
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
|
||||
|
||||
# Document metadata preservation - WILL FAIL
|
||||
metadata = chunk["metadata"]
|
||||
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
|
||||
assert metadata["file_path"] == "/project/src/utils.py", (
|
||||
f"Chunk {i} file_path incorrect"
|
||||
)
|
||||
|
||||
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
|
||||
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
|
||||
|
||||
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
|
||||
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
|
||||
f"Chunk {i} creation_date incorrect"
|
||||
)
|
||||
|
||||
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
|
||||
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
|
||||
f"Chunk {i} last_modified_date incorrect"
|
||||
)
|
||||
|
||||
# Verify metadata is consistent across chunks from same document
|
||||
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
|
||||
"All chunks from same document should have same file_path"
|
||||
)
|
||||
|
||||
# Verify text content is present and not stringified
|
||||
assert "def calculate_sum" in chunks[0]["text"]
|
||||
assert "class DataProcessor" in chunks[1]["text"]
|
||||
|
||||
def test_ast_chunks_include_astchunk_metadata(self):
|
||||
"""Test that astchunk-specific metadata is merged into chunk metadata.
|
||||
|
||||
This test verifies that astchunk's metadata (line_count, start_line_no,
|
||||
end_line_no, node_count) is merged with document metadata.
|
||||
|
||||
This will FAIL because current code returns list[str], not list[dict].
|
||||
"""
|
||||
python_code = '''
|
||||
def function_one():
|
||||
"""First function."""
|
||||
x = 1
|
||||
y = 2
|
||||
return x + y
|
||||
|
||||
def function_two():
|
||||
"""Second function."""
|
||||
return 42
|
||||
'''
|
||||
doc = MockDocument(
|
||||
python_code,
|
||||
file_path="/test/code.py",
|
||||
metadata={
|
||||
"language": "python",
|
||||
"file_path": "/test/code.py",
|
||||
"file_name": "code.py",
|
||||
},
|
||||
)
|
||||
|
||||
# Mock astchunk with detailed metadata
|
||||
mock_builder = Mock()
|
||||
astchunk_chunks = [
|
||||
{
|
||||
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
|
||||
"metadata": {
|
||||
"filepath": "/test/code.py",
|
||||
"line_count": 4,
|
||||
"start_line_no": 1,
|
||||
"end_line_no": 4,
|
||||
"node_count": 5, # function, assignments, return
|
||||
},
|
||||
},
|
||||
{
|
||||
"content": "def function_two():\n return 42",
|
||||
"metadata": {
|
||||
"filepath": "/test/code.py",
|
||||
"line_count": 2,
|
||||
"start_line_no": 7,
|
||||
"end_line_no": 8,
|
||||
"node_count": 2, # function, return
|
||||
},
|
||||
},
|
||||
]
|
||||
mock_builder.chunkify.return_value = astchunk_chunks
|
||||
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# CRITICAL: These will FAIL with current list[str] return
|
||||
assert len(chunks) == 2
|
||||
|
||||
# First chunk - function_one
|
||||
chunk1 = chunks[0]
|
||||
assert isinstance(chunk1, dict), "Chunk should be dict"
|
||||
assert "metadata" in chunk1
|
||||
|
||||
metadata1 = chunk1["metadata"]
|
||||
|
||||
# Check astchunk metadata is present
|
||||
assert "line_count" in metadata1, "Should include astchunk line_count"
|
||||
assert metadata1["line_count"] == 4, "line_count should be 4"
|
||||
|
||||
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
|
||||
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
|
||||
|
||||
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
|
||||
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
|
||||
|
||||
assert "node_count" in metadata1, "Should include astchunk node_count"
|
||||
assert metadata1["node_count"] == 5, "node_count should be 5"
|
||||
|
||||
# Second chunk - function_two
|
||||
chunk2 = chunks[1]
|
||||
metadata2 = chunk2["metadata"]
|
||||
|
||||
assert metadata2["line_count"] == 2, "line_count should be 2"
|
||||
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
|
||||
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
|
||||
assert metadata2["node_count"] == 2, "node_count should be 2"
|
||||
|
||||
# Verify document metadata is ALSO present (merged, not replaced)
|
||||
assert metadata1["file_path"] == "/test/code.py"
|
||||
assert metadata1["file_name"] == "code.py"
|
||||
assert metadata2["file_path"] == "/test/code.py"
|
||||
assert metadata2["file_name"] == "code.py"
|
||||
|
||||
# Verify text content is correct
|
||||
assert "def function_one" in chunk1["text"]
|
||||
assert "def function_two" in chunk2["text"]
|
||||
|
||||
def test_traditional_chunks_as_dicts_helper(self):
|
||||
"""Test the helper function that wraps traditional chunks as dicts.
|
||||
|
||||
This test verifies that when create_traditional_chunks is called,
|
||||
its plain string chunks are wrapped into dict format with metadata.
|
||||
|
||||
This will FAIL because the helper function _traditional_chunks_as_dicts()
|
||||
doesn't exist yet, and create_traditional_chunks returns list[str].
|
||||
"""
|
||||
# Create documents with various metadata
|
||||
docs = [
|
||||
MockDocument(
|
||||
"This is the first paragraph of text. It contains multiple sentences. "
|
||||
"This should be split into chunks based on size.",
|
||||
file_path="/docs/readme.txt",
|
||||
metadata={
|
||||
"file_path": "/docs/readme.txt",
|
||||
"file_name": "readme.txt",
|
||||
"creation_date": "2024-01-01",
|
||||
},
|
||||
),
|
||||
MockDocument(
|
||||
"Second document with different metadata. It also has content that needs chunking.",
|
||||
file_path="/docs/guide.md",
|
||||
metadata={
|
||||
"file_path": "/docs/guide.md",
|
||||
"file_name": "guide.md",
|
||||
"last_modified_date": "2024-10-31",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
# Call create_traditional_chunks (which should now return list[dict])
|
||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||
|
||||
# CRITICAL: Will FAIL - current code returns list[str]
|
||||
assert len(chunks) > 0, "Should return chunks"
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Structure assertions - WILL FAIL
|
||||
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
||||
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
||||
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
||||
|
||||
# Text should be non-empty
|
||||
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
|
||||
|
||||
# Metadata should include document info
|
||||
metadata = chunk["metadata"]
|
||||
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
|
||||
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
|
||||
|
||||
# Verify metadata tracking works correctly
|
||||
# At least one chunk should be from readme.txt
|
||||
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
|
||||
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
|
||||
|
||||
# At least one chunk should be from guide.md
|
||||
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
|
||||
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
|
||||
|
||||
# Verify creation_date is preserved for readme chunks
|
||||
for chunk in readme_chunks:
|
||||
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
|
||||
"readme.txt chunks should preserve creation_date"
|
||||
)
|
||||
|
||||
# Verify last_modified_date is preserved for guide chunks
|
||||
for chunk in guide_chunks:
|
||||
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
|
||||
"guide.md chunks should preserve last_modified_date"
|
||||
)
|
||||
|
||||
# Verify text content is present
|
||||
all_text = " ".join([c["text"] for c in chunks])
|
||||
assert "first paragraph" in all_text
|
||||
assert "Second document" in all_text
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling and edge cases."""
|
||||
|
||||
|
||||
@@ -1,533 +0,0 @@
|
||||
"""
|
||||
Tests for CLI argument integration of --embedding-prompt-template.
|
||||
|
||||
These tests verify that:
|
||||
1. The --embedding-prompt-template flag is properly registered on build and search commands
|
||||
2. The template value flows from CLI args to embedding_options dict
|
||||
3. The template is passed through to compute_embeddings() function
|
||||
4. Default behavior (no flag) is handled correctly
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from leann.cli import LeannCLI
|
||||
|
||||
|
||||
class TestCLIPromptTemplateArgument:
|
||||
"""Tests for --embedding-prompt-template on build and search commands."""
|
||||
|
||||
def test_commands_accept_prompt_template_argument(self):
|
||||
"""Verify that build and search parsers accept --embedding-prompt-template flag."""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
template_value = "search_query: "
|
||||
|
||||
# Test build command
|
||||
build_args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
"/tmp/test-docs",
|
||||
"--embedding-prompt-template",
|
||||
template_value,
|
||||
]
|
||||
)
|
||||
assert build_args.command == "build"
|
||||
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||
"build command should have embedding_prompt_template attribute"
|
||||
)
|
||||
assert build_args.embedding_prompt_template == template_value
|
||||
|
||||
# Test search command
|
||||
search_args = parser.parse_args(
|
||||
["search", "test-index", "my query", "--embedding-prompt-template", template_value]
|
||||
)
|
||||
assert search_args.command == "search"
|
||||
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||
"search command should have embedding_prompt_template attribute"
|
||||
)
|
||||
assert search_args.embedding_prompt_template == template_value
|
||||
|
||||
def test_commands_default_to_none(self):
|
||||
"""Verify default value is None when flag not provided (backward compatibility)."""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
# Test build command default
|
||||
build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"])
|
||||
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||
"build command should have embedding_prompt_template attribute"
|
||||
)
|
||||
assert build_args.embedding_prompt_template is None, (
|
||||
"Build default value should be None when flag not provided"
|
||||
)
|
||||
|
||||
# Test search command default
|
||||
search_args = parser.parse_args(["search", "test-index", "my query"])
|
||||
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||
"search command should have embedding_prompt_template attribute"
|
||||
)
|
||||
assert search_args.embedding_prompt_template is None, (
|
||||
"Search default value should be None when flag not provided"
|
||||
)
|
||||
|
||||
|
||||
class TestBuildCommandPromptTemplateArgumentExtras:
|
||||
"""Additional build-specific tests for prompt template argument."""
|
||||
|
||||
def test_build_command_prompt_template_with_multiword_value(self):
|
||||
"""
|
||||
Verify that template values with spaces are handled correctly.
|
||||
|
||||
Templates like "search_document: " or "Represent this sentence for searching: "
|
||||
should be accepted as a single string argument.
|
||||
"""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
template = "Represent this sentence for searching: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
"/tmp/test-docs",
|
||||
"--embedding-prompt-template",
|
||||
template,
|
||||
]
|
||||
)
|
||||
|
||||
assert args.embedding_prompt_template == template
|
||||
|
||||
|
||||
class TestPromptTemplateStoredInEmbeddingOptions:
|
||||
"""Tests for template storage in embedding_options dict."""
|
||||
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_prompt_template_stored_in_embedding_options_on_build(
|
||||
self, mock_builder_class, tmp_path
|
||||
):
|
||||
"""
|
||||
Verify that when --embedding-prompt-template is provided to build command,
|
||||
the value is stored in embedding_options dict passed to LeannBuilder.
|
||||
|
||||
This test will fail because the CLI doesn't currently process this argument
|
||||
and add it to embedding_options.
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
# Create CLI and run build command
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
template = "search_query: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--embedding-prompt-template",
|
||||
template,
|
||||
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||
]
|
||||
)
|
||||
|
||||
# Run the build command
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that LeannBuilder was called with embedding_options containing prompt_template
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
assert embedding_options is not None, (
|
||||
"embedding_options should not be None when template provided"
|
||||
)
|
||||
assert "prompt_template" in embedding_options, (
|
||||
"embedding_options should contain 'prompt_template' key"
|
||||
)
|
||||
assert embedding_options["prompt_template"] == template, (
|
||||
f"Template should be '{template}', got {embedding_options.get('prompt_template')}"
|
||||
)
|
||||
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path):
|
||||
"""
|
||||
Verify that when --embedding-prompt-template is NOT provided,
|
||||
embedding_options either doesn't have the key or it's None.
|
||||
|
||||
This ensures we don't pass empty/None values unnecessarily.
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||
]
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that if embedding_options is passed, it doesn't have prompt_template
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
if call_kwargs.get("embedding_options"):
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
# Either the key shouldn't exist, or it should be None
|
||||
assert (
|
||||
"prompt_template" not in embedding_options
|
||||
or embedding_options["prompt_template"] is None
|
||||
), "prompt_template should not be set when flag not provided"
|
||||
|
||||
# R1 Tests: Build-time separate template storage
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_build_stores_separate_templates(self, mock_builder_class, tmp_path):
|
||||
"""
|
||||
R1 Test 1: Verify that when both --embedding-prompt-template and
|
||||
--query-prompt-template are provided to build command, both values
|
||||
are stored separately in embedding_options dict as build_prompt_template
|
||||
and query_prompt_template.
|
||||
|
||||
This test will fail because:
|
||||
1. CLI doesn't accept --query-prompt-template flag yet
|
||||
2. CLI doesn't store templates as separate build_prompt_template and
|
||||
query_prompt_template keys
|
||||
|
||||
Expected behavior after implementation:
|
||||
- .meta.json contains: {"embedding_options": {
|
||||
"build_prompt_template": "doc: ",
|
||||
"query_prompt_template": "query: "
|
||||
}}
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
build_template = "doc: "
|
||||
query_template = "query: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--embedding-prompt-template",
|
||||
build_template,
|
||||
"--query-prompt-template",
|
||||
query_template,
|
||||
"--force",
|
||||
]
|
||||
)
|
||||
|
||||
# Run the build command
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that LeannBuilder was called with separate template keys
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
assert embedding_options is not None, (
|
||||
"embedding_options should not be None when templates provided"
|
||||
)
|
||||
|
||||
assert "build_prompt_template" in embedding_options, (
|
||||
"embedding_options should contain 'build_prompt_template' key"
|
||||
)
|
||||
assert embedding_options["build_prompt_template"] == build_template, (
|
||||
f"build_prompt_template should be '{build_template}'"
|
||||
)
|
||||
|
||||
assert "query_prompt_template" in embedding_options, (
|
||||
"embedding_options should contain 'query_prompt_template' key"
|
||||
)
|
||||
assert embedding_options["query_prompt_template"] == query_template, (
|
||||
f"query_prompt_template should be '{query_template}'"
|
||||
)
|
||||
|
||||
# Old key should NOT be present when using new separate template format
|
||||
assert "prompt_template" not in embedding_options, (
|
||||
"Old 'prompt_template' key should not be present with separate templates"
|
||||
)
|
||||
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path):
|
||||
"""
|
||||
R1 Test 2: Verify backward compatibility - when only
|
||||
--embedding-prompt-template is provided (old behavior), it should
|
||||
still be stored as 'prompt_template' in embedding_options.
|
||||
|
||||
This ensures existing workflows continue to work unchanged.
|
||||
|
||||
This test currently passes because it matches existing behavior, but it
|
||||
documents the requirement that this behavior must be preserved after
|
||||
implementing the separate template feature.
|
||||
|
||||
Expected behavior:
|
||||
- .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}}
|
||||
- No build_prompt_template or query_prompt_template keys
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
template = "prompt: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--embedding-prompt-template",
|
||||
template,
|
||||
"--force",
|
||||
]
|
||||
)
|
||||
|
||||
# Run the build command
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that LeannBuilder was called with old format
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
assert embedding_options is not None, (
|
||||
"embedding_options should not be None when template provided"
|
||||
)
|
||||
|
||||
assert "prompt_template" in embedding_options, (
|
||||
"embedding_options should contain old 'prompt_template' key for backward compat"
|
||||
)
|
||||
assert embedding_options["prompt_template"] == template, (
|
||||
f"prompt_template should be '{template}'"
|
||||
)
|
||||
|
||||
# New keys should NOT be present in backward compat mode
|
||||
assert "build_prompt_template" not in embedding_options, (
|
||||
"build_prompt_template should not be present with single template flag"
|
||||
)
|
||||
assert "query_prompt_template" not in embedding_options, (
|
||||
"query_prompt_template should not be present with single template flag"
|
||||
)
|
||||
|
||||
@patch("leann.cli.LeannBuilder")
|
||||
def test_build_no_templates(self, mock_builder_class, tmp_path):
|
||||
"""
|
||||
R1 Test 3: Verify that when no template flags are provided,
|
||||
embedding_options has no prompt template keys.
|
||||
|
||||
This ensures clean defaults and no unnecessary keys in .meta.json.
|
||||
|
||||
This test currently passes because it matches existing behavior, but it
|
||||
documents the requirement that this behavior must be preserved after
|
||||
implementing the separate template feature.
|
||||
|
||||
Expected behavior:
|
||||
- .meta.json has no prompt_template, build_prompt_template, or
|
||||
query_prompt_template keys (or embedding_options is empty/None)
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_builder = Mock()
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"])
|
||||
|
||||
# Run the build command
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Check that no template keys are present
|
||||
call_kwargs = mock_builder_class.call_args.kwargs
|
||||
if call_kwargs.get("embedding_options"):
|
||||
embedding_options = call_kwargs["embedding_options"]
|
||||
|
||||
# None of the template keys should be present
|
||||
assert "prompt_template" not in embedding_options, (
|
||||
"prompt_template should not be present when no flags provided"
|
||||
)
|
||||
assert "build_prompt_template" not in embedding_options, (
|
||||
"build_prompt_template should not be present when no flags provided"
|
||||
)
|
||||
assert "query_prompt_template" not in embedding_options, (
|
||||
"query_prompt_template should not be present when no flags provided"
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplateFlowsToComputeEmbeddings:
|
||||
"""Tests for template flowing through to compute_embeddings function."""
|
||||
|
||||
@patch("leann.api.compute_embeddings")
|
||||
def test_prompt_template_flows_to_compute_embeddings_via_provider_options(
|
||||
self, mock_compute_embeddings, tmp_path
|
||||
):
|
||||
"""
|
||||
Verify that the prompt template flows from CLI args through LeannBuilder
|
||||
to compute_embeddings() function via provider_options parameter.
|
||||
|
||||
This is an integration test that verifies the complete flow:
|
||||
CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options)
|
||||
|
||||
This test will fail because:
|
||||
1. CLI doesn't capture the argument yet
|
||||
2. embedding_options doesn't include prompt_template
|
||||
3. LeannBuilder doesn't pass it through to compute_embeddings
|
||||
"""
|
||||
# Mock compute_embeddings to return dummy embeddings as numpy array
|
||||
import numpy as np
|
||||
|
||||
mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
|
||||
# Use real LeannBuilder (not mocked) to test the actual flow
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a simple document
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
template = "search_document: "
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"build",
|
||||
"test-index",
|
||||
"--docs",
|
||||
str(tmp_path),
|
||||
"--embedding-prompt-template",
|
||||
template,
|
||||
"--backend-name",
|
||||
"hnsw", # Use hnsw backend
|
||||
"--force", # Force rebuild to ensure index is created
|
||||
]
|
||||
)
|
||||
|
||||
# This should fail because the flow isn't implemented yet
|
||||
import asyncio
|
||||
|
||||
asyncio.run(cli.build_index(args))
|
||||
|
||||
# Verify compute_embeddings was called with provider_options containing prompt_template
|
||||
assert mock_compute_embeddings.called, "compute_embeddings should have been called"
|
||||
|
||||
# Check the call arguments
|
||||
call_kwargs = mock_compute_embeddings.call_args.kwargs
|
||||
assert "provider_options" in call_kwargs, (
|
||||
"compute_embeddings should receive provider_options parameter"
|
||||
)
|
||||
|
||||
provider_options = call_kwargs["provider_options"]
|
||||
assert provider_options is not None, "provider_options should not be None"
|
||||
assert "prompt_template" in provider_options, (
|
||||
"provider_options should contain prompt_template key"
|
||||
)
|
||||
assert provider_options["prompt_template"] == template, (
|
||||
f"Template should be '{template}', got {provider_options.get('prompt_template')}"
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplateArgumentHelp:
|
||||
"""Tests for argument help text and documentation."""
|
||||
|
||||
def test_build_command_prompt_template_has_help_text(self):
|
||||
"""
|
||||
Verify that --embedding-prompt-template has descriptive help text.
|
||||
|
||||
Good help text is crucial for CLI usability.
|
||||
"""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
# Get the build subparser
|
||||
# This is a bit tricky - we need to parse to get the help
|
||||
# We'll check that the help includes relevant keywords
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
f = io.StringIO()
|
||||
try:
|
||||
with redirect_stdout(f):
|
||||
parser.parse_args(["build", "--help"])
|
||||
except SystemExit:
|
||||
pass # --help causes sys.exit(0)
|
||||
|
||||
help_text = f.getvalue()
|
||||
assert "--embedding-prompt-template" in help_text, (
|
||||
"Help text should mention --embedding-prompt-template"
|
||||
)
|
||||
# Check for keywords that should be in the help
|
||||
help_lower = help_text.lower()
|
||||
assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), (
|
||||
"Help text should explain what the prompt template does"
|
||||
)
|
||||
|
||||
def test_search_command_prompt_template_has_help_text(self):
|
||||
"""
|
||||
Verify that search command also has help text for --embedding-prompt-template.
|
||||
"""
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
f = io.StringIO()
|
||||
try:
|
||||
with redirect_stdout(f):
|
||||
parser.parse_args(["search", "--help"])
|
||||
except SystemExit:
|
||||
pass # --help causes sys.exit(0)
|
||||
|
||||
help_text = f.getvalue()
|
||||
assert "--embedding-prompt-template" in help_text, (
|
||||
"Search help text should mention --embedding-prompt-template"
|
||||
)
|
||||
@@ -1,281 +0,0 @@
|
||||
"""Unit tests for prompt template prepending in OpenAI embeddings.
|
||||
|
||||
This test suite defines the contract for prompt template functionality that allows
|
||||
users to prepend a consistent prompt to all embedding inputs. These tests verify:
|
||||
|
||||
1. Template prepending to all input texts before embedding computation
|
||||
2. Graceful handling of None/missing provider_options
|
||||
3. Empty string template behavior (no-op)
|
||||
4. Logging of template application for observability
|
||||
5. Template application before token truncation
|
||||
|
||||
All tests are written in Red Phase - they should FAIL initially because the
|
||||
implementation does not exist yet.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from leann.embedding_compute import compute_embeddings_openai
|
||||
|
||||
|
||||
class TestPromptTemplatePrepending:
|
||||
"""Tests for prompt template prepending in compute_embeddings_openai."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client(self):
|
||||
"""Create mock OpenAI client that captures input texts."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the embeddings.create response
|
||||
mock_response = Mock()
|
||||
mock_response.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||
]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
return mock_client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_module(self, mock_openai_client, monkeypatch):
|
||||
"""Mock the openai module to return our mock client."""
|
||||
# Mock the API key environment variable
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking")
|
||||
|
||||
# openai is imported inside the function, so we need to patch it there
|
||||
with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai:
|
||||
yield mock_openai
|
||||
|
||||
def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client):
|
||||
"""Verify template is prepended to all input texts.
|
||||
|
||||
When provider_options contains "prompt_template", that template should
|
||||
be prepended to every text in the input list before sending to OpenAI API.
|
||||
|
||||
This is the core functionality: the template acts as a consistent prefix
|
||||
that provides context or instruction for the embedding model.
|
||||
"""
|
||||
texts = ["First document", "Second document"]
|
||||
template = "search_document: "
|
||||
provider_options = {"prompt_template": template}
|
||||
|
||||
# Call compute_embeddings_openai with provider_options
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
# Verify embeddings.create was called with templated texts
|
||||
mock_openai_client.embeddings.create.assert_called_once()
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
|
||||
# Extract the input texts sent to API
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
|
||||
# Verify template was prepended to all texts
|
||||
assert len(sent_texts) == 2, "Should send same number of texts"
|
||||
assert sent_texts[0] == "search_document: First document", (
|
||||
"Template should be prepended to first text"
|
||||
)
|
||||
assert sent_texts[1] == "search_document: Second document", (
|
||||
"Template should be prepended to second text"
|
||||
)
|
||||
|
||||
# Verify result is valid embeddings array
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == (2, 3), "Should return correct shape"
|
||||
|
||||
def test_template_not_applied_when_missing_or_empty(
|
||||
self, mock_openai_module, mock_openai_client
|
||||
):
|
||||
"""Verify template not applied when provider_options is None, missing key, or empty string.
|
||||
|
||||
This consolidated test covers three scenarios where templates should NOT be applied:
|
||||
1. provider_options is None (default behavior)
|
||||
2. provider_options exists but missing 'prompt_template' key
|
||||
3. prompt_template is explicitly set to empty string ""
|
||||
|
||||
In all cases, texts should be sent to the API unchanged.
|
||||
"""
|
||||
# Scenario 1: None provider_options
|
||||
texts = ["Original text one", "Original text two"]
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=None,
|
||||
)
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
assert sent_texts[0] == "Original text one", (
|
||||
"Text should be unchanged with None provider_options"
|
||||
)
|
||||
assert sent_texts[1] == "Original text two"
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == (2, 3)
|
||||
|
||||
# Reset mock for next scenario
|
||||
mock_openai_client.reset_mock()
|
||||
mock_response = Mock()
|
||||
mock_response.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||
]
|
||||
mock_openai_client.embeddings.create.return_value = mock_response
|
||||
|
||||
# Scenario 2: Missing 'prompt_template' key
|
||||
texts = ["Text without template", "Another text"]
|
||||
provider_options = {"base_url": "https://api.openai.com/v1"}
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key"
|
||||
assert sent_texts[1] == "Another text"
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
# Reset mock for next scenario
|
||||
mock_openai_client.reset_mock()
|
||||
mock_openai_client.embeddings.create.return_value = mock_response
|
||||
|
||||
# Scenario 3: Empty string template
|
||||
texts = ["Text one", "Text two"]
|
||||
provider_options = {"prompt_template": ""}
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
assert sent_texts[0] == "Text one", "Empty template should not modify text"
|
||||
assert sent_texts[1] == "Text two"
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client):
|
||||
"""Verify template is prepended in all batches when texts exceed batch size.
|
||||
|
||||
OpenAI API has batch size limits. When input texts are split into
|
||||
multiple batches, the template should be prepended to texts in every batch.
|
||||
|
||||
This ensures consistency across all API calls.
|
||||
"""
|
||||
# Create many texts that will be split into multiple batches
|
||||
texts = [f"Document {i}" for i in range(1000)]
|
||||
template = "passage: "
|
||||
provider_options = {"prompt_template": template}
|
||||
|
||||
# Mock multiple batch responses
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)]
|
||||
mock_openai_client.embeddings.create.return_value = mock_response
|
||||
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
# Verify embeddings.create was called multiple times (batching)
|
||||
assert mock_openai_client.embeddings.create.call_count >= 2, (
|
||||
"Should make multiple API calls for large text list"
|
||||
)
|
||||
|
||||
# Verify template was prepended in ALL batches
|
||||
for call in mock_openai_client.embeddings.create.call_args_list:
|
||||
sent_texts = call.kwargs["input"]
|
||||
for text in sent_texts:
|
||||
assert text.startswith(template), (
|
||||
f"All texts in all batches should start with template. Got: {text}"
|
||||
)
|
||||
|
||||
# Verify result shape
|
||||
assert result.shape[0] == 1000, "Should return embeddings for all texts"
|
||||
|
||||
def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client):
|
||||
"""Verify template with special characters is handled correctly.
|
||||
|
||||
Templates may contain special characters, Unicode, newlines, etc.
|
||||
These should all be prepended correctly without encoding issues.
|
||||
"""
|
||||
texts = ["Document content"]
|
||||
# Template with various special characters
|
||||
template = "🔍 Search query [EN]: "
|
||||
provider_options = {"prompt_template": template}
|
||||
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
# Verify special characters in template were preserved
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
|
||||
assert sent_texts[0] == "🔍 Search query [EN]: Document content", (
|
||||
"Special characters in template should be preserved"
|
||||
)
|
||||
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_prompt_template_integration_with_existing_validation(
|
||||
self, mock_openai_module, mock_openai_client
|
||||
):
|
||||
"""Verify template works with existing input validation.
|
||||
|
||||
compute_embeddings_openai has validation for empty texts and whitespace.
|
||||
Template prepending should happen AFTER validation, so validation errors
|
||||
are thrown based on original texts, not templated texts.
|
||||
|
||||
This ensures users get clear error messages about their input.
|
||||
"""
|
||||
# Empty text should still raise ValueError even with template
|
||||
texts = [""]
|
||||
provider_options = {"prompt_template": "prefix: "}
|
||||
|
||||
with pytest.raises(ValueError, match="empty/invalid"):
|
||||
compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
def test_prompt_template_with_api_key_and_base_url(
|
||||
self, mock_openai_module, mock_openai_client
|
||||
):
|
||||
"""Verify template works alongside other provider_options.
|
||||
|
||||
provider_options may contain multiple settings: prompt_template,
|
||||
base_url, api_key. All should work together correctly.
|
||||
"""
|
||||
texts = ["Test document"]
|
||||
provider_options = {
|
||||
"prompt_template": "embed: ",
|
||||
"base_url": "https://custom.api.com/v1",
|
||||
"api_key": "test-key-123",
|
||||
}
|
||||
|
||||
result = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name="text-embedding-3-small",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
# Verify template was applied
|
||||
call_args = mock_openai_client.embeddings.create.call_args
|
||||
sent_texts = call_args.kwargs["input"]
|
||||
assert sent_texts[0] == "embed: Test document"
|
||||
|
||||
# Verify OpenAI client was created with correct base_url
|
||||
mock_openai_module.assert_called()
|
||||
client_init_kwargs = mock_openai_module.call_args.kwargs
|
||||
assert client_init_kwargs["base_url"] == "https://custom.api.com/v1"
|
||||
assert client_init_kwargs["api_key"] == "test-key-123"
|
||||
|
||||
assert isinstance(result, np.ndarray)
|
||||
@@ -1,315 +0,0 @@
|
||||
"""Unit tests for LM Studio TypeScript SDK bridge functionality.
|
||||
|
||||
This test suite defines the contract for the LM Studio SDK bridge that queries
|
||||
model context length via Node.js subprocess. These tests verify:
|
||||
|
||||
1. Successful SDK query returns context length
|
||||
2. Graceful fallback when Node.js not installed (FileNotFoundError)
|
||||
3. Graceful fallback when SDK not installed (npm error)
|
||||
4. Timeout handling (subprocess.TimeoutExpired)
|
||||
5. Invalid JSON response handling
|
||||
|
||||
All tests are written in Red Phase - they should FAIL initially because the
|
||||
`_query_lmstudio_context_limit` function does not exist yet.
|
||||
|
||||
The function contract:
|
||||
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
|
||||
- Outputs: context_length (int) or None on error
|
||||
- Requirements:
|
||||
1. Call Node.js with inline JavaScript using @lmstudio/sdk
|
||||
2. 10-second timeout (accounts for Node.js startup)
|
||||
3. Graceful fallback on any error (returns None, doesn't raise)
|
||||
4. Parse JSON response with contextLength field
|
||||
5. Log errors at debug level (not warning/error)
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
# Try to import the function - if it doesn't exist, tests will fail as expected
|
||||
try:
|
||||
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||
except ImportError:
|
||||
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
|
||||
def _query_lmstudio_context_limit(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
|
||||
)
|
||||
|
||||
|
||||
class TestLMStudioBridge:
|
||||
"""Tests for LM Studio TypeScript SDK bridge integration."""
|
||||
|
||||
def test_query_lmstudio_success(self, monkeypatch):
|
||||
"""Verify successful SDK query returns context length.
|
||||
|
||||
When the Node.js subprocess successfully queries the LM Studio SDK,
|
||||
it should return a JSON response with contextLength field. The function
|
||||
should parse this and return the integer context length.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
# Verify timeout is set to 10 seconds
|
||||
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
|
||||
|
||||
# Verify capture_output and text=True are set
|
||||
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
|
||||
assert kwargs.get("text") is True, "Should decode output as text"
|
||||
|
||||
# Return successful JSON response
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
# Test with typical LM Studio model
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit == 8192, "Should return context length from SDK response"
|
||||
|
||||
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
|
||||
"""Verify graceful fallback when Node.js not installed.
|
||||
|
||||
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
|
||||
The function should catch this and return None (graceful fallback to registry).
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
raise FileNotFoundError("node: command not found")
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when Node.js not installed"
|
||||
|
||||
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
|
||||
"""Verify graceful fallback when @lmstudio/sdk not installed.
|
||||
|
||||
When the SDK npm package is not installed, Node.js will return non-zero
|
||||
exit code with error message in stderr. The function should detect this
|
||||
and return None.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 1
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = (
|
||||
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
|
||||
)
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when SDK not installed"
|
||||
|
||||
def test_query_lmstudio_timeout(self, monkeypatch):
|
||||
"""Verify graceful fallback when subprocess times out.
|
||||
|
||||
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
|
||||
not responding), subprocess.TimeoutExpired should be raised. The function
|
||||
should catch this and return None.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None on timeout"
|
||||
|
||||
def test_query_lmstudio_invalid_json(self, monkeypatch):
|
||||
"""Verify graceful fallback when response is invalid JSON.
|
||||
|
||||
When the subprocess returns malformed JSON (e.g., due to SDK error),
|
||||
json.loads will raise ValueError/JSONDecodeError. The function should
|
||||
catch this and return None.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "This is not valid JSON"
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when JSON parsing fails"
|
||||
|
||||
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
|
||||
"""Verify graceful fallback when JSON lacks contextLength field.
|
||||
|
||||
When the SDK returns valid JSON but without the expected contextLength
|
||||
field (e.g., error response), the function should return None.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="nonexistent-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when contextLength field missing"
|
||||
|
||||
def test_query_lmstudio_null_context_length(self, monkeypatch):
|
||||
"""Verify graceful fallback when contextLength is null.
|
||||
|
||||
When the SDK returns contextLength: null (model couldn't be loaded),
|
||||
the function should return None for registry fallback.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="test-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when contextLength is null"
|
||||
|
||||
def test_query_lmstudio_zero_context_length(self, monkeypatch):
|
||||
"""Verify graceful fallback when contextLength is zero.
|
||||
|
||||
When the SDK returns contextLength: 0 (invalid value), the function
|
||||
should return None to trigger registry fallback.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="test-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit is None, "Should return None when contextLength is zero"
|
||||
|
||||
def test_query_lmstudio_with_custom_port(self, monkeypatch):
|
||||
"""Verify SDK query works with non-default WebSocket port.
|
||||
|
||||
LM Studio can run on custom ports. The function should pass the
|
||||
provided base_url to the Node.js subprocess.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
# Verify the base_url argument is passed correctly
|
||||
command = args[0] if args else kwargs.get("args", [])
|
||||
assert "ws://localhost:8080" in " ".join(command), (
|
||||
"Should pass custom port to subprocess"
|
||||
)
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="custom-model", base_url="ws://localhost:8080"
|
||||
)
|
||||
|
||||
assert limit == 4096, "Should work with custom WebSocket port"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"context_length,expected",
|
||||
[
|
||||
(512, 512), # Small context
|
||||
(2048, 2048), # Common context
|
||||
(8192, 8192), # Large context
|
||||
(32768, 32768), # Very large context
|
||||
],
|
||||
)
|
||||
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
|
||||
"""Verify SDK query handles various context length values.
|
||||
|
||||
Different models have different context lengths. The function should
|
||||
correctly parse and return any positive integer value.
|
||||
"""
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
mock_result = Mock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
|
||||
mock_result.stderr = ""
|
||||
return mock_result
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
limit = _query_lmstudio_context_limit(
|
||||
model_name="test-model", base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
assert limit == expected, f"Should return {expected} for context length {context_length}"
|
||||
|
||||
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
|
||||
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
|
||||
|
||||
Following the graceful fallback pattern from Ollama implementation,
|
||||
errors should be logged at debug level to avoid alarming users when
|
||||
fallback to registry works fine.
|
||||
"""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
raise FileNotFoundError("node: command not found")
|
||||
|
||||
monkeypatch.setattr("subprocess.run", mock_run)
|
||||
|
||||
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
|
||||
|
||||
# Check that debug logging occurred (not warning/error)
|
||||
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
|
||||
assert len(debug_logs) > 0, "Should log error at DEBUG level"
|
||||
|
||||
# Verify no WARNING or ERROR logs
|
||||
warning_or_error_logs = [
|
||||
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
|
||||
]
|
||||
assert len(warning_or_error_logs) == 0, (
|
||||
"Should not log at WARNING/ERROR level for expected failures"
|
||||
)
|
||||
@@ -1,400 +0,0 @@
|
||||
"""End-to-end integration tests for prompt template and token limit features.
|
||||
|
||||
These tests verify real-world functionality with live services:
|
||||
- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support
|
||||
- Ollama with dynamic token limit detection
|
||||
- Hybrid token limit discovery mechanism
|
||||
|
||||
Run with: pytest tests/test_prompt_template_e2e.py -v -s
|
||||
Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration"
|
||||
|
||||
Prerequisites:
|
||||
1. LM Studio running with embedding model: http://localhost:1234
|
||||
2. [Optional] Ollama running: ollama serve
|
||||
3. [Optional] Ollama model: ollama pull nomic-embed-text
|
||||
4. [Optional] Node.js + @lmstudio/sdk for context length detection
|
||||
"""
|
||||
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
from leann.embedding_compute import (
|
||||
compute_embeddings_ollama,
|
||||
compute_embeddings_openai,
|
||||
get_model_token_limit,
|
||||
)
|
||||
|
||||
# Test markers for conditional execution
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool:
|
||||
"""Check if a service is available on the given host:port."""
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
result = sock.connect_ex((host, port))
|
||||
sock.close()
|
||||
return result == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def check_ollama_available() -> bool:
|
||||
"""Check if Ollama service is available."""
|
||||
if not check_service_available("localhost", 11434):
|
||||
return False
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def check_lmstudio_available() -> bool:
|
||||
"""Check if LM Studio service is available."""
|
||||
if not check_service_available("localhost", 1234):
|
||||
return False
|
||||
try:
|
||||
response = requests.get("http://localhost:1234/v1/models", timeout=2.0)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_lmstudio_first_model() -> str:
|
||||
"""Get the first available model from LM Studio."""
|
||||
try:
|
||||
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||
data = response.json()
|
||||
models = data.get("data", [])
|
||||
if models:
|
||||
return models[0]["id"]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class TestPromptTemplateOpenAI:
|
||||
"""End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio)."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234"
|
||||
)
|
||||
def test_lmstudio_embedding_with_prompt_template(self):
|
||||
"""Test prompt templates with LM Studio using OpenAI-compatible API."""
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
|
||||
texts = ["artificial intelligence", "machine learning"]
|
||||
prompt_template = "search_query: "
|
||||
|
||||
# Get embeddings with prompt template via provider_options
|
||||
provider_options = {"prompt_template": prompt_template}
|
||||
embeddings = compute_embeddings_openai(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
base_url="http://localhost:1234/v1",
|
||||
api_key="lm-studio", # LM Studio doesn't require real key
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
assert embeddings is not None
|
||||
assert len(embeddings) == 2
|
||||
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||
assert all(len(emb) > 0 for emb in embeddings)
|
||||
|
||||
logger.info(
|
||||
f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||
def test_lmstudio_prompt_template_affects_embeddings(self):
|
||||
"""Verify that prompt templates actually change embedding values."""
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
|
||||
text = "machine learning"
|
||||
base_url = "http://localhost:1234/v1"
|
||||
api_key = "lm-studio"
|
||||
|
||||
# Get embeddings without template
|
||||
embeddings_no_template = compute_embeddings_openai(
|
||||
texts=[text],
|
||||
model_name=model_name,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
provider_options={},
|
||||
)
|
||||
|
||||
# Get embeddings with template
|
||||
embeddings_with_template = compute_embeddings_openai(
|
||||
texts=[text],
|
||||
model_name=model_name,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
provider_options={"prompt_template": "search_query: "},
|
||||
)
|
||||
|
||||
# Embeddings should be different when template is applied
|
||||
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||
|
||||
logger.info("✓ Prompt template changes embedding values as expected")
|
||||
|
||||
|
||||
class TestPromptTemplateOllama:
|
||||
"""End-to-end tests for prompt template with Ollama."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_ollama_available(), reason="Ollama service not available on localhost:11434"
|
||||
)
|
||||
def test_ollama_embedding_with_prompt_template(self):
|
||||
"""Test prompt templates with Ollama using any available embedding model."""
|
||||
# Get any available embedding model
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
models = response.json().get("models", [])
|
||||
|
||||
embedding_models = []
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
base_name = name.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||
embedding_models.append(name)
|
||||
|
||||
if not embedding_models:
|
||||
pytest.skip("No embedding models available in Ollama")
|
||||
|
||||
model_name = embedding_models[0]
|
||||
|
||||
texts = ["artificial intelligence", "machine learning"]
|
||||
prompt_template = "search_query: "
|
||||
|
||||
# Get embeddings with prompt template via provider_options
|
||||
provider_options = {"prompt_template": prompt_template}
|
||||
embeddings = compute_embeddings_ollama(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
is_build=False,
|
||||
host="http://localhost:11434",
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
assert embeddings is not None
|
||||
assert len(embeddings) == 2
|
||||
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||
assert all(len(emb) > 0 for emb in embeddings)
|
||||
|
||||
logger.info(
|
||||
f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||
|
||||
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||
def test_ollama_prompt_template_affects_embeddings(self):
|
||||
"""Verify that prompt templates actually change embedding values with Ollama."""
|
||||
# Get any available embedding model
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
models = response.json().get("models", [])
|
||||
|
||||
embedding_models = []
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
base_name = name.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||
embedding_models.append(name)
|
||||
|
||||
if not embedding_models:
|
||||
pytest.skip("No embedding models available in Ollama")
|
||||
|
||||
model_name = embedding_models[0]
|
||||
text = "machine learning"
|
||||
host = "http://localhost:11434"
|
||||
|
||||
# Get embeddings without template
|
||||
embeddings_no_template = compute_embeddings_ollama(
|
||||
texts=[text], model_name=model_name, is_build=False, host=host, provider_options={}
|
||||
)
|
||||
|
||||
# Get embeddings with template
|
||||
embeddings_with_template = compute_embeddings_ollama(
|
||||
texts=[text],
|
||||
model_name=model_name,
|
||||
is_build=False,
|
||||
host=host,
|
||||
provider_options={"prompt_template": "search_query: "},
|
||||
)
|
||||
|
||||
# Embeddings should be different when template is applied
|
||||
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||
|
||||
logger.info("✓ Ollama prompt template changes embedding values as expected")
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||
|
||||
|
||||
class TestLMStudioSDK:
|
||||
"""End-to-end tests for LM Studio SDK integration."""
|
||||
|
||||
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||
def test_lmstudio_model_listing(self):
|
||||
"""Test that we can list models from LM Studio."""
|
||||
try:
|
||||
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
|
||||
models = data["data"]
|
||||
logger.info(f"✓ LM Studio models available: {len(models)}")
|
||||
|
||||
if models:
|
||||
logger.info(f" First model: {models[0].get('id', 'unknown')}")
|
||||
except Exception as e:
|
||||
pytest.skip(f"LM Studio API error: {e}")
|
||||
|
||||
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||
def test_lmstudio_sdk_context_length_detection(self):
|
||||
"""Test context length detection via LM Studio SDK bridge (requires Node.js + SDK)."""
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||
|
||||
# SDK requires WebSocket URL (ws://)
|
||||
context_length = _query_lmstudio_context_limit(
|
||||
model_name=model_name, base_url="ws://localhost:1234"
|
||||
)
|
||||
|
||||
if context_length is None:
|
||||
logger.warning(
|
||||
"⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)"
|
||||
)
|
||||
pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable")
|
||||
else:
|
||||
assert context_length > 0
|
||||
logger.info(
|
||||
f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}"
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
pytest.skip("_query_lmstudio_context_limit not implemented yet")
|
||||
except Exception as e:
|
||||
logger.error(f"LM Studio SDK test error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class TestOllamaTokenLimit:
|
||||
"""End-to-end tests for Ollama token limit discovery."""
|
||||
|
||||
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||
def test_ollama_token_limit_detection(self):
|
||||
"""Test dynamic token limit detection from Ollama /api/show endpoint."""
|
||||
# Get any available embedding model
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
models = response.json().get("models", [])
|
||||
|
||||
embedding_models = []
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
base_name = name.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||
embedding_models.append(name)
|
||||
|
||||
if not embedding_models:
|
||||
pytest.skip("No embedding models available in Ollama")
|
||||
|
||||
test_model = embedding_models[0]
|
||||
|
||||
# Test token limit detection
|
||||
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||
|
||||
assert limit > 0
|
||||
logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}")
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not test Ollama token detection: {e}")
|
||||
|
||||
|
||||
class TestHybridTokenLimit:
|
||||
"""End-to-end tests for hybrid token limit discovery mechanism."""
|
||||
|
||||
def test_hybrid_discovery_registry_fallback(self):
|
||||
"""Test fallback to static registry for known OpenAI models."""
|
||||
# Use a known OpenAI model (should be in registry)
|
||||
limit = get_model_token_limit(
|
||||
model_name="text-embedding-3-small",
|
||||
base_url="http://fake-server:9999", # Fake URL to force registry lookup
|
||||
)
|
||||
|
||||
# text-embedding-3-small should have 8192 in registry
|
||||
assert limit == 8192
|
||||
logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens")
|
||||
|
||||
def test_hybrid_discovery_default_fallback(self):
|
||||
"""Test fallback to safe default for completely unknown models."""
|
||||
limit = get_model_token_limit(
|
||||
model_name="completely-unknown-model-xyz-12345",
|
||||
base_url="http://fake-server:9999",
|
||||
default=512,
|
||||
)
|
||||
|
||||
# Should get the specified default
|
||||
assert limit == 512
|
||||
logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens")
|
||||
|
||||
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||
def test_hybrid_discovery_ollama_dynamic_first(self):
|
||||
"""Test that Ollama models use dynamic discovery first."""
|
||||
# Get any available embedding model
|
||||
try:
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||
models = response.json().get("models", [])
|
||||
|
||||
embedding_models = []
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
base_name = name.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||
embedding_models.append(name)
|
||||
|
||||
if not embedding_models:
|
||||
pytest.skip("No embedding models available in Ollama")
|
||||
|
||||
test_model = embedding_models[0]
|
||||
|
||||
# Should query Ollama /api/show dynamically
|
||||
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||
|
||||
assert limit > 0
|
||||
logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}")
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not test hybrid Ollama discovery: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 70)
|
||||
print("INTEGRATION TEST SUITE - Real Service Testing")
|
||||
print("=" * 70)
|
||||
print("\nThese tests require live services:")
|
||||
print(" • LM Studio: http://localhost:1234 (with embedding model loaded)")
|
||||
print(" • [Optional] Ollama: http://localhost:11434")
|
||||
print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests")
|
||||
print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s")
|
||||
print("=" * 70 + "\n")
|
||||
@@ -1,808 +0,0 @@
|
||||
"""
|
||||
Integration tests for prompt template metadata persistence and reuse.
|
||||
|
||||
These tests verify the complete lifecycle of prompt template persistence:
|
||||
1. Template is saved to .meta.json during index build
|
||||
2. Template is automatically loaded during search operations
|
||||
3. Template can be overridden with explicit flag during search
|
||||
4. Template is reused during chat/ask operations
|
||||
|
||||
These are integration tests that:
|
||||
- Use real file system with temporary directories
|
||||
- Run actual build and search operations
|
||||
- Inspect .meta.json file contents directly
|
||||
- Mock embedding servers to avoid external dependencies
|
||||
- Use small test codebases for fast execution
|
||||
|
||||
Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
class TestPromptTemplateMetadataPersistence:
|
||||
"""Tests for prompt template storage in .meta.json during build."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings(self):
|
||||
"""Mock compute_embeddings to return dummy embeddings."""
|
||||
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||
# Return dummy embeddings as numpy array
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
yield mock_compute
|
||||
|
||||
def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings):
|
||||
"""
|
||||
Verify that when build is run with embedding_options containing prompt_template,
|
||||
the template value is saved to .meta.json file.
|
||||
|
||||
This is the core persistence requirement - templates must be saved to allow
|
||||
reuse in subsequent search operations without re-specifying the flag.
|
||||
|
||||
Expected failure: .meta.json exists but doesn't contain embedding_options
|
||||
with prompt_template, or the value is not persisted correctly.
|
||||
"""
|
||||
# Setup test data
|
||||
index_path = temp_index_dir / "test_index.leann"
|
||||
template = "search_document: "
|
||||
|
||||
# Build index with prompt template in embedding_options
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={"prompt_template": template},
|
||||
)
|
||||
|
||||
# Add a simple document
|
||||
builder.add_text("This is a test document for indexing")
|
||||
|
||||
# Build the index
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Verify .meta.json was created and contains the template
|
||||
meta_path = temp_index_dir / "test_index.leann.meta.json"
|
||||
assert meta_path.exists(), ".meta.json file should be created during build"
|
||||
|
||||
# Read and parse metadata
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
|
||||
# Verify embedding_options exists in metadata
|
||||
assert "embedding_options" in meta_data, (
|
||||
"embedding_options should be saved to .meta.json when provided"
|
||||
)
|
||||
|
||||
# Verify prompt_template is in embedding_options
|
||||
embedding_options = meta_data["embedding_options"]
|
||||
assert "prompt_template" in embedding_options, (
|
||||
"prompt_template should be saved within embedding_options"
|
||||
)
|
||||
|
||||
# Verify the template value matches what we provided
|
||||
assert embedding_options["prompt_template"] == template, (
|
||||
f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'"
|
||||
)
|
||||
|
||||
def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings):
|
||||
"""
|
||||
Verify that when no prompt template is provided during build,
|
||||
.meta.json either doesn't have embedding_options or prompt_template key.
|
||||
|
||||
This ensures clean metadata without unnecessary keys when features aren't used.
|
||||
|
||||
Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template.
|
||||
"""
|
||||
index_path = temp_index_dir / "test_no_template.leann"
|
||||
|
||||
# Build index WITHOUT prompt template
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
# No embedding_options provided
|
||||
)
|
||||
|
||||
builder.add_text("Document without template")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Verify metadata
|
||||
meta_path = temp_index_dir / "test_no_template.leann.meta.json"
|
||||
assert meta_path.exists()
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
|
||||
# If embedding_options exists, it should not contain prompt_template
|
||||
if "embedding_options" in meta_data:
|
||||
embedding_options = meta_data["embedding_options"]
|
||||
assert "prompt_template" not in embedding_options, (
|
||||
"prompt_template should not be in metadata when not provided"
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplateAutoLoadOnSearch:
|
||||
"""Tests for automatic loading of prompt template during search operations.
|
||||
|
||||
NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search).
|
||||
This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad
|
||||
which uses simpler mocking and doesn't hang.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings(self):
|
||||
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
yield mock_compute
|
||||
|
||||
def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings):
|
||||
"""
|
||||
Verify that searching an index built WITHOUT a prompt template
|
||||
works correctly (backward compatibility).
|
||||
|
||||
The searcher should handle missing prompt_template gracefully.
|
||||
|
||||
Expected behavior: Search succeeds, no template is used.
|
||||
"""
|
||||
# Build index without template
|
||||
index_path = temp_index_dir / "no_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
)
|
||||
builder.add_text("Document without template")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mocks
|
||||
mock_embeddings.reset_mock()
|
||||
|
||||
# Create searcher and search
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
# Verify no template in embedding_options
|
||||
assert "prompt_template" not in searcher.embedding_options, (
|
||||
"Searcher should not have prompt_template when not in metadata"
|
||||
)
|
||||
|
||||
|
||||
class TestQueryPromptTemplateAutoLoad:
|
||||
"""Tests for automatic loading of separate query_prompt_template during search (R2).
|
||||
|
||||
These tests verify the new two-template system where:
|
||||
- build_prompt_template: Applied during index building
|
||||
- query_prompt_template: Applied during search operations
|
||||
|
||||
Expected to FAIL in Red Phase (R2) because query template extraction
|
||||
and application is not yet implemented in LeannSearcher.search().
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_compute_embeddings(self):
|
||||
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||
with patch("leann.embedding_compute.compute_embeddings") as mock_compute:
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
yield mock_compute
|
||||
|
||||
def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings):
|
||||
"""
|
||||
Verify that search() automatically loads and applies query_prompt_template from .meta.json.
|
||||
|
||||
Given: Index built with separate build_prompt_template and query_prompt_template
|
||||
When: LeannSearcher.search("my query") is called
|
||||
Then: Query embedding is computed with "query: my query" (query template applied)
|
||||
|
||||
This is the core R2 requirement - query templates must be auto-loaded and applied
|
||||
during search without user intervention.
|
||||
|
||||
Expected failure: compute_embeddings called with raw "my query" instead of
|
||||
"query: my query" because query template extraction is not implemented.
|
||||
"""
|
||||
# Setup: Build index with separate templates in new format
|
||||
index_path = temp_index_dir / "query_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={
|
||||
"build_prompt_template": "doc: ",
|
||||
"query_prompt_template": "query: ",
|
||||
},
|
||||
)
|
||||
builder.add_text("Test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mock to ignore build calls
|
||||
mock_compute_embeddings.reset_mock()
|
||||
|
||||
# Act: Search with query
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
# Mock the backend search to avoid actual search
|
||||
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||
mock_backend_search.return_value = {
|
||||
"labels": [["test_id_0"]], # IDs (nested list for batch support)
|
||||
"distances": [[0.9]], # Distances (nested list for batch support)
|
||||
}
|
||||
|
||||
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||
|
||||
# Assert: compute_embeddings was called with query template applied
|
||||
assert mock_compute_embeddings.called, "compute_embeddings should be called during search"
|
||||
|
||||
# Get the actual text passed to compute_embeddings
|
||||
call_args = mock_compute_embeddings.call_args
|
||||
texts_arg = call_args[0][0] # First positional arg (list of texts)
|
||||
|
||||
assert len(texts_arg) == 1, "Should compute embedding for one query"
|
||||
assert texts_arg[0] == "query: my query", (
|
||||
f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'"
|
||||
)
|
||||
|
||||
def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings):
|
||||
"""
|
||||
Verify backward compatibility with old single prompt_template format.
|
||||
|
||||
Given: Index with old format (single prompt_template, no query_prompt_template)
|
||||
When: LeannSearcher.search("my query") is called
|
||||
Then: Query embedding computed with "doc: my query" (old template applied)
|
||||
|
||||
This ensures indexes built with the old single-template system continue
|
||||
to work correctly with the new search implementation.
|
||||
|
||||
Expected failure: Old template not recognized/applied because backward
|
||||
compatibility logic is not implemented.
|
||||
"""
|
||||
# Setup: Build index with old single-template format
|
||||
index_path = temp_index_dir / "old_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={"prompt_template": "doc: "}, # Old format
|
||||
)
|
||||
builder.add_text("Test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mock
|
||||
mock_compute_embeddings.reset_mock()
|
||||
|
||||
# Act: Search
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||
|
||||
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||
|
||||
# Assert: Old template was applied
|
||||
call_args = mock_compute_embeddings.call_args
|
||||
texts_arg = call_args[0][0]
|
||||
|
||||
assert texts_arg[0] == "doc: my query", (
|
||||
f"Old prompt_template should be applied for backward compatibility: "
|
||||
f"expected 'doc: my query', got '{texts_arg[0]}'"
|
||||
)
|
||||
|
||||
def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings):
|
||||
"""
|
||||
Verify backward compatibility when no template is present in .meta.json.
|
||||
|
||||
Given: Index with no template in .meta.json (very old indexes)
|
||||
When: LeannSearcher.search("my query") is called
|
||||
Then: Query embedding computed with "my query" (no template, raw query)
|
||||
|
||||
This ensures the most basic backward compatibility - indexes without
|
||||
any template support continue to work as before.
|
||||
|
||||
Expected failure: May fail if default template is incorrectly applied,
|
||||
or if missing template causes error.
|
||||
"""
|
||||
# Setup: Build index without any template
|
||||
index_path = temp_index_dir / "no_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
# No embedding_options at all
|
||||
)
|
||||
builder.add_text("Test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mock
|
||||
mock_compute_embeddings.reset_mock()
|
||||
|
||||
# Act: Search
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||
|
||||
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||
|
||||
# Assert: No template applied (raw query)
|
||||
call_args = mock_compute_embeddings.call_args
|
||||
texts_arg = call_args[0][0]
|
||||
|
||||
assert texts_arg[0] == "my query", (
|
||||
f"No template should be applied when missing from metadata: "
|
||||
f"expected 'my query', got '{texts_arg[0]}'"
|
||||
)
|
||||
|
||||
def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings):
|
||||
"""
|
||||
Verify that explicit provider_options can override stored query template.
|
||||
|
||||
Given: Index with query_prompt_template: "query: "
|
||||
When: search() called with provider_options={"prompt_template": "override: "}
|
||||
Then: Query embedding computed with "override: test" (override takes precedence)
|
||||
|
||||
This enables users to experiment with different query templates without
|
||||
rebuilding the index, or to handle special query types differently.
|
||||
|
||||
Expected failure: provider_options parameter is accepted via **kwargs but
|
||||
not used. Query embedding computed with raw "test" instead of "override: test"
|
||||
because override logic is not implemented.
|
||||
"""
|
||||
# Setup: Build index with query template
|
||||
index_path = temp_index_dir / "override_template.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={
|
||||
"build_prompt_template": "doc: ",
|
||||
"query_prompt_template": "query: ",
|
||||
},
|
||||
)
|
||||
builder.add_text("Test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Reset mock
|
||||
mock_compute_embeddings.reset_mock()
|
||||
|
||||
# Act: Search with override
|
||||
searcher = LeannSearcher(index_path=str(index_path))
|
||||
|
||||
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||
|
||||
# This should accept provider_options parameter
|
||||
searcher.search(
|
||||
"test",
|
||||
top_k=1,
|
||||
recompute_embeddings=False,
|
||||
provider_options={"prompt_template": "override: "},
|
||||
)
|
||||
|
||||
# Assert: Override template was applied
|
||||
call_args = mock_compute_embeddings.call_args
|
||||
texts_arg = call_args[0][0]
|
||||
|
||||
assert texts_arg[0] == "override: test", (
|
||||
f"Override template should take precedence: "
|
||||
f"expected 'override: test', got '{texts_arg[0]}'"
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplateReuseInChat:
|
||||
"""Tests for prompt template reuse in chat/ask operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings(self):
|
||||
"""Mock compute_embeddings to return dummy embeddings."""
|
||||
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
yield mock_compute
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_server_manager(self):
|
||||
"""Mock EmbeddingServerManager for chat tests."""
|
||||
with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class:
|
||||
mock_manager = Mock()
|
||||
mock_manager.start_server.return_value = (True, 5557)
|
||||
mock_manager_class.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
@pytest.fixture
|
||||
def index_with_template(self, temp_index_dir, mock_embeddings):
|
||||
"""Build an index with a prompt template."""
|
||||
index_path = temp_index_dir / "chat_template_index.leann"
|
||||
template = "document_query: "
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
embedding_options={"prompt_template": template},
|
||||
)
|
||||
|
||||
builder.add_text("Test document for chat")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
return str(index_path), template
|
||||
|
||||
|
||||
class TestPromptTemplateIntegrationWithEmbeddingModes:
|
||||
"""Tests for prompt template compatibility with different embedding modes."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir(self):
|
||||
"""Create temporary directory for test indexes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode,model,template,filename_prefix",
|
||||
[
|
||||
(
|
||||
"openai",
|
||||
"text-embedding-3-small",
|
||||
"Represent this for searching: ",
|
||||
"openai_template",
|
||||
),
|
||||
("ollama", "nomic-embed-text", "search_query: ", "ollama_template"),
|
||||
("sentence-transformers", "facebook/contriever", "query: ", "st_template"),
|
||||
],
|
||||
)
|
||||
def test_prompt_template_metadata_with_embedding_modes(
|
||||
self, temp_index_dir, mode, model, template, filename_prefix
|
||||
):
|
||||
"""Verify prompt template is saved correctly across different embedding modes.
|
||||
|
||||
Tests that prompt templates are persisted to .meta.json for:
|
||||
- OpenAI mode (primary use case)
|
||||
- Ollama mode (also supports templates)
|
||||
- Sentence-transformers mode (saved for forward compatibility)
|
||||
|
||||
Expected behavior: Template is saved to .meta.json regardless of mode.
|
||||
"""
|
||||
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||
|
||||
index_path = temp_index_dir / f"{filename_prefix}.leann"
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model,
|
||||
embedding_mode=mode,
|
||||
embedding_options={"prompt_template": template},
|
||||
)
|
||||
|
||||
builder.add_text(f"{mode.capitalize()} test document")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Verify metadata
|
||||
meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
|
||||
assert meta_data["embedding_mode"] == mode
|
||||
# Template should be saved for all modes (even if not used by some)
|
||||
if "embedding_options" in meta_data:
|
||||
assert meta_data["embedding_options"]["prompt_template"] == template
|
||||
|
||||
|
||||
class TestQueryTemplateApplicationInComputeEmbedding:
|
||||
"""Tests for query template application in compute_query_embedding() (Bug Fix).
|
||||
|
||||
These tests verify that query templates are applied consistently in BOTH
|
||||
code paths (server and fallback) when computing query embeddings.
|
||||
|
||||
This addresses the bug where query templates were only applied in the
|
||||
fallback path, not when using the embedding server (the default path).
|
||||
|
||||
Bug Context:
|
||||
- Issue: Query templates were stored in metadata but only applied during
|
||||
fallback (direct) computation, not when using embedding server
|
||||
- Fix: Move template application to BEFORE any computation path in
|
||||
compute_query_embedding() (searcher_base.py:107-110)
|
||||
- Impact: Critical for models like EmbeddingGemma that require task-specific
|
||||
templates for optimal performance
|
||||
|
||||
These tests ensure the fix works correctly and prevent regression.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_with_template(self):
|
||||
"""Create a temporary index with query template in metadata"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
index_dir = Path(tmpdir)
|
||||
index_file = index_dir / "test.leann"
|
||||
meta_file = index_dir / "test.leann.meta.json"
|
||||
|
||||
# Create minimal metadata with query template
|
||||
metadata = {
|
||||
"version": "1.0",
|
||||
"backend_name": "hnsw",
|
||||
"embedding_model": "text-embedding-embeddinggemma-300m-qat",
|
||||
"dimensions": 768,
|
||||
"embedding_mode": "openai",
|
||||
"backend_kwargs": {
|
||||
"graph_degree": 32,
|
||||
"complexity": 64,
|
||||
"distance_metric": "cosine",
|
||||
},
|
||||
"embedding_options": {
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_key": "test-key",
|
||||
"build_prompt_template": "title: none | text: ",
|
||||
"query_prompt_template": "task: search result | query: ",
|
||||
},
|
||||
}
|
||||
|
||||
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
# Create minimal HNSW index file (empty is okay for this test)
|
||||
index_file.write_bytes(b"")
|
||||
|
||||
yield str(index_file)
|
||||
|
||||
def test_query_template_applied_in_fallback_path(self, temp_index_with_template):
|
||||
"""Test that query template is applied when using fallback (direct) path"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
# Create a concrete implementation for testing
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
# Load metadata
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
# Mock compute_embeddings to capture the query text
|
||||
captured_queries = []
|
||||
|
||||
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||
captured_queries.extend(texts)
|
||||
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||
|
||||
with patch(
|
||||
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||
):
|
||||
# Call compute_query_embedding with template (fallback path)
|
||||
result = searcher.compute_query_embedding(
|
||||
query="vector database",
|
||||
use_server_if_available=False, # Force fallback path
|
||||
query_template="task: search result | query: ",
|
||||
)
|
||||
|
||||
# Verify template was applied
|
||||
assert len(captured_queries) == 1
|
||||
assert captured_queries[0] == "task: search result | query: vector database"
|
||||
assert result.shape == (1, 768)
|
||||
|
||||
def test_query_template_applied_in_server_path(self, temp_index_with_template):
|
||||
"""Test that query template is applied when using server path"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
# Create a concrete implementation for testing
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
# Load metadata
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
# Mock the server methods to capture the query text
|
||||
captured_queries = []
|
||||
|
||||
def mock_ensure_server_running(passages_file, port):
|
||||
return port
|
||||
|
||||
def mock_compute_embedding_via_server(chunks, port):
|
||||
captured_queries.extend(chunks)
|
||||
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||
|
||||
searcher._ensure_server_running = mock_ensure_server_running
|
||||
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||
|
||||
# Call compute_query_embedding with template (server path)
|
||||
result = searcher.compute_query_embedding(
|
||||
query="vector database",
|
||||
use_server_if_available=True, # Use server path
|
||||
query_template="task: search result | query: ",
|
||||
)
|
||||
|
||||
# Verify template was applied BEFORE calling server
|
||||
assert len(captured_queries) == 1
|
||||
assert captured_queries[0] == "task: search result | query: vector database"
|
||||
assert result.shape == (1, 768)
|
||||
|
||||
def test_query_template_without_template_parameter(self, temp_index_with_template):
|
||||
"""Test that query is unchanged when no template is provided"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
captured_queries = []
|
||||
|
||||
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||
captured_queries.extend(texts)
|
||||
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||
|
||||
with patch(
|
||||
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||
):
|
||||
searcher.compute_query_embedding(
|
||||
query="vector database",
|
||||
use_server_if_available=False,
|
||||
query_template=None, # No template
|
||||
)
|
||||
|
||||
# Verify query is unchanged
|
||||
assert len(captured_queries) == 1
|
||||
assert captured_queries[0] == "vector database"
|
||||
|
||||
def test_query_template_consistency_between_paths(self, temp_index_with_template):
|
||||
"""Test that both paths apply template identically"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
query_template = "task: search result | query: "
|
||||
original_query = "vector database"
|
||||
|
||||
# Capture queries from fallback path
|
||||
fallback_queries = []
|
||||
|
||||
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||
fallback_queries.extend(texts)
|
||||
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||
|
||||
with patch(
|
||||
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||
):
|
||||
searcher.compute_query_embedding(
|
||||
query=original_query,
|
||||
use_server_if_available=False,
|
||||
query_template=query_template,
|
||||
)
|
||||
|
||||
# Capture queries from server path
|
||||
server_queries = []
|
||||
|
||||
def mock_ensure_server_running(passages_file, port):
|
||||
return port
|
||||
|
||||
def mock_compute_embedding_via_server(chunks, port):
|
||||
server_queries.extend(chunks)
|
||||
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||
|
||||
searcher._ensure_server_running = mock_ensure_server_running
|
||||
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||
|
||||
searcher.compute_query_embedding(
|
||||
query=original_query,
|
||||
use_server_if_available=True,
|
||||
query_template=query_template,
|
||||
)
|
||||
|
||||
# Verify both paths produced identical templated queries
|
||||
assert len(fallback_queries) == 1
|
||||
assert len(server_queries) == 1
|
||||
assert fallback_queries[0] == server_queries[0]
|
||||
assert fallback_queries[0] == f"{query_template}{original_query}"
|
||||
|
||||
def test_query_template_with_empty_string(self, temp_index_with_template):
|
||||
"""Test behavior with empty template string"""
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
searcher.index_path = Path(temp_index_with_template)
|
||||
searcher.index_dir = searcher.index_path.parent
|
||||
|
||||
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||
with open(meta_file) as f:
|
||||
searcher.meta = json.load(f)
|
||||
|
||||
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||
|
||||
captured_queries = []
|
||||
|
||||
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||
captured_queries.extend(texts)
|
||||
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||
|
||||
with patch(
|
||||
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||
):
|
||||
searcher.compute_query_embedding(
|
||||
query="vector database",
|
||||
use_server_if_available=False,
|
||||
query_template="", # Empty string
|
||||
)
|
||||
|
||||
# Empty string is falsy, so no template should be applied
|
||||
assert captured_queries[0] == "vector database"
|
||||
@@ -1,643 +0,0 @@
|
||||
"""Unit tests for token-aware truncation functionality.
|
||||
|
||||
This test suite defines the contract for token truncation functions that prevent
|
||||
500 errors from Ollama when text exceeds model token limits. These tests verify:
|
||||
|
||||
1. Model token limit retrieval (known and unknown models)
|
||||
2. Text truncation behavior for single and multiple texts
|
||||
3. Token counting and truncation accuracy using tiktoken
|
||||
|
||||
All tests are written in Red Phase - they should FAIL initially because the
|
||||
implementation does not exist yet.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tiktoken
|
||||
from leann.embedding_compute import (
|
||||
EMBEDDING_MODEL_LIMITS,
|
||||
get_model_token_limit,
|
||||
truncate_to_token_limit,
|
||||
)
|
||||
|
||||
|
||||
class TestModelTokenLimits:
|
||||
"""Tests for retrieving model-specific token limits."""
|
||||
|
||||
def test_get_model_token_limit_known_model(self):
|
||||
"""Verify correct token limit is returned for known models.
|
||||
|
||||
Known models should return their specific token limits from
|
||||
EMBEDDING_MODEL_LIMITS dictionary.
|
||||
"""
|
||||
# Test nomic-embed-text (2048 tokens)
|
||||
limit = get_model_token_limit("nomic-embed-text")
|
||||
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
|
||||
|
||||
# Test nomic-embed-text-v1.5 (2048 tokens)
|
||||
limit = get_model_token_limit("nomic-embed-text-v1.5")
|
||||
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
|
||||
|
||||
# Test nomic-embed-text-v2 (512 tokens)
|
||||
limit = get_model_token_limit("nomic-embed-text-v2")
|
||||
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
|
||||
|
||||
# Test OpenAI models (8192 tokens)
|
||||
limit = get_model_token_limit("text-embedding-3-small")
|
||||
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
|
||||
|
||||
def test_get_model_token_limit_unknown_model(self):
|
||||
"""Verify default token limit is returned for unknown models.
|
||||
|
||||
Unknown models should return the default limit (2048) to allow
|
||||
operation with reasonable safety margin.
|
||||
"""
|
||||
# Test with completely unknown model
|
||||
limit = get_model_token_limit("unknown-model-xyz")
|
||||
assert limit == 2048, "Unknown models should return default 2048"
|
||||
|
||||
# Test with empty string
|
||||
limit = get_model_token_limit("")
|
||||
assert limit == 2048, "Empty model name should return default 2048"
|
||||
|
||||
def test_get_model_token_limit_custom_default(self):
|
||||
"""Verify custom default can be specified for unknown models.
|
||||
|
||||
Allow callers to specify their own default token limit when
|
||||
model is not in the known models dictionary.
|
||||
"""
|
||||
limit = get_model_token_limit("unknown-model", default=4096)
|
||||
assert limit == 4096, "Should return custom default for unknown models"
|
||||
|
||||
# Known model should ignore custom default
|
||||
limit = get_model_token_limit("nomic-embed-text", default=4096)
|
||||
assert limit == 2048, "Known model should ignore custom default"
|
||||
|
||||
def test_embedding_model_limits_dictionary_exists(self):
|
||||
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
|
||||
|
||||
The dictionary should be importable and contain at least the
|
||||
known nomic models with correct token limits.
|
||||
"""
|
||||
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
|
||||
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
|
||||
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
|
||||
"Should contain nomic-embed-text-v1.5"
|
||||
)
|
||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
|
||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
|
||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
|
||||
# OpenAI models
|
||||
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
|
||||
|
||||
|
||||
class TestTokenTruncation:
|
||||
"""Tests for truncating texts to token limits."""
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self):
|
||||
"""Provide tiktoken tokenizer for token counting verification."""
|
||||
return tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def test_truncate_single_text_under_limit(self, tokenizer):
|
||||
"""Verify text under token limit remains unchanged.
|
||||
|
||||
When text is already within the token limit, it should be
|
||||
returned unchanged with no truncation.
|
||||
"""
|
||||
text = "This is a short text that is well under the token limit."
|
||||
token_count = len(tokenizer.encode(text))
|
||||
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
|
||||
|
||||
# Truncate with generous limit
|
||||
result = truncate_to_token_limit([text], token_limit=512)
|
||||
|
||||
assert len(result) == 1, "Should return same number of texts"
|
||||
assert result[0] == text, "Text under limit should be unchanged"
|
||||
|
||||
def test_truncate_single_text_over_limit(self, tokenizer):
|
||||
"""Verify text over token limit is truncated correctly.
|
||||
|
||||
When text exceeds the token limit, it should be truncated to
|
||||
fit within the limit while maintaining valid token boundaries.
|
||||
"""
|
||||
# Create a text that definitely exceeds limit
|
||||
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
|
||||
original_token_count = len(tokenizer.encode(text))
|
||||
assert original_token_count > 50, (
|
||||
f"Test setup: text should be long (has {original_token_count} tokens)"
|
||||
)
|
||||
|
||||
# Truncate to 50 tokens
|
||||
result = truncate_to_token_limit([text], token_limit=50)
|
||||
|
||||
assert len(result) == 1, "Should return same number of texts"
|
||||
assert result[0] != text, "Text over limit should be truncated"
|
||||
assert len(result[0]) < len(text), "Truncated text should be shorter"
|
||||
|
||||
# Verify truncated text is within token limit
|
||||
truncated_token_count = len(tokenizer.encode(result[0]))
|
||||
assert truncated_token_count <= 50, (
|
||||
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
|
||||
)
|
||||
|
||||
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
|
||||
"""Verify multiple texts with mixed lengths are handled correctly.
|
||||
|
||||
When processing multiple texts:
|
||||
- Texts under limit should remain unchanged
|
||||
- Texts over limit should be truncated independently
|
||||
- Output list should maintain same order and length
|
||||
"""
|
||||
texts = [
|
||||
"Short text.", # Under limit
|
||||
"word " * 200, # Over limit
|
||||
"Another short one.", # Under limit
|
||||
"token " * 150, # Over limit
|
||||
]
|
||||
|
||||
# Verify test setup
|
||||
for i, text in enumerate(texts):
|
||||
token_count = len(tokenizer.encode(text))
|
||||
if i in [1, 3]:
|
||||
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
|
||||
else:
|
||||
assert token_count < 50, (
|
||||
f"Text {i} should be under limit (has {token_count} tokens)"
|
||||
)
|
||||
|
||||
# Truncate with 50 token limit
|
||||
result = truncate_to_token_limit(texts, token_limit=50)
|
||||
|
||||
assert len(result) == len(texts), "Should return same number of texts"
|
||||
|
||||
# Verify each text individually
|
||||
for i, (original, truncated) in enumerate(zip(texts, result)):
|
||||
token_count = len(tokenizer.encode(truncated))
|
||||
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
|
||||
|
||||
# Short texts should be unchanged
|
||||
if i in [0, 2]:
|
||||
assert truncated == original, f"Short text {i} should be unchanged"
|
||||
# Long texts should be truncated
|
||||
else:
|
||||
assert len(truncated) < len(original), f"Long text {i} should be truncated"
|
||||
|
||||
def test_truncate_empty_list(self):
|
||||
"""Verify empty input list returns empty output list.
|
||||
|
||||
Edge case: empty list should return empty list without errors.
|
||||
"""
|
||||
result = truncate_to_token_limit([], token_limit=512)
|
||||
assert result == [], "Empty input should return empty output"
|
||||
|
||||
def test_truncate_preserves_order(self, tokenizer):
|
||||
"""Verify truncation preserves original text order.
|
||||
|
||||
Output list should maintain the same order as input list,
|
||||
regardless of which texts were truncated.
|
||||
"""
|
||||
texts = [
|
||||
"First text " * 50, # Will be truncated
|
||||
"Second text.", # Won't be truncated
|
||||
"Third text " * 50, # Will be truncated
|
||||
]
|
||||
|
||||
result = truncate_to_token_limit(texts, token_limit=20)
|
||||
|
||||
assert len(result) == 3, "Should preserve list length"
|
||||
# Check that order is maintained by looking for distinctive words
|
||||
assert "First" in result[0], "First text should remain in first position"
|
||||
assert "Second" in result[1], "Second text should remain in second position"
|
||||
assert "Third" in result[2], "Third text should remain in third position"
|
||||
|
||||
def test_truncate_extremely_long_text(self, tokenizer):
|
||||
"""Verify extremely long texts are truncated efficiently.
|
||||
|
||||
Test with text that far exceeds token limit to ensure
|
||||
truncation handles extreme cases without performance issues.
|
||||
"""
|
||||
# Create very long text (simulate real-world scenario)
|
||||
text = "token " * 5000 # ~5000+ tokens
|
||||
original_token_count = len(tokenizer.encode(text))
|
||||
assert original_token_count > 1000, "Test setup: text should be very long"
|
||||
|
||||
# Truncate to small limit
|
||||
result = truncate_to_token_limit([text], token_limit=100)
|
||||
|
||||
assert len(result) == 1
|
||||
truncated_token_count = len(tokenizer.encode(result[0]))
|
||||
assert truncated_token_count <= 100, (
|
||||
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
|
||||
)
|
||||
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
|
||||
|
||||
def test_truncate_exact_token_limit(self, tokenizer):
|
||||
"""Verify text at exactly token limit is handled correctly.
|
||||
|
||||
Edge case: text with exactly the token limit should either
|
||||
remain unchanged or be safely truncated by 1 token.
|
||||
"""
|
||||
# Create text with approximately 50 tokens
|
||||
# We'll adjust to get exactly 50
|
||||
target_tokens = 50
|
||||
text = "word " * 50
|
||||
tokens = tokenizer.encode(text)
|
||||
|
||||
# Adjust to get exactly target_tokens
|
||||
if len(tokens) > target_tokens:
|
||||
tokens = tokens[:target_tokens]
|
||||
text = tokenizer.decode(tokens)
|
||||
elif len(tokens) < target_tokens:
|
||||
# Add more words
|
||||
while len(tokenizer.encode(text)) < target_tokens:
|
||||
text += "word "
|
||||
tokens = tokenizer.encode(text)[:target_tokens]
|
||||
text = tokenizer.decode(tokens)
|
||||
|
||||
# Verify we have exactly target_tokens
|
||||
assert len(tokenizer.encode(text)) == target_tokens, (
|
||||
"Test setup: should have exactly 50 tokens"
|
||||
)
|
||||
|
||||
result = truncate_to_token_limit([text], token_limit=target_tokens)
|
||||
|
||||
assert len(result) == 1
|
||||
result_tokens = len(tokenizer.encode(result[0]))
|
||||
assert result_tokens <= target_tokens, (
|
||||
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
||||
)
|
||||
|
||||
|
||||
class TestLMStudioHybridDiscovery:
|
||||
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
|
||||
|
||||
These tests verify that get_model_token_limit() properly integrates with
|
||||
the LM Studio SDK bridge for dynamic token limit discovery. The integration
|
||||
should:
|
||||
|
||||
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
|
||||
2. Convert HTTP URLs to WebSocket format for SDK queries
|
||||
3. Query LM Studio SDK and use discovered limit
|
||||
4. Fall back to registry when SDK returns None
|
||||
5. Execute AFTER Ollama detection but BEFORE registry fallback
|
||||
|
||||
All tests are written in Red Phase - they should FAIL initially because the
|
||||
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
|
||||
"""
|
||||
|
||||
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
|
||||
"""Verify LM Studio SDK query succeeds and returns detected limit.
|
||||
|
||||
When a LM Studio base_url is detected and the SDK query succeeds,
|
||||
get_model_token_limit() should return the dynamically discovered
|
||||
context length without falling back to the registry.
|
||||
"""
|
||||
|
||||
# Mock _query_lmstudio_context_limit to return successful SDK query
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
# Verify WebSocket URL was passed (not HTTP)
|
||||
assert base_url.startswith("ws://"), (
|
||||
f"Should convert HTTP to WebSocket format, got: {base_url}"
|
||||
)
|
||||
return 8192 # Successful SDK query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with HTTP URL that should be converted to WebSocket
|
||||
limit = get_model_token_limit(
|
||||
model_name="custom-model", base_url="http://localhost:1234/v1"
|
||||
)
|
||||
|
||||
assert limit == 8192, "Should return limit from LM Studio SDK query"
|
||||
|
||||
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
|
||||
"""Verify fallback to registry when LM Studio SDK returns None.
|
||||
|
||||
When LM Studio SDK query fails (returns None), get_model_token_limit()
|
||||
should fall back to the EMBEDDING_MODEL_LIMITS registry.
|
||||
"""
|
||||
|
||||
# Mock _query_lmstudio_context_limit to return None (SDK failure)
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
return None # SDK query failed
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with known model that exists in registry
|
||||
limit = get_model_token_limit(
|
||||
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
|
||||
)
|
||||
|
||||
# Should fall back to registry value
|
||||
assert limit == 2048, "Should fall back to registry when SDK returns None"
|
||||
|
||||
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
|
||||
"""Verify detection of LM Studio via port 1234.
|
||||
|
||||
get_model_token_limit() should recognize port 1234 as a LM Studio
|
||||
server and attempt SDK query, regardless of hostname.
|
||||
"""
|
||||
query_called = False
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal query_called
|
||||
query_called = True
|
||||
return 4096
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with port 1234 (default LM Studio port)
|
||||
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
|
||||
|
||||
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
|
||||
assert limit == 4096, "Should return SDK query result"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_url,expected_limit,keyword",
|
||||
[
|
||||
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
|
||||
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
|
||||
],
|
||||
)
|
||||
def test_get_model_token_limit_lmstudio_url_keyword_detection(
|
||||
self, monkeypatch, test_url, expected_limit, keyword
|
||||
):
|
||||
"""Verify detection of LM Studio via keywords in URL.
|
||||
|
||||
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
|
||||
in the URL as indicating a LM Studio server.
|
||||
"""
|
||||
query_called = False
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal query_called
|
||||
query_called = True
|
||||
return expected_limit
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
|
||||
|
||||
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
|
||||
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_url,expected_protocol,expected_host",
|
||||
[
|
||||
("http://localhost:1234/v1", "ws://", "localhost:1234"),
|
||||
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
|
||||
],
|
||||
)
|
||||
def test_get_model_token_limit_protocol_conversion(
|
||||
self, monkeypatch, input_url, expected_protocol, expected_host
|
||||
):
|
||||
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
|
||||
|
||||
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
|
||||
1. Convert 'http://' to 'ws://'
|
||||
2. Convert 'https://' to 'wss://'
|
||||
3. Remove '/v1' or other path suffixes (SDK expects base URL)
|
||||
4. Preserve host and port
|
||||
"""
|
||||
conversions_tested = []
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
conversions_tested.append(base_url)
|
||||
return 8192
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
get_model_token_limit(model_name="test-model", base_url=input_url)
|
||||
|
||||
# Verify conversion happened
|
||||
assert len(conversions_tested) == 1, "Should have called SDK query once"
|
||||
assert conversions_tested[0].startswith(expected_protocol), (
|
||||
f"Should convert to {expected_protocol}"
|
||||
)
|
||||
assert expected_host in conversions_tested[0], (
|
||||
f"Should preserve host and port: {expected_host}"
|
||||
)
|
||||
|
||||
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
|
||||
"""Verify LM Studio detection happens AFTER Ollama detection.
|
||||
|
||||
The hybrid discovery order should be:
|
||||
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
|
||||
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
|
||||
3. Registry fallback
|
||||
|
||||
If both Ollama and LM Studio patterns match, Ollama should take precedence.
|
||||
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
|
||||
"""
|
||||
ollama_called = False
|
||||
lmstudio_called = False
|
||||
|
||||
def mock_query_ollama(model_name, base_url):
|
||||
nonlocal ollama_called
|
||||
ollama_called = True
|
||||
return 2048 # Ollama query succeeds
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal lmstudio_called
|
||||
lmstudio_called = True
|
||||
return None # Should not be reached if Ollama succeeds
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_ollama_context_limit",
|
||||
mock_query_ollama,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with Ollama URL
|
||||
limit = get_model_token_limit(
|
||||
model_name="test-model", base_url="http://localhost:11434/api"
|
||||
)
|
||||
|
||||
assert ollama_called, "Should attempt Ollama query first"
|
||||
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
|
||||
assert limit == 2048, "Should return Ollama result"
|
||||
|
||||
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
|
||||
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
|
||||
|
||||
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
|
||||
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
|
||||
"""
|
||||
lmstudio_called = False
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal lmstudio_called
|
||||
lmstudio_called = True
|
||||
return 8192
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test with non-LM Studio URLs
|
||||
test_cases = [
|
||||
"http://localhost:8080/v1", # Different port
|
||||
"http://openai.example.com/v1", # Different service
|
||||
"http://localhost:3000/v1", # Another port
|
||||
]
|
||||
|
||||
for base_url in test_cases:
|
||||
lmstudio_called = False # Reset for each test
|
||||
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
|
||||
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
|
||||
|
||||
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
|
||||
"""Verify LM Studio detection is case-insensitive for keywords.
|
||||
|
||||
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
|
||||
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
|
||||
"""
|
||||
query_called = False
|
||||
|
||||
def mock_query_lmstudio(model_name, base_url):
|
||||
nonlocal query_called
|
||||
query_called = True
|
||||
return 8192
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||
mock_query_lmstudio,
|
||||
)
|
||||
|
||||
# Test various case variations
|
||||
test_cases = [
|
||||
"http://LMStudio.local:8080/v1",
|
||||
"http://LMSTUDIO.example.com/v1",
|
||||
"http://LmStudio.local/v1",
|
||||
"http://api.LM.STUDIO:5000/v1",
|
||||
]
|
||||
|
||||
for base_url in test_cases:
|
||||
query_called = False # Reset for each test
|
||||
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
|
||||
assert query_called, f"Should detect LM Studio in URL: {base_url}"
|
||||
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
|
||||
|
||||
|
||||
class TestTokenLimitCaching:
|
||||
"""Tests for token limit caching to prevent repeated SDK/API calls.
|
||||
|
||||
Caching prevents duplicate SDK/API calls within the same Python process,
|
||||
which is important because:
|
||||
1. LM Studio SDK load() can load duplicate model instances
|
||||
2. Ollama /api/show queries add latency
|
||||
3. Registry lookups are pure overhead
|
||||
|
||||
Cache is process-scoped and resets between leann build invocations.
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear cache before each test."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
_token_limit_cache.clear()
|
||||
|
||||
def test_registry_lookup_is_cached(self):
|
||||
"""Verify that registry lookups are cached."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
# First call
|
||||
limit1 = get_model_token_limit("text-embedding-3-small")
|
||||
assert limit1 == 8192
|
||||
|
||||
# Verify it's in cache
|
||||
cache_key = ("text-embedding-3-small", "")
|
||||
assert cache_key in _token_limit_cache
|
||||
assert _token_limit_cache[cache_key] == 8192
|
||||
|
||||
# Second call should use cache
|
||||
limit2 = get_model_token_limit("text-embedding-3-small")
|
||||
assert limit2 == 8192
|
||||
|
||||
def test_default_fallback_is_cached(self):
|
||||
"""Verify that default fallbacks are cached."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
# First call with unknown model
|
||||
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||
assert limit1 == 512
|
||||
|
||||
# Verify it's in cache
|
||||
cache_key = ("unknown-model-xyz", "")
|
||||
assert cache_key in _token_limit_cache
|
||||
assert _token_limit_cache[cache_key] == 512
|
||||
|
||||
# Second call should use cache
|
||||
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||
assert limit2 == 512
|
||||
|
||||
def test_different_urls_create_separate_cache_entries(self):
|
||||
"""Verify that different base_urls create separate cache entries."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
# Same model, different URLs
|
||||
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
|
||||
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
|
||||
|
||||
# Both should find the model in registry (2048)
|
||||
assert limit1 == 2048
|
||||
assert limit2 == 2048
|
||||
|
||||
# But they should be separate cache entries
|
||||
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
|
||||
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
|
||||
|
||||
assert cache_key1 in _token_limit_cache
|
||||
assert cache_key2 in _token_limit_cache
|
||||
assert len(_token_limit_cache) == 2
|
||||
|
||||
def test_cache_prevents_repeated_lookups(self):
|
||||
"""Verify that cache prevents repeated registry/API lookups."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
model_name = "text-embedding-ada-002"
|
||||
|
||||
# First call - should add to cache
|
||||
assert len(_token_limit_cache) == 0
|
||||
limit1 = get_model_token_limit(model_name)
|
||||
|
||||
cache_size_after_first = len(_token_limit_cache)
|
||||
assert cache_size_after_first == 1
|
||||
|
||||
# Multiple subsequent calls - cache size should not change
|
||||
for _ in range(5):
|
||||
limit = get_model_token_limit(model_name)
|
||||
assert limit == limit1
|
||||
assert len(_token_limit_cache) == cache_size_after_first
|
||||
|
||||
def test_versioned_model_names_cached_correctly(self):
|
||||
"""Verify that versioned model names (e.g., model:tag) are cached."""
|
||||
from leann.embedding_compute import _token_limit_cache
|
||||
|
||||
# Model with version tag
|
||||
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
|
||||
assert limit == 2048
|
||||
|
||||
# Should be cached with full name including version
|
||||
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
|
||||
assert cache_key in _token_limit_cache
|
||||
assert _token_limit_cache[cache_key] == 2048
|
||||
Reference in New Issue
Block a user