Compare commits
46 Commits
fix/chunki
...
fix/drop-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e6aa34afd | ||
|
|
5791367d13 | ||
|
|
674977a950 | ||
|
|
56785d30ee | ||
|
|
a73640f95e | ||
|
|
47b91f7313 | ||
|
|
7601e0b112 | ||
|
|
2a22ec1b26 | ||
|
|
530507d39d | ||
|
|
8a2ea37871 | ||
|
|
7ddb4772c0 | ||
|
|
a1c21adbce | ||
|
|
d1b3c93a5a | ||
|
|
a6ee95b18a | ||
|
|
17cbd07b25 | ||
|
|
3629ccf8f7 | ||
|
|
0175bc9c20 | ||
|
|
af47dfdde7 | ||
|
|
f13bd02fbd | ||
|
|
a0bbf831db | ||
|
|
86287d8832 | ||
|
|
76cc798e3e | ||
|
|
d599566fd7 | ||
|
|
00770aebbb | ||
|
|
e268392d5b | ||
|
|
eb909ccec5 | ||
|
|
13beb98164 | ||
|
|
969f514564 | ||
|
|
1ef9cba7de | ||
|
|
a63550944b | ||
|
|
97493a2896 | ||
|
|
f7d2dc6e7c | ||
|
|
ea86b283cb | ||
|
|
e7519bceaa | ||
|
|
abf0b2c676 | ||
|
|
3c4785bb63 | ||
|
|
930b79cc98 | ||
|
|
9b7353f336 | ||
|
|
9dd0e0b26f | ||
|
|
3766ad1fd2 | ||
|
|
c3aceed1e0 | ||
|
|
dc6c9f696e | ||
|
|
2406c41eef | ||
|
|
d4f5f2896f | ||
|
|
366984e92e | ||
|
|
64b92a04a7 |
95
.github/workflows/build-reusable.yml
vendored
95
.github/workflows/build-reusable.yml
vendored
@@ -35,8 +35,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- os: ubuntu-22.04
|
# Note: Python 3.9 dropped - uses PEP 604 union syntax (str | None)
|
||||||
python: '3.9'
|
# which requires Python 3.10+
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
@@ -46,8 +46,6 @@ jobs:
|
|||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
# ARM64 Linux builds
|
# ARM64 Linux builds
|
||||||
- os: ubuntu-24.04-arm
|
|
||||||
python: '3.9'
|
|
||||||
- os: ubuntu-24.04-arm
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: ubuntu-24.04-arm
|
- os: ubuntu-24.04-arm
|
||||||
@@ -56,8 +54,6 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-24.04-arm
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
- os: macos-14
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
@@ -66,8 +62,6 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
- os: macos-15
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-15
|
- os: macos-15
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: macos-15
|
- os: macos-15
|
||||||
@@ -76,16 +70,24 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: macos-15
|
- os: macos-15
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
- os: macos-13
|
# Intel Mac builds (x86_64) - replaces deprecated macos-13
|
||||||
python: '3.9'
|
# Note: Python 3.13 excluded - PyTorch has no wheels for macOS x86_64 + Python 3.13
|
||||||
- os: macos-13
|
# (PyTorch <=2.4.1 lacks cp313, PyTorch >=2.5.0 dropped Intel Mac support)
|
||||||
|
- os: macos-15-intel
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: macos-13
|
- os: macos-15-intel
|
||||||
python: '3.11'
|
python: '3.11'
|
||||||
- os: macos-13
|
- os: macos-15-intel
|
||||||
python: '3.12'
|
python: '3.12'
|
||||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
# macOS 26 (beta) - arm64
|
||||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
- os: macos-26
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-26
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-26
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-26
|
||||||
|
python: '3.13'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -204,13 +206,16 @@ jobs:
|
|||||||
# Use system clang for better compatibility
|
# Use system clang for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# Homebrew libraries on each macOS version require matching minimum version
|
# Set deployment target based on runner
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||||
fi
|
fi
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
else
|
else
|
||||||
@@ -224,14 +229,16 @@ jobs:
|
|||||||
# Use system clang for better compatibility
|
# Use system clang for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
# Set deployment target based on runner
|
||||||
# But Homebrew libraries on each macOS version require matching minimum version
|
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||||
fi
|
fi
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
else
|
else
|
||||||
@@ -269,16 +276,19 @@ jobs:
|
|||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
run: |
|
run: |
|
||||||
# Determine deployment target based on runner OS
|
# Determine deployment target based on runner OS
|
||||||
# Must match the Homebrew libraries for each macOS version
|
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||||
HNSW_TARGET="13.0"
|
|
||||||
DISKANN_TARGET="13.3"
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
HNSW_TARGET="14.0"
|
|
||||||
DISKANN_TARGET="14.0"
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
HNSW_TARGET="15.0"
|
HNSW_TARGET="15.0"
|
||||||
DISKANN_TARGET="15.0"
|
DISKANN_TARGET="15.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||||
|
HNSW_TARGET="14.0"
|
||||||
|
DISKANN_TARGET="14.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||||
|
HNSW_TARGET="15.0"
|
||||||
|
DISKANN_TARGET="15.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||||
|
HNSW_TARGET="26.0"
|
||||||
|
DISKANN_TARGET="26.0"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Repair HNSW wheel
|
# Repair HNSW wheel
|
||||||
@@ -334,12 +344,15 @@ jobs:
|
|||||||
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/link-check.yml
vendored
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: lycheeverse/lychee-action@v2
|
- uses: lycheeverse/lychee-action@v2
|
||||||
with:
|
with:
|
||||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -91,7 +91,8 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
|
|
||||||
*.meta.json
|
*.meta.json
|
||||||
*.passages.json
|
*.passages.json
|
||||||
|
*.npy
|
||||||
|
*.db
|
||||||
batchtest.py
|
batchtest.py
|
||||||
tests/__pytest_cache__/
|
tests/__pytest_cache__/
|
||||||
tests/__pycache__/
|
tests/__pycache__/
|
||||||
|
|||||||
75
README.md
75
README.md
@@ -16,15 +16,27 @@
|
|||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
|
||||||
|
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
|
||||||
|
</a>
|
||||||
|
<p>
|
||||||
|
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
|
||||||
|
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
|
||||||
|
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
The smallest vector index in the world. RAG Everything with LEANN!
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||||
@@ -189,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
|
|||||||
|
|
||||||
#### LLM Backend
|
#### LLM Backend
|
||||||
|
|
||||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
|
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -257,6 +269,7 @@ Below is a list of base URLs for common providers to get you started.
|
|||||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||||
|
| **Anthropic** | `https://api.anthropic.com/v1` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -316,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
|||||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||||
|
|
||||||
# LLM Parameters (Text generation models)
|
# LLM Parameters (Text generation models)
|
||||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
|
||||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||||
|
|
||||||
@@ -379,6 +392,54 @@ python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authenticat
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### 🎨 ColQwen: Multimodal PDF Retrieval with Vision-Language Models
|
||||||
|
|
||||||
|
Search through PDFs using both text and visual understanding with ColQwen2/ColPali models. Perfect for research papers, technical documents, and any PDFs with complex layouts, figures, or diagrams.
|
||||||
|
|
||||||
|
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build index from PDFs
|
||||||
|
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||||
|
|
||||||
|
# Search with text queries
|
||||||
|
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||||
|
|
||||||
|
# Interactive Q&A
|
||||||
|
python -m apps.colqwen_rag ask research_papers --interactive
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: ColQwen Setup & Usage</strong></summary>
|
||||||
|
|
||||||
|
#### Prerequisites
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||||
|
brew install poppler # macOS only, for PDF processing
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Build Index
|
||||||
|
```bash
|
||||||
|
python -m apps.colqwen_rag build \
|
||||||
|
--pdfs ./pdf_directory/ \
|
||||||
|
--index my_index \
|
||||||
|
--model colqwen2 # or colpali
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Search
|
||||||
|
```bash
|
||||||
|
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Models
|
||||||
|
- **ColQwen2** (`colqwen2`): Latest vision-language model with improved performance
|
||||||
|
- **ColPali** (`colpali`): Proven multimodal retriever
|
||||||
|
|
||||||
|
For detailed usage, see the [ColQwen Guide](docs/COLQWEN_GUIDE.md).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||||
|
|
||||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||||
@@ -1045,10 +1106,10 @@ Options:
|
|||||||
leann ask INDEX_NAME [OPTIONS]
|
leann ask INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
|
||||||
--model MODEL Model name (default: qwen3:8b)
|
--model MODEL Model name (default: qwen3:8b)
|
||||||
--interactive Interactive chat mode
|
--interactive Interactive chat mode
|
||||||
--top-k N Retrieval count (default: 20)
|
--top-k N Retrieval count (default: 20)
|
||||||
```
|
```
|
||||||
|
|
||||||
**List Command:**
|
**List Command:**
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Provides common parameters and functionality for all RAG examples.
|
|||||||
import argparse
|
import argparse
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
@@ -257,8 +257,8 @@ class BaseRAGExample(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
|
||||||
"""Load data from the source. Returns list of text chunks."""
|
"""Load data from the source. Returns list of text chunks (strings or dicts with 'text' key)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_llm_config(self, args) -> dict[str, Any]:
|
def get_llm_config(self, args) -> dict[str, Any]:
|
||||||
@@ -282,8 +282,8 @@ class BaseRAGExample(ABC):
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
async def build_index(self, args, texts: list[str]) -> str:
|
async def build_index(self, args, texts: list[Union[str, dict[str, Any]]]) -> str:
|
||||||
"""Build LEANN index from texts."""
|
"""Build LEANN index from texts (accepts strings or dicts with 'text' key)."""
|
||||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
|
||||||
print(f"\n[Building Index] Creating {self.name} index...")
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
@@ -314,8 +314,14 @@ class BaseRAGExample(ABC):
|
|||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
batch = texts[i : i + batch_size]
|
batch = texts[i : i + batch_size]
|
||||||
for text in batch:
|
for item in batch:
|
||||||
builder.add_text(text)
|
# Handle both dict format (from create_text_chunks) and plain strings
|
||||||
|
if isinstance(item, dict):
|
||||||
|
text = item.get("text", "")
|
||||||
|
metadata = item.get("metadata")
|
||||||
|
builder.add_text(text, metadata)
|
||||||
|
else:
|
||||||
|
builder.add_text(item)
|
||||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||||
|
|
||||||
print("Building index structure...")
|
print("Building index structure...")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from pathlib import Path
|
|||||||
try:
|
try:
|
||||||
from leann.chunking_utils import (
|
from leann.chunking_utils import (
|
||||||
CODE_EXTENSIONS,
|
CODE_EXTENSIONS,
|
||||||
|
_traditional_chunks_as_dicts,
|
||||||
create_ast_chunks,
|
create_ast_chunks,
|
||||||
create_text_chunks,
|
create_text_chunks,
|
||||||
create_traditional_chunks,
|
create_traditional_chunks,
|
||||||
@@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
|||||||
sys.path.insert(0, str(leann_src))
|
sys.path.insert(0, str(leann_src))
|
||||||
from leann.chunking_utils import (
|
from leann.chunking_utils import (
|
||||||
CODE_EXTENSIONS,
|
CODE_EXTENSIONS,
|
||||||
|
_traditional_chunks_as_dicts,
|
||||||
create_ast_chunks,
|
create_ast_chunks,
|
||||||
create_text_chunks,
|
create_text_chunks,
|
||||||
create_traditional_chunks,
|
create_traditional_chunks,
|
||||||
@@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CODE_EXTENSIONS",
|
"CODE_EXTENSIONS",
|
||||||
|
"_traditional_chunks_as_dicts",
|
||||||
"create_ast_chunks",
|
"create_ast_chunks",
|
||||||
"create_text_chunks",
|
"create_text_chunks",
|
||||||
"create_traditional_chunks",
|
"create_traditional_chunks",
|
||||||
|
|||||||
364
apps/colqwen_rag.py
Normal file
364
apps/colqwen_rag.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
|
||||||
|
python -m apps.colqwen_rag search my_index "How does attention work?"
|
||||||
|
python -m apps.colqwen_rag ask my_index --interactive
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
# Add LEANN packages to path
|
||||||
|
_repo_root = Path(__file__).resolve().parents[1]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
import torch # noqa: E402
|
||||||
|
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
|
||||||
|
from pdf2image import convert_from_path # noqa: E402
|
||||||
|
from PIL import Image # noqa: E402
|
||||||
|
from torch.utils.data import DataLoader # noqa: E402
|
||||||
|
from tqdm import tqdm # noqa: E402
|
||||||
|
|
||||||
|
# Import the existing multi-vector implementation
|
||||||
|
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
|
||||||
|
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class ColQwenRAG:
|
||||||
|
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
|
||||||
|
|
||||||
|
def __init__(self, model_type: str = "colpali"):
|
||||||
|
"""
|
||||||
|
Initialize ColQwen RAG system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: "colqwen2" or "colpali"
|
||||||
|
"""
|
||||||
|
self.model_type = model_type
|
||||||
|
self.device = self._get_device()
|
||||||
|
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
|
||||||
|
if self.device.type == "mps":
|
||||||
|
self.dtype = torch.float32
|
||||||
|
elif self.device.type == "cuda":
|
||||||
|
self.dtype = torch.float16
|
||||||
|
else:
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
|
||||||
|
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
|
||||||
|
|
||||||
|
# Load model and processor with MPS-optimized settings
|
||||||
|
try:
|
||||||
|
if model_type == "colqwen2":
|
||||||
|
self.model_name = "vidore/colqwen2-v1.0"
|
||||||
|
if self.device.type == "mps":
|
||||||
|
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||||
|
self.model = ColQwen2.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
else:
|
||||||
|
self.model = ColQwen2.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map=self.device,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||||
|
else: # colpali
|
||||||
|
self.model_name = "vidore/colpali-v1.2"
|
||||||
|
if self.device.type == "mps":
|
||||||
|
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||||
|
self.model = ColPali.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
else:
|
||||||
|
self.model = ColPali.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map=self.device,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
|
||||||
|
except Exception as e:
|
||||||
|
if "memory" in str(e).lower() or "offload" in str(e).lower():
|
||||||
|
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self.dtype = torch.float32
|
||||||
|
|
||||||
|
if model_type == "colqwen2":
|
||||||
|
self.model = ColQwen2.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
else:
|
||||||
|
self.model = ColPali.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_device(self):
|
||||||
|
"""Auto-select best available device."""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Build multimodal index from PDF files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_paths: List of PDF file paths
|
||||||
|
index_name: Name for the index
|
||||||
|
pages_dir: Directory to save page images (optional)
|
||||||
|
"""
|
||||||
|
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
|
||||||
|
|
||||||
|
# Convert PDFs to images
|
||||||
|
all_images = []
|
||||||
|
all_metadata = []
|
||||||
|
|
||||||
|
if pages_dir:
|
||||||
|
os.makedirs(pages_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
|
||||||
|
try:
|
||||||
|
images = convert_from_path(pdf_path, dpi=150)
|
||||||
|
pdf_name = Path(pdf_path).stem
|
||||||
|
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
# Save image if pages_dir specified
|
||||||
|
if pages_dir:
|
||||||
|
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
|
||||||
|
image.save(image_path)
|
||||||
|
|
||||||
|
all_images.append(image)
|
||||||
|
all_metadata.append(
|
||||||
|
{
|
||||||
|
"pdf_path": pdf_path,
|
||||||
|
"pdf_name": pdf_name,
|
||||||
|
"page_number": i + 1,
|
||||||
|
"image_path": str(image_path) if pages_dir else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error processing {pdf_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
|
||||||
|
print(f"All metadata: {all_metadata}")
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
print("🧠 Generating embeddings...")
|
||||||
|
embeddings = self._embed_images(all_images)
|
||||||
|
|
||||||
|
# Build LEANN index
|
||||||
|
print("🔍 Building LEANN index...")
|
||||||
|
leann_mv = LeannMultiVector(
|
||||||
|
index_path=index_name,
|
||||||
|
dim=embeddings.shape[-1],
|
||||||
|
embedding_model_name=self.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create collection and insert data
|
||||||
|
leann_mv.create_collection()
|
||||||
|
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
|
||||||
|
data = {
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": metadata.get("image_path", ""),
|
||||||
|
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
|
||||||
|
}
|
||||||
|
leann_mv.insert(data)
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
leann_mv.create_index()
|
||||||
|
print(f"✅ Index '{index_name}' built successfully!")
|
||||||
|
|
||||||
|
return leann_mv
|
||||||
|
|
||||||
|
def search(self, index_name: str, query: str, top_k: int = 5):
|
||||||
|
"""
|
||||||
|
Search the index with a text query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name: Name of the index to search
|
||||||
|
query: Text query
|
||||||
|
top_k: Number of results to return
|
||||||
|
"""
|
||||||
|
print(f"🔍 Searching '{index_name}' for: '{query}'")
|
||||||
|
|
||||||
|
# Load index
|
||||||
|
leann_mv = LeannMultiVector(
|
||||||
|
index_path=index_name,
|
||||||
|
dim=128, # Will be updated when loading
|
||||||
|
embedding_model_name=self.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate query embedding
|
||||||
|
query_embedding = self._embed_query(query)
|
||||||
|
|
||||||
|
# Search (returns list of (score, doc_id) tuples)
|
||||||
|
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
print(f"\n📋 Top {len(search_results)} results:")
|
||||||
|
for i, (score, doc_id) in enumerate(search_results, 1):
|
||||||
|
# Get metadata for this doc_id (we need to load the metadata)
|
||||||
|
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
def ask(self, index_name: str, interactive: bool = False):
|
||||||
|
"""
|
||||||
|
Interactive Q&A with the indexed documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name: Name of the index to query
|
||||||
|
interactive: Whether to run in interactive mode
|
||||||
|
"""
|
||||||
|
print(f"💬 ColQwen Chat with '{index_name}'")
|
||||||
|
|
||||||
|
if interactive:
|
||||||
|
print("Type 'quit' to exit, 'help' for commands")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("\n🤔 Your question: ").strip()
|
||||||
|
if query.lower() in ["quit", "exit", "q"]:
|
||||||
|
break
|
||||||
|
elif query.lower() == "help":
|
||||||
|
print("Commands: quit/exit/q (exit), help (this message)")
|
||||||
|
continue
|
||||||
|
elif not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.search(index_name, query, top_k=3)
|
||||||
|
|
||||||
|
# TODO: Add answer generation with Qwen-VL
|
||||||
|
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 Goodbye!")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
query = input("🤔 Your question: ").strip()
|
||||||
|
if query:
|
||||||
|
self.search(index_name, query)
|
||||||
|
|
||||||
|
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
|
||||||
|
"""Generate embeddings for a list of images."""
|
||||||
|
dataset = ListDataset(images)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(dataloader, desc="Embedding images"):
|
||||||
|
batch_images = cast(list, batch)
|
||||||
|
batch_inputs = self.processor.process_images(batch_images).to(self.device)
|
||||||
|
batch_embeddings = self.model(**batch_inputs)
|
||||||
|
embeddings.append(batch_embeddings.cpu())
|
||||||
|
|
||||||
|
return torch.cat(embeddings, dim=0)
|
||||||
|
|
||||||
|
def _embed_query(self, query: str) -> torch.Tensor:
|
||||||
|
"""Generate embedding for a text query."""
|
||||||
|
with torch.no_grad():
|
||||||
|
query_inputs = self.processor.process_queries([query]).to(self.device)
|
||||||
|
query_embedding = self.model(**query_inputs)
|
||||||
|
return query_embedding.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
|
||||||
|
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
|
||||||
|
build_parser.add_argument("--index", required=True, help="Index name")
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||||
|
)
|
||||||
|
build_parser.add_argument("--pages-dir", help="Directory to save page images")
|
||||||
|
|
||||||
|
# Search command
|
||||||
|
search_parser = subparsers.add_parser("search", help="Search the index")
|
||||||
|
search_parser.add_argument("index", help="Index name")
|
||||||
|
search_parser.add_argument("query", help="Search query")
|
||||||
|
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ask command
|
||||||
|
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
|
||||||
|
ask_parser.add_argument("index", help="Index name")
|
||||||
|
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.command:
|
||||||
|
parser.print_help()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize ColQwen RAG
|
||||||
|
if args.command == "build":
|
||||||
|
colqwen = ColQwenRAG(args.model)
|
||||||
|
|
||||||
|
# Get PDF files
|
||||||
|
pdf_dir = Path(args.pdfs)
|
||||||
|
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
|
||||||
|
pdf_paths = [str(pdf_dir)]
|
||||||
|
elif pdf_dir.is_dir():
|
||||||
|
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
|
||||||
|
else:
|
||||||
|
print(f"❌ Invalid PDF path: {args.pdfs}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not pdf_paths:
|
||||||
|
print(f"❌ No PDF files found in {args.pdfs}")
|
||||||
|
return
|
||||||
|
|
||||||
|
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
|
||||||
|
|
||||||
|
elif args.command == "search":
|
||||||
|
colqwen = ColQwenRAG(args.model)
|
||||||
|
colqwen.search(args.index, args.query, args.top_k)
|
||||||
|
|
||||||
|
elif args.command == "ask":
|
||||||
|
colqwen = ColQwenRAG(args.model)
|
||||||
|
colqwen.ask(args.index, args.interactive)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -5,6 +5,7 @@ Supports PDF, TXT, MD, and other document formats.
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
@@ -51,7 +52,7 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
help="Enable AST-aware chunking for code files in the data directory",
|
help="Enable AST-aware chunking for code files in the data directory",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
|
||||||
"""Load documents and convert to text chunks."""
|
"""Load documents and convert to text chunks."""
|
||||||
print(f"Loading documents from: {args.data_dir}")
|
print(f"Loading documents from: {args.data_dir}")
|
||||||
if args.file_types:
|
if args.file_types:
|
||||||
|
|||||||
218
apps/image_rag.py
Normal file
218
apps/image_rag.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
CLIP Image RAG Application
|
||||||
|
|
||||||
|
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||||
|
You can index a directory of images and search them using text queries.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||||
|
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from apps.base_rag_example import BaseRAGExample
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRAG(BaseRAGExample):
|
||||||
|
"""
|
||||||
|
RAG application for images using CLIP embeddings.
|
||||||
|
|
||||||
|
This class provides a complete RAG pipeline for image data, including
|
||||||
|
CLIP embedding generation, indexing, and text-based image search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Image RAG",
|
||||||
|
description="RAG application for images using CLIP embeddings",
|
||||||
|
default_index_name="image_index",
|
||||||
|
)
|
||||||
|
# Override default embedding model to use CLIP
|
||||||
|
self.embedding_model_default = "clip-ViT-L-14"
|
||||||
|
self.embedding_mode_default = "sentence-transformers"
|
||||||
|
self._image_data: list[dict] = []
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add image-specific arguments."""
|
||||||
|
image_group = parser.add_argument_group("Image Parameters")
|
||||||
|
image_group.add_argument(
|
||||||
|
"--image-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing images to index",
|
||||||
|
)
|
||||||
|
image_group.add_argument(
|
||||||
|
"--image-extensions",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||||
|
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||||
|
)
|
||||||
|
image_group.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size for CLIP embedding generation (default: 32)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||||
|
self._image_data = self._load_images_and_embeddings(args)
|
||||||
|
return [entry["text"] for entry in self._image_data]
|
||||||
|
|
||||||
|
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||||
|
"""Helper to process images and produce embeddings/metadata."""
|
||||||
|
image_dir = Path(args.image_dir)
|
||||||
|
if not image_dir.exists():
|
||||||
|
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||||
|
|
||||||
|
print(f"📸 Loading images from {image_dir}...")
|
||||||
|
|
||||||
|
# Find all image files
|
||||||
|
image_files = []
|
||||||
|
for ext in args.image_extensions:
|
||||||
|
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||||
|
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||||
|
|
||||||
|
if not image_files:
|
||||||
|
raise ValueError(
|
||||||
|
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Found {len(image_files)} images")
|
||||||
|
|
||||||
|
# Limit if max_items is set
|
||||||
|
if args.max_items > 0:
|
||||||
|
image_files = image_files[: args.max_items]
|
||||||
|
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||||
|
|
||||||
|
# Load CLIP model
|
||||||
|
print("🔍 Loading CLIP model...")
|
||||||
|
model = SentenceTransformer(self.embedding_model_default)
|
||||||
|
|
||||||
|
# Process images and generate embeddings
|
||||||
|
print("🖼️ Processing images and generating embeddings...")
|
||||||
|
image_data = []
|
||||||
|
batch_images = []
|
||||||
|
batch_paths = []
|
||||||
|
|
||||||
|
for image_path in tqdm(image_files, desc="Processing images"):
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
batch_images.append(image)
|
||||||
|
batch_paths.append(image_path)
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
if len(batch_images) >= args.batch_size:
|
||||||
|
embeddings = model.encode(
|
||||||
|
batch_images,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for img_path, embedding in zip(batch_paths, embeddings):
|
||||||
|
image_data.append(
|
||||||
|
{
|
||||||
|
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||||
|
"metadata": {
|
||||||
|
"image_path": str(img_path),
|
||||||
|
"image_name": img_path.name,
|
||||||
|
"image_dir": str(image_dir),
|
||||||
|
},
|
||||||
|
"embedding": embedding.astype(np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = []
|
||||||
|
batch_paths = []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process remaining images
|
||||||
|
if batch_images:
|
||||||
|
embeddings = model.encode(
|
||||||
|
batch_images,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=len(batch_images),
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for img_path, embedding in zip(batch_paths, embeddings):
|
||||||
|
image_data.append(
|
||||||
|
{
|
||||||
|
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||||
|
"metadata": {
|
||||||
|
"image_path": str(img_path),
|
||||||
|
"image_name": img_path.name,
|
||||||
|
"image_dir": str(image_dir),
|
||||||
|
},
|
||||||
|
"embedding": embedding.astype(np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Processed {len(image_data)} images")
|
||||||
|
return image_data
|
||||||
|
|
||||||
|
async def build_index(self, args, texts: list[str]) -> str:
|
||||||
|
"""Build index using pre-computed CLIP embeddings."""
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if not self._image_data or len(self._image_data) != len(texts):
|
||||||
|
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||||
|
|
||||||
|
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend_name,
|
||||||
|
embedding_model=self.embedding_model_default,
|
||||||
|
embedding_mode=self.embedding_mode_default,
|
||||||
|
is_recompute=False,
|
||||||
|
distance_metric="cosine",
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
build_complexity=args.build_complexity,
|
||||||
|
is_compact=not args.no_compact,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text, data in zip(texts, self._image_data):
|
||||||
|
builder.add_text(text=text, metadata=data["metadata"])
|
||||||
|
|
||||||
|
ids = [str(i) for i in range(len(self._image_data))]
|
||||||
|
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||||
|
pickle.dump((ids, embeddings), f)
|
||||||
|
pkl_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||||
|
print(f"✅ Index built successfully at {index_path}")
|
||||||
|
return index_path
|
||||||
|
finally:
|
||||||
|
Path(pkl_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for the image RAG application."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
app = ImageRAG()
|
||||||
|
asyncio.run(app.run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
@@ -0,0 +1,132 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the current directory to path to import leann_multi_vector
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Ensure repo paths are importable
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# Set environment variable
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_image():
|
||||||
|
"""Create a simple test image."""
|
||||||
|
# Create a simple RGB image (800x600)
|
||||||
|
img = Image.new("RGB", (800, 600), color="white")
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def load_test_image_from_file():
|
||||||
|
"""Try to load an image from the indexes directory if available."""
|
||||||
|
# Try to find an existing image in the indexes directory
|
||||||
|
indexes_dir = Path(__file__).parent / "indexes"
|
||||||
|
|
||||||
|
# Look for images in common locations
|
||||||
|
possible_paths = [
|
||||||
|
indexes_dir / "vidore_fastplaid" / "images",
|
||||||
|
indexes_dir / "colvision_large.leann.images",
|
||||||
|
indexes_dir / "colvision.leann.images",
|
||||||
|
]
|
||||||
|
|
||||||
|
for img_dir in possible_paths:
|
||||||
|
if img_dir.exists():
|
||||||
|
# Find first image file
|
||||||
|
for ext in [".png", ".jpg", ".jpeg"]:
|
||||||
|
for img_file in img_dir.glob(f"*{ext}"):
|
||||||
|
print(f"Loading test image from: {img_file}")
|
||||||
|
return Image.open(img_file)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing ColQwen2 Forward Pass")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Step 1: Load or create test image
|
||||||
|
print("\n[Step 1] Loading test image...")
|
||||||
|
test_image = load_test_image_from_file()
|
||||||
|
if test_image is None:
|
||||||
|
print("No existing image found, creating a simple test image...")
|
||||||
|
test_image = create_test_image()
|
||||||
|
else:
|
||||||
|
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||||
|
|
||||||
|
# Convert to RGB if needed
|
||||||
|
if test_image.mode != "RGB":
|
||||||
|
test_image = test_image.convert("RGB")
|
||||||
|
print(f"✓ Converted to RGB: {test_image.size}")
|
||||||
|
|
||||||
|
# Step 2: Load model
|
||||||
|
print("\n[Step 2] Loading ColQwen2 model...")
|
||||||
|
try:
|
||||||
|
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||||
|
print(f"✓ Model loaded: {model_name}")
|
||||||
|
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||||
|
|
||||||
|
# Print model info
|
||||||
|
if hasattr(model, "device"):
|
||||||
|
print(f"✓ Model device: {model.device}")
|
||||||
|
if hasattr(model, "dtype"):
|
||||||
|
print(f"✓ Model dtype: {model.dtype}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error loading model: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Test forward pass
|
||||||
|
print("\n[Step 3] Running forward pass...")
|
||||||
|
try:
|
||||||
|
# Use the _embed_images function which handles batching and forward pass
|
||||||
|
images = [test_image]
|
||||||
|
print(f"Processing {len(images)} image(s)...")
|
||||||
|
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
|
||||||
|
print("✓ Forward pass completed!")
|
||||||
|
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||||
|
|
||||||
|
if len(doc_vecs) > 0:
|
||||||
|
emb = doc_vecs[0]
|
||||||
|
print(f"✓ Embedding shape: {emb.shape}")
|
||||||
|
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||||
|
print("✓ Embedding stats:")
|
||||||
|
print(f" - Min: {emb.min().item():.4f}")
|
||||||
|
print(f" - Max: {emb.max().item():.4f}")
|
||||||
|
print(f" - Mean: {emb.mean().item():.4f}")
|
||||||
|
print(f" - Std: {emb.std().item():.4f}")
|
||||||
|
|
||||||
|
# Check for NaN or Inf
|
||||||
|
if torch.isnan(emb).any():
|
||||||
|
print("⚠ Warning: Embedding contains NaN values!")
|
||||||
|
if torch.isinf(emb).any():
|
||||||
|
print("⚠ Warning: Embedding contains Inf values!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error during forward pass: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Test completed successfully!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,448 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||||
|
|
||||||
|
This script uses the interface from leann_multi_vector.py to:
|
||||||
|
1. Download ViDoRe v1 datasets
|
||||||
|
2. Build indexes (LEANN or Fast-Plaid)
|
||||||
|
3. Perform retrieval
|
||||||
|
4. Evaluate using NDCG metrics
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Evaluate all ViDoRe v1 tasks
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||||
|
|
||||||
|
# Evaluate specific task
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||||
|
|
||||||
|
# Use Fast-Plaid index
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||||
|
|
||||||
|
# Rebuild index
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from leann_multi_vector import (
|
||||||
|
ViDoReBenchmarkEvaluator,
|
||||||
|
_ensure_repo_paths_importable,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# ViDoRe v1 task configurations
|
||||||
|
# Prompts match MTEB task metadata prompts
|
||||||
|
VIDORE_V1_TASKS = {
|
||||||
|
"VidoreArxivQARetrieval": {
|
||||||
|
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||||
|
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreDocVQARetrieval": {
|
||||||
|
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||||
|
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreInfoVQARetrieval": {
|
||||||
|
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||||
|
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreTabfquadRetrieval": {
|
||||||
|
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||||
|
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreTatdqaRetrieval": {
|
||||||
|
"dataset_path": "vidore/tatdqa_test_beir",
|
||||||
|
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreShiftProjectRetrieval": {
|
||||||
|
"dataset_path": "vidore/shiftproject_test_beir",
|
||||||
|
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAAIRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||||
|
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||||
|
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||||
|
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||||
|
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Task name aliases (short names -> full names)
|
||||||
|
TASK_ALIASES = {
|
||||||
|
"arxivqa": "VidoreArxivQARetrieval",
|
||||||
|
"docvqa": "VidoreDocVQARetrieval",
|
||||||
|
"infovqa": "VidoreInfoVQARetrieval",
|
||||||
|
"tabfquad": "VidoreTabfquadRetrieval",
|
||||||
|
"tatdqa": "VidoreTatdqaRetrieval",
|
||||||
|
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||||
|
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||||
|
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||||
|
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||||
|
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_task_name(task_name: str) -> str:
|
||||||
|
"""Normalize task name (handle aliases)."""
|
||||||
|
task_name_lower = task_name.lower()
|
||||||
|
if task_name in VIDORE_V1_TASKS:
|
||||||
|
return task_name
|
||||||
|
if task_name_lower in TASK_ALIASES:
|
||||||
|
return TASK_ALIASES[task_name_lower]
|
||||||
|
# Try partial match
|
||||||
|
for alias, full_name in TASK_ALIASES.items():
|
||||||
|
if alias in task_name_lower or task_name_lower in alias:
|
||||||
|
return full_name
|
||||||
|
return task_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_model_name(model_name: str) -> str:
|
||||||
|
"""Get a safe model name for use in file paths."""
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
# If it's a path, use basename or hash
|
||||||
|
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||||
|
# Use basename if it's reasonable, otherwise use hash
|
||||||
|
basename = os.path.basename(model_name.rstrip("/"))
|
||||||
|
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||||
|
return basename
|
||||||
|
# Use hash for very long or problematic paths
|
||||||
|
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||||
|
# For HuggingFace model names, replace / with _
|
||||||
|
return model_name.replace("/", "_").replace(":", "_")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vidore_v1_data(
|
||||||
|
dataset_path: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
split: str = "test",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load ViDoRe v1 dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
|
queries: dict mapping query_id to query text
|
||||||
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
|
"""
|
||||||
|
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||||
|
|
||||||
|
# Load queries
|
||||||
|
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||||
|
|
||||||
|
queries = {}
|
||||||
|
for row in query_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
queries[query_id] = row["query"]
|
||||||
|
|
||||||
|
# Load corpus (images)
|
||||||
|
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||||
|
|
||||||
|
corpus = {}
|
||||||
|
for row in corpus_ds:
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
# Extract image from the dataset row
|
||||||
|
if "image" in row:
|
||||||
|
corpus[corpus_id] = row["image"]
|
||||||
|
elif "page_image" in row:
|
||||||
|
corpus[corpus_id] = row["page_image"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load qrels (relevance judgments)
|
||||||
|
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||||
|
|
||||||
|
qrels = {}
|
||||||
|
for row in qrels_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
if query_id not in qrels:
|
||||||
|
qrels[query_id] = {}
|
||||||
|
qrels[query_id][corpus_id] = int(row["score"])
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter qrels to only include queries that exist
|
||||||
|
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||||
|
|
||||||
|
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||||
|
# This is important for correct NDCG calculation
|
||||||
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||||
|
queries_filtered = {
|
||||||
|
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
return corpus, queries_filtered, qrels_filtered
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_task(
|
||||||
|
task_name: str,
|
||||||
|
model_name: str,
|
||||||
|
index_path: str,
|
||||||
|
use_fast_plaid: bool = False,
|
||||||
|
fast_plaid_index_path: Optional[str] = None,
|
||||||
|
rebuild_index: bool = False,
|
||||||
|
top_k: int = 1000,
|
||||||
|
first_stage_k: int = 500,
|
||||||
|
k_values: Optional[list[int]] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Evaluate a single ViDoRe v1 task.
|
||||||
|
"""
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Evaluating task: {task_name}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Normalize task name (handle aliases)
|
||||||
|
task_name = normalize_task_name(task_name)
|
||||||
|
|
||||||
|
# Get task config
|
||||||
|
if task_name not in VIDORE_V1_TASKS:
|
||||||
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||||
|
|
||||||
|
task_config = VIDORE_V1_TASKS[task_name]
|
||||||
|
dataset_path = task_config["dataset_path"]
|
||||||
|
revision = task_config["revision"]
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
corpus, queries, qrels = load_vidore_v1_data(
|
||||||
|
dataset_path=dataset_path,
|
||||||
|
revision=revision,
|
||||||
|
split="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize k_values if not provided
|
||||||
|
if k_values is None:
|
||||||
|
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||||
|
|
||||||
|
# Check if we have any queries
|
||||||
|
if len(queries) == 0:
|
||||||
|
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||||
|
# Return zero scores
|
||||||
|
scores = {}
|
||||||
|
for k in k_values:
|
||||||
|
scores[f"ndcg_at_{k}"] = 0.0
|
||||||
|
scores[f"map_at_{k}"] = 0.0
|
||||||
|
scores[f"recall_at_{k}"] = 0.0
|
||||||
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# Initialize evaluator
|
||||||
|
evaluator = ViDoReBenchmarkEvaluator(
|
||||||
|
model_name=model_name,
|
||||||
|
use_fast_plaid=use_fast_plaid,
|
||||||
|
top_k=top_k,
|
||||||
|
first_stage_k=first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build or load index
|
||||||
|
# Use safe model name for index path (different models need different indexes)
|
||||||
|
safe_model_name = get_safe_model_name(model_name)
|
||||||
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
|
if index_path_full is None:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||||
|
if use_fast_plaid:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||||
|
|
||||||
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
|
corpus=corpus,
|
||||||
|
index_path=index_path_full,
|
||||||
|
rebuild=rebuild_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search queries
|
||||||
|
task_prompt = task_config.get("prompt")
|
||||||
|
results = evaluator.search_queries(
|
||||||
|
queries=queries,
|
||||||
|
corpus_ids=corpus_ids_ordered,
|
||||||
|
index_or_retriever=index_or_retriever,
|
||||||
|
fast_plaid_index_path=fast_plaid_index_path,
|
||||||
|
task_prompt=task_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Results for {task_name}:")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for metric, value in scores.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
print(f" {metric}: {value:.5f}")
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
if output_dir:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||||
|
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||||
|
|
||||||
|
with open(results_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print(f"\nSaved results to: {results_file}")
|
||||||
|
|
||||||
|
with open(scores_file, "w") as f:
|
||||||
|
json.dump(scores, f, indent=2)
|
||||||
|
print(f"Saved scores to: {scores_file}")
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="colqwen2",
|
||||||
|
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tasks",
|
||||||
|
type=str,
|
||||||
|
default="all",
|
||||||
|
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to LEANN index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fast-plaid",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Fast-Plaid instead of LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fast-plaid-index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-index",
|
||||||
|
action="store_true",
|
||||||
|
help="Rebuild index even if it exists",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--first-stage-k",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="First stage k for LEANN search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k-values",
|
||||||
|
type=str,
|
||||||
|
default="1,3,5,10,20,100,1000",
|
||||||
|
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="./vidore_v1_results",
|
||||||
|
help="Output directory for results",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Parse k_values
|
||||||
|
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||||
|
|
||||||
|
# Determine tasks to evaluate
|
||||||
|
if args.task:
|
||||||
|
tasks_to_eval = [normalize_task_name(args.task)]
|
||||||
|
elif args.tasks.lower() == "all":
|
||||||
|
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||||
|
else:
|
||||||
|
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||||
|
|
||||||
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
|
# Evaluate each task
|
||||||
|
all_scores = {}
|
||||||
|
for task_name in tasks_to_eval:
|
||||||
|
try:
|
||||||
|
scores = evaluate_task(
|
||||||
|
task_name=task_name,
|
||||||
|
model_name=args.model,
|
||||||
|
index_path=args.index_path,
|
||||||
|
use_fast_plaid=args.use_fast_plaid,
|
||||||
|
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||||
|
rebuild_index=args.rebuild_index,
|
||||||
|
top_k=args.top_k,
|
||||||
|
first_stage_k=args.first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
)
|
||||||
|
all_scores[task_name] = scores
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError evaluating {task_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
if all_scores:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("SUMMARY")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for task_name, scores in all_scores.items():
|
||||||
|
print(f"\n{task_name}:")
|
||||||
|
# Print main metrics
|
||||||
|
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||||
|
if metric in scores:
|
||||||
|
print(f" {metric}: {scores[metric]:.5f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,439 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
|
||||||
|
|
||||||
|
This script uses the interface from leann_multi_vector.py to:
|
||||||
|
1. Download ViDoRe v2 datasets
|
||||||
|
2. Build indexes (LEANN or Fast-Plaid)
|
||||||
|
3. Perform retrieval
|
||||||
|
4. Evaluate using NDCG metrics
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Evaluate all ViDoRe v2 tasks
|
||||||
|
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||||
|
|
||||||
|
# Evaluate specific task
|
||||||
|
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||||
|
|
||||||
|
# Use Fast-Plaid index
|
||||||
|
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||||
|
|
||||||
|
# Rebuild index
|
||||||
|
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from leann_multi_vector import (
|
||||||
|
ViDoReBenchmarkEvaluator,
|
||||||
|
_ensure_repo_paths_importable,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# Language name to dataset language field value mapping
|
||||||
|
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
|
||||||
|
LANGUAGE_MAPPING = {
|
||||||
|
"english": "eng-Latn",
|
||||||
|
"french": "fra-Latn",
|
||||||
|
"spanish": "spa-Latn",
|
||||||
|
"german": "deu-Latn",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ViDoRe v2 task configurations
|
||||||
|
# Prompts match MTEB task metadata prompts
|
||||||
|
VIDORE_V2_TASKS = {
|
||||||
|
"Vidore2ESGReportsRetrieval": {
|
||||||
|
"dataset_path": "vidore/esg_reports_v2",
|
||||||
|
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
|
||||||
|
"languages": ["french", "spanish", "english", "german"],
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"Vidore2EconomicsReportsRetrieval": {
|
||||||
|
"dataset_path": "vidore/economics_reports_v2",
|
||||||
|
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
|
||||||
|
"languages": ["french", "spanish", "english", "german"],
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"Vidore2BioMedicalLecturesRetrieval": {
|
||||||
|
"dataset_path": "vidore/biomedical_lectures_v2",
|
||||||
|
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
|
||||||
|
"languages": ["french", "spanish", "english", "german"],
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"Vidore2ESGReportsHLRetrieval": {
|
||||||
|
"dataset_path": "vidore/esg_reports_human_labeled_v2",
|
||||||
|
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
|
||||||
|
# Note: This dataset doesn't have language filtering - all queries are English
|
||||||
|
"languages": None, # No language filtering needed
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_vidore_v2_data(
|
||||||
|
dataset_path: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
split: str = "test",
|
||||||
|
language: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load ViDoRe v2 dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
|
queries: dict mapping query_id to query text
|
||||||
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
|
"""
|
||||||
|
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||||
|
|
||||||
|
# Load queries
|
||||||
|
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||||
|
|
||||||
|
# Check if dataset has language field before filtering
|
||||||
|
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||||
|
|
||||||
|
if language and has_language_field:
|
||||||
|
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||||
|
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||||
|
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||||
|
# Check if filtering resulted in empty dataset
|
||||||
|
if len(query_ds_filtered) == 0:
|
||||||
|
print(
|
||||||
|
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||||
|
)
|
||||||
|
# Try with original language value (dataset might use simple names like 'english')
|
||||||
|
print(f"Trying with original language value '{language}'...")
|
||||||
|
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||||
|
if len(query_ds_filtered) == 0:
|
||||||
|
# Try to get a sample to see actual language values
|
||||||
|
try:
|
||||||
|
sample_ds = load_dataset(
|
||||||
|
dataset_path, "queries", split=split, revision=revision
|
||||||
|
)
|
||||||
|
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||||
|
sample_langs = set(sample_ds["language"])
|
||||||
|
print(f"Available language values in dataset: {sample_langs}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||||
|
)
|
||||||
|
query_ds = query_ds_filtered
|
||||||
|
|
||||||
|
queries = {}
|
||||||
|
for row in query_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
queries[query_id] = row["query"]
|
||||||
|
|
||||||
|
# Load corpus (images)
|
||||||
|
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||||
|
|
||||||
|
corpus = {}
|
||||||
|
for row in corpus_ds:
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
# Extract image from the dataset row
|
||||||
|
if "image" in row:
|
||||||
|
corpus[corpus_id] = row["image"]
|
||||||
|
elif "page_image" in row:
|
||||||
|
corpus[corpus_id] = row["page_image"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load qrels (relevance judgments)
|
||||||
|
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||||
|
|
||||||
|
qrels = {}
|
||||||
|
for row in qrels_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
if query_id not in qrels:
|
||||||
|
qrels[query_id] = {}
|
||||||
|
qrels[query_id][corpus_id] = int(row["score"])
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter qrels to only include queries that exist
|
||||||
|
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||||
|
|
||||||
|
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||||
|
# This is important for correct NDCG calculation
|
||||||
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||||
|
queries_filtered = {
|
||||||
|
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
return corpus, queries_filtered, qrels_filtered
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_task(
|
||||||
|
task_name: str,
|
||||||
|
model_name: str,
|
||||||
|
index_path: str,
|
||||||
|
use_fast_plaid: bool = False,
|
||||||
|
fast_plaid_index_path: Optional[str] = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
rebuild_index: bool = False,
|
||||||
|
top_k: int = 100,
|
||||||
|
first_stage_k: int = 500,
|
||||||
|
k_values: Optional[list[int]] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Evaluate a single ViDoRe v2 task.
|
||||||
|
"""
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Evaluating task: {task_name}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Get task config
|
||||||
|
if task_name not in VIDORE_V2_TASKS:
|
||||||
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||||
|
|
||||||
|
task_config = VIDORE_V2_TASKS[task_name]
|
||||||
|
dataset_path = task_config["dataset_path"]
|
||||||
|
revision = task_config["revision"]
|
||||||
|
|
||||||
|
# Determine language
|
||||||
|
if language is None:
|
||||||
|
# Use first language if multiple available
|
||||||
|
languages = task_config.get("languages")
|
||||||
|
if languages is None:
|
||||||
|
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||||
|
language = None
|
||||||
|
elif len(languages) == 1:
|
||||||
|
language = languages[0]
|
||||||
|
else:
|
||||||
|
language = None
|
||||||
|
|
||||||
|
# Initialize k_values if not provided
|
||||||
|
if k_values is None:
|
||||||
|
k_values = [1, 3, 5, 10, 100]
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
corpus, queries, qrels = load_vidore_v2_data(
|
||||||
|
dataset_path=dataset_path,
|
||||||
|
revision=revision,
|
||||||
|
split="test",
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if we have any queries
|
||||||
|
if len(queries) == 0:
|
||||||
|
print(
|
||||||
|
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||||
|
)
|
||||||
|
# Return zero scores
|
||||||
|
scores = {}
|
||||||
|
for k in k_values:
|
||||||
|
scores[f"ndcg_at_{k}"] = 0.0
|
||||||
|
scores[f"map_at_{k}"] = 0.0
|
||||||
|
scores[f"recall_at_{k}"] = 0.0
|
||||||
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# Initialize evaluator
|
||||||
|
evaluator = ViDoReBenchmarkEvaluator(
|
||||||
|
model_name=model_name,
|
||||||
|
use_fast_plaid=use_fast_plaid,
|
||||||
|
top_k=top_k,
|
||||||
|
first_stage_k=first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build or load index
|
||||||
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
|
if index_path_full is None:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||||
|
if use_fast_plaid:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||||
|
|
||||||
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
|
corpus=corpus,
|
||||||
|
index_path=index_path_full,
|
||||||
|
rebuild=rebuild_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search queries
|
||||||
|
task_prompt = task_config.get("prompt")
|
||||||
|
results = evaluator.search_queries(
|
||||||
|
queries=queries,
|
||||||
|
corpus_ids=corpus_ids_ordered,
|
||||||
|
index_or_retriever=index_or_retriever,
|
||||||
|
fast_plaid_index_path=fast_plaid_index_path,
|
||||||
|
task_prompt=task_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Results for {task_name}:")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for metric, value in scores.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
print(f" {metric}: {value:.5f}")
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
if output_dir:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||||
|
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||||
|
|
||||||
|
with open(results_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print(f"\nSaved results to: {results_file}")
|
||||||
|
|
||||||
|
with open(scores_file, "w") as f:
|
||||||
|
json.dump(scores, f, indent=2)
|
||||||
|
print(f"Saved scores to: {scores_file}")
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="colqwen2",
|
||||||
|
choices=["colqwen2", "colpali"],
|
||||||
|
help="Model to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tasks",
|
||||||
|
type=str,
|
||||||
|
default="all",
|
||||||
|
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to LEANN index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fast-plaid",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Fast-Plaid instead of LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fast-plaid-index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-index",
|
||||||
|
action="store_true",
|
||||||
|
help="Rebuild index even if it exists",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Language to evaluate (default: first available)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Top-k results to retrieve",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--first-stage-k",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="First stage k for LEANN search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k-values",
|
||||||
|
type=str,
|
||||||
|
default="1,3,5,10,100",
|
||||||
|
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="./vidore_v2_results",
|
||||||
|
help="Output directory for results",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Parse k_values
|
||||||
|
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||||
|
|
||||||
|
# Determine tasks to evaluate
|
||||||
|
if args.task:
|
||||||
|
tasks_to_eval = [args.task]
|
||||||
|
elif args.tasks.lower() == "all":
|
||||||
|
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||||
|
else:
|
||||||
|
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||||
|
|
||||||
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
|
# Evaluate each task
|
||||||
|
all_scores = {}
|
||||||
|
for task_name in tasks_to_eval:
|
||||||
|
try:
|
||||||
|
scores = evaluate_task(
|
||||||
|
task_name=task_name,
|
||||||
|
model_name=args.model,
|
||||||
|
index_path=args.index_path,
|
||||||
|
use_fast_plaid=args.use_fast_plaid,
|
||||||
|
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||||
|
language=args.language,
|
||||||
|
rebuild_index=args.rebuild_index,
|
||||||
|
top_k=args.top_k,
|
||||||
|
first_stage_k=args.first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
)
|
||||||
|
all_scores[task_name] = scores
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError evaluating {task_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
if all_scores:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("SUMMARY")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for task_name, scores in all_scores.items():
|
||||||
|
print(f"\n{task_name}:")
|
||||||
|
# Print main metrics
|
||||||
|
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||||
|
if metric in scores:
|
||||||
|
print(f" {metric}: {scores[metric]:.5f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
|
|||||||
flexible message processing options.
|
flexible message processing options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -146,16 +147,16 @@ class SlackMCPReader:
|
|||||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
error_dict = eval(match.group(1))
|
error_dict = ast.literal_eval(match.group(1))
|
||||||
except (ValueError, SyntaxError, NameError):
|
except (ValueError, SyntaxError):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Try alternative format
|
# Try alternative format
|
||||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
error_dict = eval(match.group(1))
|
error_dict = ast.literal_eval(match.group(1))
|
||||||
except (ValueError, SyntaxError, NameError):
|
except (ValueError, SyntaxError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self._is_cache_sync_error(error_dict):
|
if self._is_cache_sync_error(error_dict):
|
||||||
|
|||||||
143
benchmarks/update/README.md
Normal file
143
benchmarks/update/README.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# Update Benchmarks
|
||||||
|
|
||||||
|
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||||
|
search” pipeline under different assumptions:
|
||||||
|
|
||||||
|
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||||
|
settings influence incremental `add()` latency when embeddings are fetched
|
||||||
|
over the ZMQ embedding server.
|
||||||
|
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||||
|
against an offline approach that keeps the graph static and fuses results.
|
||||||
|
|
||||||
|
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||||
|
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||||
|
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
### 1. HNSW RNG Recompute Benchmark
|
||||||
|
|
||||||
|
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||||
|
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||||
|
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||||
|
is enabled:
|
||||||
|
|
||||||
|
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||||
|
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||||
|
| `baseline` | Enabled | Enabled | Enabled |
|
||||||
|
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||||
|
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||||
|
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||||
|
|
||||||
|
For each scenario the script:
|
||||||
|
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||||
|
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||||
|
3. Appends the requested updates using the scenario’s RNG flags.
|
||||||
|
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||||
|
timings before appending a row to the CSV output.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||||
|
LEANN_LOG_LEVEL=INFO \
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--runs 1 \
|
||||||
|
--index-path .leann/bench/test.leann \
|
||||||
|
--initial-files data/PrideandPrejudice.txt \
|
||||||
|
--update-files data/huawei_pangu.md \
|
||||||
|
--max-initial 300 \
|
||||||
|
--max-updates 1 \
|
||||||
|
--add-timeout 120
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||||
|
(including ms/passage) for each run.
|
||||||
|
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||||
|
`LEANN_HNSW_LOG_PATH`).
|
||||||
|
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||||
|
|
||||||
|
### 2. Sequential vs. Offline Update Benchmark
|
||||||
|
|
||||||
|
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||||
|
same dataset:
|
||||||
|
|
||||||
|
- **Scenario A – Sequential Update**
|
||||||
|
- Start an embedding server.
|
||||||
|
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||||
|
mutates the HNSW graph.
|
||||||
|
- After all inserts, run a search on the updated graph.
|
||||||
|
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||||
|
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||||
|
latency.
|
||||||
|
|
||||||
|
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||||
|
- Stop Scenario A’s server and start a fresh embedding server.
|
||||||
|
- Spawn two threads: one generates embeddings for the new passages offline
|
||||||
|
(graph unchanged); the other computes the query embedding and searches the
|
||||||
|
existing graph.
|
||||||
|
- Merge offline similarities with the graph search results to emulate late
|
||||||
|
fusion, then report the merged top‑k preview.
|
||||||
|
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||||
|
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||||
|
|
||||||
|
**Run (both scenarios):**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 \
|
||||||
|
--num-updates 1
|
||||||
|
```
|
||||||
|
|
||||||
|
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||||
|
print timing summaries to stdout and append the results to CSV.
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||||
|
Scenario A and B.
|
||||||
|
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||||
|
checks.
|
||||||
|
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||||
|
|
||||||
|
### 3. Visualisation
|
||||||
|
|
||||||
|
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||||
|
benchmark into a single two-panel plot.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.plot_bench_results \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||||
|
- `--csv` – RNG benchmark results CSV (left panel).
|
||||||
|
- `--csv-right` – Update strategy results CSV (right panel).
|
||||||
|
- `--out` – Output image path (PNG/PDF supported).
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||||
|
suites.
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||||
|
slides/papers.
|
||||||
|
|
||||||
|
## Parameters & Environment
|
||||||
|
|
||||||
|
### Common CLI Flags
|
||||||
|
- `--max-initial` – Number of initial passages used to seed the index.
|
||||||
|
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||||
|
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||||
|
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||||
|
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||||
|
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||||
|
execution of the embedding model.
|
||||||
|
|
||||||
|
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||||
|
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||||
|
fusion better match your latency/accuracy trade-offs.
|
||||||
16
benchmarks/update/__init__.py
Normal file
16
benchmarks/update/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Benchmarks for LEANN update workflows."""
|
||||||
|
|
||||||
|
# Expose helper to locate repository root for other modules that need it.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_repo_root() -> Path:
|
||||||
|
"""Return the project root containing pyproject.toml."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
return current.parents[1]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["find_repo_root"]
|
||||||
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
@@ -0,0 +1,804 @@
|
|||||||
|
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||||
|
embedding recomputation.
|
||||||
|
|
||||||
|
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||||
|
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||||
|
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||||
|
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||||
|
|
||||||
|
Example usage (run from the repo root; downloads the model on first run)::
|
||||||
|
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--index-path .leann/bench/leann-demo.leann \
|
||||||
|
--runs 1
|
||||||
|
|
||||||
|
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||||
|
if you want a larger or different workload, and change the embedding model via
|
||||||
|
``--model-name``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||||
|
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_update_with_mode(
|
||||||
|
index_path: Path,
|
||||||
|
new_chunks: list[dict[str, Any]],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
disable_forward_rng: bool,
|
||||||
|
disable_reverse_rng: bool,
|
||||||
|
server_port: int,
|
||||||
|
add_timeout: int,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||||
|
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
with open(offset_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
existing_ids = set(offset_map.keys())
|
||||||
|
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
for chunk in new_chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
continue
|
||||||
|
metadata = chunk.setdefault("metadata", {})
|
||||||
|
passage_id = chunk.get("id") or metadata.get("id")
|
||||||
|
if passage_id and passage_id in existing_ids:
|
||||||
|
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
|
||||||
|
if not valid_chunks:
|
||||||
|
raise ValueError("No valid chunks to append.")
|
||||||
|
|
||||||
|
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts_to_embed,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
index = faiss.read_index(str(index_file))
|
||||||
|
index.is_recompute = True
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
try:
|
||||||
|
storage_index.ntotal = index.ntotal
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||||
|
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||||
|
if ef_construction is not None:
|
||||||
|
index.hnsw.efConstruction = ef_construction
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||||
|
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||||
|
logger.info(
|
||||||
|
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||||
|
disable_forward_rng,
|
||||||
|
disable_reverse_rng,
|
||||||
|
applied_forward,
|
||||||
|
applied_reverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_id = index.ntotal
|
||||||
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
|
new_id = str(base_id + offset)
|
||||||
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
|
chunk["id"] = new_id
|
||||||
|
|
||||||
|
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||||
|
offset_map_backup = offset_map.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for chunk in valid_chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk.get("metadata", {}),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
server_started, actual_port = server_manager.start_server(
|
||||||
|
port=server_port,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
)
|
||||||
|
if not server_started:
|
||||||
|
raise RuntimeError("Failed to start embedding server.")
|
||||||
|
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
|
_warmup_embedding_server(actual_port)
|
||||||
|
|
||||||
|
total_start = time.time()
|
||||||
|
add_elapsed = 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def _timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("incremental add timed out")
|
||||||
|
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||||
|
signal.alarm(add_timeout)
|
||||||
|
|
||||||
|
add_start = time.time()
|
||||||
|
for i in range(embeddings.shape[0]):
|
||||||
|
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||||
|
add_elapsed = time.time() - add_start
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.alarm(0)
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
if passages_file.exists():
|
||||||
|
with open(passages_file, "rb+") as f:
|
||||||
|
f.truncate(rollback_size)
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map_backup, f)
|
||||||
|
raise
|
||||||
|
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(False)
|
||||||
|
index.hnsw.set_disable_reverse_prune(False)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
total_elapsed = time.time() - total_start
|
||||||
|
|
||||||
|
return total_elapsed, add_elapsed
|
||||||
|
|
||||||
|
|
||||||
|
def _total_zmq_nodes(log_path: Path) -> int:
|
||||||
|
if not log_path.exists():
|
||||||
|
return 0
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
text = log_file.read()
|
||||||
|
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||||
|
|
||||||
|
|
||||||
|
def _warmup_embedding_server(port: int) -> None:
|
||||||
|
"""Send a dummy REQ so the embedding server loads its model."""
|
||||||
|
ctx = zmq.Context()
|
||||||
|
try:
|
||||||
|
sock = ctx.socket(zmq.REQ)
|
||||||
|
sock.setsockopt(zmq.LINGER, 0)
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||||
|
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||||
|
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||||
|
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||||
|
sock.send(payload)
|
||||||
|
try:
|
||||||
|
sock.recv()
|
||||||
|
except zmq.error.Again:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
ctx.term()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/leann-demo.leann"),
|
||||||
|
help="Output index base path (without extension).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Files used to build the initial index.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Files appended during the benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model used for build/update.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
default="sentence-transformers",
|
||||||
|
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
help="Distance metric for HNSW backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-construction",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="efConstruction setting for initial build.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-port",
|
||||||
|
type=int,
|
||||||
|
default=5557,
|
||||||
|
help="Port for the real embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-initial",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="Optional cap on initial passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-updates",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Optional cap on update passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--add-timeout",
|
||||||
|
type=int,
|
||||||
|
default=900,
|
||||||
|
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plot-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("bench_latency.png"),
|
||||||
|
help="Where to save the latency bar plot.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Where to append per-scenario results as CSV.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||||
|
|
||||||
|
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
|
||||||
|
scenarios = [
|
||||||
|
("baseline", False, False, True),
|
||||||
|
("no_cache_baseline", False, False, False),
|
||||||
|
("disable_forward_rng", True, False, True),
|
||||||
|
("disable_forward_and_reverse_rng", True, True, True),
|
||||||
|
]
|
||||||
|
|
||||||
|
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||||
|
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
import csv
|
||||||
|
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"cache_enabled",
|
||||||
|
"ef_construction",
|
||||||
|
"max_initial",
|
||||||
|
"max_updates",
|
||||||
|
"total_time_s",
|
||||||
|
"add_only_s",
|
||||||
|
"latency_ms_per_passage",
|
||||||
|
"zmq_nodes",
|
||||||
|
"stageA_time_s",
|
||||||
|
"stageBC_time_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
# Create CSV with header if missing
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for run in range(args.runs):
|
||||||
|
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||||
|
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||||
|
print(f"\nScenario: {name}")
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
if log_path.exists():
|
||||||
|
try:
|
||||||
|
log_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||||
|
args.index_path,
|
||||||
|
update_chunks,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
disable_forward,
|
||||||
|
disable_reverse,
|
||||||
|
args.server_port,
|
||||||
|
args.add_timeout,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
print(f"Scenario {name} timed out: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
if curr_size < prev_size:
|
||||||
|
prev_size = 0
|
||||||
|
zmq_count = 0
|
||||||
|
if log_path.exists():
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
log_file.seek(prev_size)
|
||||||
|
new_entries = log_file.read()
|
||||||
|
zmq_count = sum(
|
||||||
|
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||||
|
)
|
||||||
|
stageA = sum(
|
||||||
|
float(x)
|
||||||
|
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
stageBC = sum(
|
||||||
|
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stageA = 0.0
|
||||||
|
stageBC = 0.0
|
||||||
|
|
||||||
|
per_chunk = add_elapsed / len(update_chunks)
|
||||||
|
print(
|
||||||
|
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||||
|
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||||
|
)
|
||||||
|
print(f"ZMQ node fetch total: {zmq_count}")
|
||||||
|
results_total[name].append(total_elapsed)
|
||||||
|
results_add[name].append(add_elapsed)
|
||||||
|
results_zmq[name].append(zmq_count)
|
||||||
|
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||||
|
results_stageA[name].append(stageA)
|
||||||
|
results_stageBC[name].append(stageBC)
|
||||||
|
|
||||||
|
# Append row to CSV
|
||||||
|
if args.csv_path:
|
||||||
|
row = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": name,
|
||||||
|
"cache_enabled": 1 if cache_enabled else 0,
|
||||||
|
"ef_construction": args.ef_construction,
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"max_updates": args.max_updates,
|
||||||
|
"total_time_s": round(total_elapsed, 6),
|
||||||
|
"add_only_s": round(add_elapsed, 6),
|
||||||
|
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||||
|
"zmq_nodes": int(zmq_count),
|
||||||
|
"stageA_time_s": round(stageA, 6),
|
||||||
|
"stageBC_time_s": round(stageBC, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
for name in results_add:
|
||||||
|
add_values = results_add[name]
|
||||||
|
total_values = results_total[name]
|
||||||
|
zmq_values = results_zmq[name]
|
||||||
|
latency_values = results_ms_per_passage[name]
|
||||||
|
if not add_values:
|
||||||
|
print(f"{name}: no successful runs")
|
||||||
|
continue
|
||||||
|
avg_add = sum(add_values) / len(add_values)
|
||||||
|
avg_total = sum(total_values) / len(total_values)
|
||||||
|
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||||
|
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||||
|
runs = len(add_values)
|
||||||
|
print(
|
||||||
|
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||||
|
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.plot_path:
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
labels = [name for name, *_ in scenarios]
|
||||||
|
values = [
|
||||||
|
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||||
|
if results_ms_per_passage[name]
|
||||||
|
else 0.0
|
||||||
|
for name in labels
|
||||||
|
]
|
||||||
|
|
||||||
|
def _auto_cap(vals: list[float]) -> float | None:
|
||||||
|
s = sorted(vals, reverse=True)
|
||||||
|
if len(s) < 2:
|
||||||
|
return None
|
||||||
|
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||||
|
return s[1] * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||||
|
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.4, 5.0),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||||
|
)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap * 0.02,
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False)
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.set_xticks(range(len(labels)))
|
||||||
|
ax_bottom.set_xticklabels(labels)
|
||||||
|
ax = ax_bottom
|
||||||
|
else:
|
||||||
|
cap = args.cap_y or _auto_cap(values)
|
||||||
|
plt.figure(figsize=(7.2, 4.2))
|
||||||
|
ax = plt.gca()
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||||
|
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(b[0])
|
||||||
|
if v > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
ax.plot(
|
||||||
|
[0.02 - 0.02, 0.02 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
ax.plot(
|
||||||
|
[0.98 - 0.02, 0.98 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
if any(v > cap for v in values):
|
||||||
|
ax.legend(
|
||||||
|
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||||
|
)
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels)
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||||
|
|
||||||
|
plt.ylabel("Average add latency (ms per passage)")
|
||||||
|
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(args.plot_path)
|
||||||
|
print(f"Saved latency bar plot to {args.plot_path}")
|
||||||
|
# ZMQ time split (Stage A vs B/C)
|
||||||
|
try:
|
||||||
|
plt.figure(figsize=(6, 4))
|
||||||
|
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||||
|
bc_vals = [
|
||||||
|
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||||
|
]
|
||||||
|
ind = range(len(labels))
|
||||||
|
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||||
|
plt.bar(
|
||||||
|
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||||
|
)
|
||||||
|
plt.xticks(list(ind), labels, rotation=10)
|
||||||
|
plt.ylabel("Server ZMQ time (s)")
|
||||||
|
plt.title(
|
||||||
|
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||||
|
)
|
||||||
|
plt.legend()
|
||||||
|
out2 = args.plot_path.with_name(
|
||||||
|
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||||
|
)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out2)
|
||||||
|
print(f"Saved ZMQ time split plot to {out2}")
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to plot ZMQ split:", e)
|
||||||
|
except ImportError:
|
||||||
|
print("matplotlib not available; skipping plot generation")
|
||||||
|
|
||||||
|
# leave the last build on disk for inspection
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/bench_results.csv
Normal file
5
benchmarks/update/bench_results.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
@@ -0,0 +1,704 @@
|
|||||||
|
"""
|
||||||
|
Compare two latency models for small incremental updates vs. search:
|
||||||
|
|
||||||
|
Scenario A (sequential update then search):
|
||||||
|
- Build initial HNSW (is_recompute=True)
|
||||||
|
- Start embedding server (ZMQ) for recompute
|
||||||
|
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||||
|
- Then run a search query on the updated index
|
||||||
|
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||||
|
|
||||||
|
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||||
|
- Do NOT insert the N passages into the graph
|
||||||
|
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||||
|
embedding and run a search on the existing index
|
||||||
|
- After both finish, compute similarity between the query embedding and the N
|
||||||
|
new passage embeddings, merge with the index search results by score, and
|
||||||
|
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||||
|
|
||||||
|
This script reuses the model/data loading conventions of
|
||||||
|
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||||
|
comparison for the two execution strategies above.
|
||||||
|
|
||||||
|
Example (from the repository root):
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 --num-updates 5 --k 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import psutil # type: ignore
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||||
|
if metric == "cosine":
|
||||||
|
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||||
|
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
vecs = vecs / norms
|
||||||
|
return vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _read_index_for_search(index_path: Path) -> Any:
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
# Force-disable experimental disk cache when loading the index so that
|
||||||
|
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||||
|
cfg = faiss.HNSWIndexConfig()
|
||||||
|
cfg.is_recompute = True
|
||||||
|
if hasattr(cfg, "disk_cache_ratio"):
|
||||||
|
cfg.disk_cache_ratio = 0.0
|
||||||
|
if hasattr(cfg, "external_storage_path"):
|
||||||
|
cfg.external_storage_path = None
|
||||||
|
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||||
|
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||||
|
# ensure recompute mode persists after reload
|
||||||
|
try:
|
||||||
|
index.is_recompute = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
actual_ntotal = index.hnsw.levels.size()
|
||||||
|
except AttributeError:
|
||||||
|
actual_ntotal = index.ntotal
|
||||||
|
if actual_ntotal != index.ntotal:
|
||||||
|
print(
|
||||||
|
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
index.ntotal = actual_ntotal
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def _append_passages_for_updates(
|
||||||
|
meta_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
index_dir = meta_path.parent
|
||||||
|
meta_name = meta_path.name
|
||||||
|
if not meta_name.endswith(".meta.json"):
|
||||||
|
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||||
|
index_base = meta_name[: -len(".meta.json")]
|
||||||
|
|
||||||
|
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||||
|
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not offsets_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passage store missing; cannot register update passages for recompute mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(offsets_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
|
||||||
|
assigned_ids: list[str] = []
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
passage_id = str(start_id + i)
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[passage_id] = offset
|
||||||
|
assigned_ids.append(passage_id)
|
||||||
|
|
||||||
|
with open(offsets_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
meta = {}
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
return assigned_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||||
|
distances = np.zeros((1, k), dtype=np.float32)
|
||||||
|
indices = np.zeros((1, k), dtype=np.int64)
|
||||||
|
index.search(
|
||||||
|
1,
|
||||||
|
faiss.swig_ptr(q),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(indices),
|
||||||
|
)
|
||||||
|
return distances[0], indices[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _score_for_metric(dist: float, metric: str) -> float:
|
||||||
|
# Convert FAISS distance to a "higher is better" score
|
||||||
|
if metric in ("mips", "cosine"):
|
||||||
|
return float(dist)
|
||||||
|
# l2 distance (smaller better) -> negative distance as score
|
||||||
|
return -float(dist)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray],
|
||||||
|
offline_scores: list[tuple[int, float]],
|
||||||
|
k: int,
|
||||||
|
metric: str,
|
||||||
|
) -> list[tuple[str, float]]:
|
||||||
|
distances, indices = index_results
|
||||||
|
merged: list[tuple[str, float]] = []
|
||||||
|
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||||
|
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||||
|
for j, s in offline_scores:
|
||||||
|
merged.append((f"offline:{j}", s))
|
||||||
|
merged.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return merged[:k]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScenarioResult:
|
||||||
|
name: str
|
||||||
|
update_total_s: float
|
||||||
|
search_s: float
|
||||||
|
overall_s: float
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-initial", type=int, default=300)
|
||||||
|
parser.add_argument("--num-updates", type=int, default=5)
|
||||||
|
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="neural network",
|
||||||
|
help="Query text used for the search benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--server-port", type=int, default=5557)
|
||||||
|
parser.add_argument("--add-timeout", type=int, default=600)
|
||||||
|
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
)
|
||||||
|
parser.add_argument("--ef-construction", type=int, default=200)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only",
|
||||||
|
choices=["A", "B", "both"],
|
||||||
|
default="both",
|
||||||
|
help="Run only Scenario A, Scenario B, or both",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Where to append results (CSV).",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages loaded from --update-files")
|
||||||
|
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||||
|
if len(update_paragraphs) < args.num_updates:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
|
||||||
|
# Build initial index
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare index object and meta
|
||||||
|
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||||
|
index = _read_index_for_search(args.index_path)
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"max_initial",
|
||||||
|
"num_updates",
|
||||||
|
"k",
|
||||||
|
"total_time_s",
|
||||||
|
"add_total_s",
|
||||||
|
"search_time_s",
|
||||||
|
"emb_time_s",
|
||||||
|
"makespan_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
# Debug: list existing HNSW server PIDs before starting
|
||||||
|
try:
|
||||||
|
existing = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if existing:
|
||||||
|
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||||
|
for p in existing:
|
||||||
|
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||||
|
except Exception as _e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
add_total = 0.0
|
||||||
|
search_after_add = 0.0
|
||||||
|
total_seq = 0.0
|
||||||
|
port_a = None
|
||||||
|
if args.only in ("A", "both"):
|
||||||
|
# Scenario A: sequential update then search
|
||||||
|
start_id = index.ntotal
|
||||||
|
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||||
|
if assigned_ids:
|
||||||
|
logger.debug(
|
||||||
|
"Registered %d update passages starting at id %s",
|
||||||
|
len(assigned_ids),
|
||||||
|
assigned_ids[0],
|
||||||
|
)
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
ok, port = server_manager.start_server(
|
||||||
|
port=args.server_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError("Failed to start embedding server")
|
||||||
|
try:
|
||||||
|
# Set ZMQ port for recompute mode
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(port)
|
||||||
|
|
||||||
|
# Start A overall timer BEFORE computing update embeddings
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Compute embeddings for updates (counted into A's overall)
|
||||||
|
t_emb0 = time.time()
|
||||||
|
upd_embs = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time_updates = time.time() - t_emb0
|
||||||
|
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||||
|
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||||
|
|
||||||
|
# Perform sequential adds
|
||||||
|
for i in range(upd_embs.shape[0]):
|
||||||
|
t_add0 = time.time()
|
||||||
|
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||||
|
add_total += time.time() - t_add0
|
||||||
|
# Don't persist index after adds to avoid contaminating Scenario B
|
||||||
|
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||||
|
# faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
# Search after updates
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||||
|
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||||
|
|
||||||
|
# Warm up search with a dummy query first
|
||||||
|
print("[DEBUG] Warming up search...")
|
||||||
|
_ = _search(index, q_emb, 1)
|
||||||
|
|
||||||
|
t_s0 = time.time()
|
||||||
|
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||||
|
search_after_add = time.time() - t_s0
|
||||||
|
total_seq = time.time() - t0
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
port_a = port
|
||||||
|
|
||||||
|
print("\n=== Scenario A: update->search (sequential) ===")
|
||||||
|
# emb_time_updates is defined only when A runs
|
||||||
|
try:
|
||||||
|
_emb_a = emb_time_updates
|
||||||
|
except NameError:
|
||||||
|
_emb_a = 0.0
|
||||||
|
print(
|
||||||
|
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||||
|
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||||
|
)
|
||||||
|
# CSV row for A
|
||||||
|
if args.csv_path:
|
||||||
|
row_a = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "A",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": round(total_seq, 6),
|
||||||
|
"add_total_s": round(add_total, 6),
|
||||||
|
"search_time_s": round(search_after_add, 6),
|
||||||
|
"emb_time_s": round(_emb_a, 6),
|
||||||
|
"makespan_s": 0.0,
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_a)
|
||||||
|
|
||||||
|
# Verify server cleanup
|
||||||
|
try:
|
||||||
|
# short sleep to allow signal handling to finish
|
||||||
|
time.sleep(0.5)
|
||||||
|
leftovers = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if leftovers:
|
||||||
|
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||||
|
for p in leftovers:
|
||||||
|
print(
|
||||||
|
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||||
|
if args.only in ("B", "both"):
|
||||||
|
# ensure a server is available for recompute search
|
||||||
|
server_manager_b = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
requested_port = args.server_port if port_a is None else port_a
|
||||||
|
ok_b, port_b = server_manager_b.start_server(
|
||||||
|
port=requested_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok_b:
|
||||||
|
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||||
|
|
||||||
|
# Wait for server to fully initialize
|
||||||
|
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read the index first
|
||||||
|
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||||
|
|
||||||
|
# Then configure ZMQ port on the correct index object
|
||||||
|
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||||
|
index_no_update.hnsw.set_zmq_port(port_b)
|
||||||
|
elif hasattr(index_no_update, "set_zmq_port"):
|
||||||
|
index_no_update.set_zmq_port(port_b)
|
||||||
|
|
||||||
|
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||||
|
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||||
|
logger.info("Warming up embedding model for Scenario B...")
|
||||||
|
_ = compute_embeddings(
|
||||||
|
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare worker A: compute embeddings for the same N passages
|
||||||
|
emb_time = 0.0
|
||||||
|
updates_embs_offline: np.ndarray | None = None
|
||||||
|
|
||||||
|
def _worker_emb():
|
||||||
|
nonlocal emb_time, updates_embs_offline
|
||||||
|
t = time.time()
|
||||||
|
updates_embs_offline = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time = time.time() - t
|
||||||
|
|
||||||
|
# Pre-compute query embedding and warm up search outside of timed section.
|
||||||
|
q_vec = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||||
|
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||||
|
print("[DEBUG B] Warming up search...")
|
||||||
|
_ = _search(index_no_update, q_vec, 1)
|
||||||
|
|
||||||
|
# Worker B: timed search on the warmed index
|
||||||
|
search_time = 0.0
|
||||||
|
offline_elapsed = 0.0
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||||
|
|
||||||
|
def _worker_search():
|
||||||
|
nonlocal search_time, index_results
|
||||||
|
t = time.time()
|
||||||
|
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||||
|
search_time = time.time() - t
|
||||||
|
index_results = (distances, indices)
|
||||||
|
|
||||||
|
# Run two workers concurrently
|
||||||
|
t0 = time.time()
|
||||||
|
th1 = threading.Thread(target=_worker_emb)
|
||||||
|
th2 = threading.Thread(target=_worker_search)
|
||||||
|
th1.start()
|
||||||
|
th2.start()
|
||||||
|
th1.join()
|
||||||
|
th2.join()
|
||||||
|
offline_elapsed = time.time() - t0
|
||||||
|
|
||||||
|
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||||
|
offline_scores: list[tuple[int, float]] = []
|
||||||
|
if updates_embs_offline is not None:
|
||||||
|
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||||
|
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||||
|
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||||
|
for j in range(upd2.shape[0]):
|
||||||
|
if args.distance_metric in ("mips", "cosine"):
|
||||||
|
s = float(np.dot(q_vec[0], upd2[j]))
|
||||||
|
else:
|
||||||
|
diff = q_vec[0] - upd2[j]
|
||||||
|
s = -float(np.dot(diff, diff))
|
||||||
|
offline_scores.append((j, s))
|
||||||
|
|
||||||
|
merged_topk = (
|
||||||
|
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||||
|
if index_results
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||||
|
print(
|
||||||
|
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||||
|
)
|
||||||
|
if merged_topk:
|
||||||
|
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||||
|
print(f"Merged top-5 preview: {preview}")
|
||||||
|
# CSV row for B
|
||||||
|
if args.csv_path:
|
||||||
|
row_b = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "B",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": 0.0,
|
||||||
|
"add_total_s": 0.0,
|
||||||
|
"search_time_s": round(search_time, 6),
|
||||||
|
"emb_time_s": round(emb_time, 6),
|
||||||
|
"makespan_s": round(offline_elapsed, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_b)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
server_manager_b.stop_server()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
msg_a = (
|
||||||
|
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||||
|
if args.only in ("A", "both")
|
||||||
|
else "A: skipped"
|
||||||
|
)
|
||||||
|
msg_b = (
|
||||||
|
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||||
|
if args.only in ("B", "both")
|
||||||
|
else "B: skipped"
|
||||||
|
)
|
||||||
|
print(msg_a + "\n" + msg_b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/offline_vs_update.csv
Normal file
5
benchmarks/update/offline_vs_update.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
645
benchmarks/update/plot_bench_results.py
Normal file
645
benchmarks/update/plot_bench_results.py
Normal file
@@ -0,0 +1,645 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Plot latency bars from the benchmark CSV produced by
|
||||||
|
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||||
|
|
||||||
|
If you also provide an offline_vs_update.csv via --csv-right
|
||||||
|
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||||
|
output a side-by-side figure:
|
||||||
|
- Left: ms/passage bars (four RNG scenarios).
|
||||||
|
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python benchmarks/update/plot_bench_results.py \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
|
||||||
|
The script selects the latest run_id in the CSV and plots four bars for
|
||||||
|
the default scenarios:
|
||||||
|
- baseline
|
||||||
|
- no_cache_baseline
|
||||||
|
- disable_forward_rng
|
||||||
|
- disable_forward_and_reverse_rng
|
||||||
|
|
||||||
|
If multiple rows exist per scenario for that run_id, the script averages
|
||||||
|
their latency_ms_per_passage values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DEFAULT_SCENARIOS = [
|
||||||
|
"no_cache_baseline",
|
||||||
|
"baseline",
|
||||||
|
"disable_forward_rng",
|
||||||
|
"disable_forward_and_reverse_rng",
|
||||||
|
]
|
||||||
|
|
||||||
|
SCENARIO_LABELS = {
|
||||||
|
"baseline": "+ Cache",
|
||||||
|
"no_cache_baseline": "Naive \n Recompute",
|
||||||
|
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||||
|
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Paper-style colors and hatches for scenarios
|
||||||
|
SCENARIO_STYLES = {
|
||||||
|
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||||
|
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||||
|
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||||
|
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_latest_run(csv_path: Path):
|
||||||
|
rows = []
|
||||||
|
with csv_path.open("r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
rows.append(row)
|
||||||
|
if not rows:
|
||||||
|
raise SystemExit("CSV is empty: no rows to plot")
|
||||||
|
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||||
|
run_ids = [r.get("run_id", "") for r in rows]
|
||||||
|
latest = max(run_ids)
|
||||||
|
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||||
|
if not latest_rows:
|
||||||
|
# Fallback: take last 4 rows
|
||||||
|
latest_rows = rows[-4:]
|
||||||
|
latest = latest_rows[-1].get("run_id", "unknown")
|
||||||
|
return latest, latest_rows
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_latency(rows):
|
||||||
|
acc = defaultdict(list)
|
||||||
|
for r in rows:
|
||||||
|
sc = r.get("scenario", "")
|
||||||
|
try:
|
||||||
|
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
acc[sc].append(val)
|
||||||
|
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_cap(values: list[float]) -> float | None:
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
sorted_vals = sorted(values, reverse=True)
|
||||||
|
if len(sorted_vals) < 2:
|
||||||
|
return None
|
||||||
|
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||||
|
if second <= 0:
|
||||||
|
return None
|
||||||
|
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||||
|
if max_v >= 2.5 * second:
|
||||||
|
return second * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||||
|
# Draw small diagonal ticks near left/right to signal cap
|
||||||
|
x0, x1 = rel_x0, rel_x1
|
||||||
|
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
if v >= 1000:
|
||||||
|
return f"{v / 1000:.1f}k"
|
||||||
|
return f"{v:.1f}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.rcParams["font.family"] = "Helvetica"
|
||||||
|
plt.rcParams["ytick.direction"] = "in"
|
||||||
|
plt.rcParams["hatch.linewidth"] = 1.5
|
||||||
|
plt.rcParams["font.weight"] = "bold"
|
||||||
|
plt.rcParams["axes.labelweight"] = "bold"
|
||||||
|
plt.rcParams["text.usetex"] = True
|
||||||
|
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Path to results CSV (defaults to bench_results.csv)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--out",
|
||||||
|
type=Path,
|
||||||
|
default=Path("add_ablation.pdf"),
|
||||||
|
help="Output image path",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv-right",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--no-auto-cap",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||||
|
)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
latest_run, latest_rows = load_latest_run(args.csv)
|
||||||
|
avg = aggregate_latency(latest_rows)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
except Exception as e:
|
||||||
|
raise SystemExit(f"matplotlib not available: {e}")
|
||||||
|
|
||||||
|
scenarios = DEFAULT_SCENARIOS
|
||||||
|
values = [avg.get(name, 0.0) for name in scenarios]
|
||||||
|
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
# If right CSV is provided, build side-by-side figure
|
||||||
|
if args.csv_right is not None:
|
||||||
|
try:
|
||||||
|
right_rows_all = []
|
||||||
|
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||||
|
rreader = csv.DictReader(f)
|
||||||
|
right_rows_all = list(rreader)
|
||||||
|
if right_rows_all:
|
||||||
|
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||||
|
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||||
|
else:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
except Exception:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
|
||||||
|
a_total = 0.0
|
||||||
|
b_makespan = 0.0
|
||||||
|
for r in right_rows:
|
||||||
|
sc = (r.get("scenario", "") or "").strip().upper()
|
||||||
|
if sc == "A":
|
||||||
|
try:
|
||||||
|
a_total = float(r.get("total_time_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif sc == "B":
|
||||||
|
try:
|
||||||
|
b_makespan = float(r.get("makespan_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import gridspec
|
||||||
|
|
||||||
|
# Left subplot (reuse current style, with optional cap)
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
# Use broken axis for left subplot
|
||||||
|
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||||
|
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||||
|
gs = gridspec.GridSpec(
|
||||||
|
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||||
|
)
|
||||||
|
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||||
|
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||||
|
ax_right = fig.add_subplot(gs[:, 1])
|
||||||
|
|
||||||
|
# Determine break points
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = (
|
||||||
|
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||||
|
) # Increased to show more range
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.5, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = (
|
||||||
|
max(values) * 1.90 if values else 1.0
|
||||||
|
) # Increase headroom to 1.90 for text label and tick range
|
||||||
|
|
||||||
|
# Draw bars on both axes
|
||||||
|
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Set limits
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_left_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values (convert ms to s)
|
||||||
|
values_s = [v / 1000.0 for v in values]
|
||||||
|
lower_cap_s = lower_cap / 1000.0
|
||||||
|
upper_start_s = upper_start / 1000.0
|
||||||
|
ymax_s = ymax / 1000.0
|
||||||
|
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||||
|
ax_left_bottom.clear()
|
||||||
|
ax_left_top.clear()
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||||
|
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||||
|
# Draw in bottom axis for all bars
|
||||||
|
ax_left_bottom.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||||
|
if v > upper_start_s:
|
||||||
|
ax_left_top.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
for i, v in enumerate(values_s):
|
||||||
|
if v <= lower_cap_s:
|
||||||
|
ax_left_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap_s * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left_top.text(
|
||||||
|
i,
|
||||||
|
v + (ymax_s - upper_start_s) * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hide spines between axes
|
||||||
|
ax_left_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_left_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_left_top.tick_params(
|
||||||
|
labeltop=False, labelbottom=False, bottom=False
|
||||||
|
) # Hide tick marks
|
||||||
|
ax_left_bottom.xaxis.tick_bottom()
|
||||||
|
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||||
|
|
||||||
|
# Draw break marks (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_left_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||||
|
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
|
||||||
|
ax_left_bottom.set_xticks(x)
|
||||||
|
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||||
|
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match bar width with right subplot
|
||||||
|
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||||
|
ax_left_top.set_xlim(-0.6, 3.6)
|
||||||
|
|
||||||
|
ax_left = ax_left_bottom # for compatibility
|
||||||
|
else:
|
||||||
|
# Regular side-by-side layout
|
||||||
|
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||||
|
if val > cap:
|
||||||
|
bars[i].set_hatch("//")
|
||||||
|
ax_left.text(
|
||||||
|
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(val),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax_left.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax_left, y=0.98)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
else:
|
||||||
|
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax_left.set_ylabel("Latency (ms per passage)")
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
ax_left.set_title(
|
||||||
|
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right subplot (A vs B, seconds) - paper style
|
||||||
|
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||||
|
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||||
|
r_styles = [
|
||||||
|
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||||
|
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||||
|
]
|
||||||
|
# 2 bars, centered with proper spacing
|
||||||
|
xr = [0, 1]
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||||
|
ax_right.bar(
|
||||||
|
xr[i],
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
for i, v in enumerate(r_values):
|
||||||
|
max_v = max(r_values) if r_values else 1.0
|
||||||
|
offset = max(0.0002, 0.02 * max_v)
|
||||||
|
ax_right.text(
|
||||||
|
xr[i],
|
||||||
|
v + offset,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_right.set_xticks(xr)
|
||||||
|
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_right.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match left subplot's bar width visually
|
||||||
|
# Accounting for width_ratios=[1.5, 1]:
|
||||||
|
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||||
|
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# Right: 2 bars, need same visual width
|
||||||
|
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# range_right = 4.2 / 1.5 = 2.8
|
||||||
|
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||||
|
ax_right.set_xlim(-0.9, 1.9)
|
||||||
|
|
||||||
|
# Set y-axis limit with headroom for text labels
|
||||||
|
if r_values:
|
||||||
|
max_v = max(r_values)
|
||||||
|
ax_right.set_ylim(0, max_v * 1.15)
|
||||||
|
|
||||||
|
# Format y-axis to avoid scientific notation
|
||||||
|
ax_right.ticklabel_format(style="plain", axis="y")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Add aligned ylabels using fig.text (after tight_layout)
|
||||||
|
# Get the vertical center of the entire figure
|
||||||
|
fig_center_y = 0.5
|
||||||
|
# Left ylabel - closer to left plot
|
||||||
|
left_x = 0.05
|
||||||
|
fig.text(
|
||||||
|
left_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
# Right ylabel - closer to right plot
|
||||||
|
right_bbox = ax_right.get_position()
|
||||||
|
right_x = right_bbox.x0 - 0.07
|
||||||
|
fig.text(
|
||||||
|
right_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Broken-Y mode
|
||||||
|
if args.broken_y:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.5, 6.75),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine default breaks from second-highest
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
|
||||||
|
# Hide spines between axes and draw diagonal break marks
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
|
||||||
|
# Diagonal lines at the break (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||||
|
|
||||||
|
ax_bottom.set_xticks(x)
|
||||||
|
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax = ax_bottom # for labeling below
|
||||||
|
else:
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
|
||||||
|
plt.figure(figsize=(5.4, 3.15))
|
||||||
|
ax = plt.gca()
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||||
|
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(bar[0])
|
||||||
|
# Hatch and annotate when capped
|
||||||
|
if val > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax, y=0.98)
|
||||||
|
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||||
|
v > cap for v in values
|
||||||
|
) else None
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(
|
||||||
|
idx,
|
||||||
|
val + 1.0,
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=10,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
# Try to extract some context for title
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
fig.text(
|
||||||
|
0.02,
|
||||||
|
0.5,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
fig.suptitle(
|
||||||
|
"Add Operation Latency",
|
||||||
|
fontsize=11,
|
||||||
|
y=0.98,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||||
|
else:
|
||||||
|
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||||
|
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
200
docs/COLQWEN_GUIDE.md
Normal file
200
docs/COLQWEN_GUIDE.md
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
# ColQwen Integration Guide
|
||||||
|
|
||||||
|
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||||
|
|
||||||
|
### 1. Install Dependencies
|
||||||
|
```bash
|
||||||
|
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||||
|
brew install poppler # macOS only, for PDF processing
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Basic Usage
|
||||||
|
```bash
|
||||||
|
# Build index from PDFs
|
||||||
|
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||||
|
|
||||||
|
# Search with text queries
|
||||||
|
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||||
|
|
||||||
|
# Interactive Q&A
|
||||||
|
python -m apps.colqwen_rag ask research_papers --interactive
|
||||||
|
```
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
### Build Index
|
||||||
|
```bash
|
||||||
|
python -m apps.colqwen_rag build \
|
||||||
|
--pdfs ./pdf_directory/ \
|
||||||
|
--index my_index \
|
||||||
|
--model colqwen2 \
|
||||||
|
--pages-dir ./page_images/ # Optional: save page images
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--pdfs`: Directory containing PDF files (or single PDF path)
|
||||||
|
- `--index`: Name for the index (required)
|
||||||
|
- `--model`: `colqwen2` (default) or `colpali`
|
||||||
|
- `--pages-dir`: Directory to save page images (optional)
|
||||||
|
|
||||||
|
### Search Index
|
||||||
|
```bash
|
||||||
|
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--top-k`: Number of results to return (default: 5)
|
||||||
|
- `--model`: Model used for search (should match build model)
|
||||||
|
|
||||||
|
### Interactive Q&A
|
||||||
|
```bash
|
||||||
|
python -m apps.colqwen_rag ask my_index --interactive
|
||||||
|
```
|
||||||
|
|
||||||
|
**Commands in interactive mode:**
|
||||||
|
- Type your questions naturally
|
||||||
|
- `help`: Show available commands
|
||||||
|
- `quit`/`exit`/`q`: Exit interactive mode
|
||||||
|
|
||||||
|
## 🧪 Test & Reproduce Results
|
||||||
|
|
||||||
|
Run the reproduction test for issue #119:
|
||||||
|
```bash
|
||||||
|
python test_colqwen_reproduction.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
1. ✅ Check dependencies
|
||||||
|
2. 📥 Download sample PDF (Attention Is All You Need paper)
|
||||||
|
3. 🏗️ Build test index
|
||||||
|
4. 🔍 Run sample queries
|
||||||
|
5. 📊 Show how to generate similarity maps
|
||||||
|
|
||||||
|
## 🎨 Advanced: Similarity Maps
|
||||||
|
|
||||||
|
For visual similarity analysis, use the existing advanced script:
|
||||||
|
```bash
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||||
|
python multi-vector-leann-similarity-map.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Edit the script to customize:
|
||||||
|
- `QUERY`: Your question
|
||||||
|
- `MODEL`: "colqwen2" or "colpali"
|
||||||
|
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
|
||||||
|
- `SIMILARITY_MAP`: Generate heatmaps
|
||||||
|
- `ANSWER`: Enable Qwen-VL answer generation
|
||||||
|
|
||||||
|
## 🔧 How It Works
|
||||||
|
|
||||||
|
### ColQwen2 vs ColPali
|
||||||
|
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
|
||||||
|
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
|
||||||
|
2. **Vision Encoding**: Process images with ColQwen2/ColPali
|
||||||
|
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
|
||||||
|
4. **Query Processing**: Encode text queries with same model
|
||||||
|
5. **Similarity Search**: Find most relevant pages/regions
|
||||||
|
6. **Visual Maps**: Generate attention heatmaps (optional)
|
||||||
|
|
||||||
|
### Device Support
|
||||||
|
- **CUDA**: Best performance with GPU acceleration
|
||||||
|
- **MPS**: Apple Silicon Mac support
|
||||||
|
- **CPU**: Fallback for any system (slower)
|
||||||
|
|
||||||
|
Auto-detection: CUDA > MPS > CPU
|
||||||
|
|
||||||
|
## 📊 Performance Tips
|
||||||
|
|
||||||
|
### For Best Performance:
|
||||||
|
```bash
|
||||||
|
# Use ColQwen2 for latest features
|
||||||
|
--model colqwen2
|
||||||
|
|
||||||
|
# Save page images for reuse
|
||||||
|
--pages-dir ./cached_pages/
|
||||||
|
|
||||||
|
# Adjust batch size based on GPU memory
|
||||||
|
# (automatically handled)
|
||||||
|
```
|
||||||
|
|
||||||
|
### For Large Document Sets:
|
||||||
|
- Process PDFs in batches
|
||||||
|
- Use SSD storage for index files
|
||||||
|
- Consider using CUDA if available
|
||||||
|
|
||||||
|
## 🔗 Related Resources
|
||||||
|
|
||||||
|
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
|
||||||
|
- **Pylate**: https://github.com/lightonai/pylate
|
||||||
|
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
|
||||||
|
- **ColPali Paper**: Vision-Language Models for Document Retrieval
|
||||||
|
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
|
||||||
|
|
||||||
|
## 🐛 Troubleshooting
|
||||||
|
|
||||||
|
### PDF Conversion Issues (macOS)
|
||||||
|
```bash
|
||||||
|
# Install poppler
|
||||||
|
brew install poppler
|
||||||
|
which pdfinfo && pdfinfo -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory Issues
|
||||||
|
- Reduce batch size (automatically handled)
|
||||||
|
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
|
||||||
|
- Process fewer PDFs at once
|
||||||
|
|
||||||
|
### Model Download Issues
|
||||||
|
- Ensure internet connection for first run
|
||||||
|
- Models are cached after first download
|
||||||
|
- Use HuggingFace mirrors if needed
|
||||||
|
|
||||||
|
### Import Errors
|
||||||
|
```bash
|
||||||
|
# Ensure all dependencies installed
|
||||||
|
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||||
|
|
||||||
|
# Check PyTorch installation
|
||||||
|
python -c "import torch; print(torch.__version__)"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 💡 Examples
|
||||||
|
|
||||||
|
### Research Paper Analysis
|
||||||
|
```bash
|
||||||
|
# Index your research papers
|
||||||
|
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
|
||||||
|
|
||||||
|
# Ask research questions
|
||||||
|
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
|
||||||
|
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Document Q&A
|
||||||
|
```bash
|
||||||
|
# Index business documents
|
||||||
|
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
|
||||||
|
|
||||||
|
# Interactive analysis
|
||||||
|
python -m apps.colqwen_rag ask reports --interactive
|
||||||
|
```
|
||||||
|
|
||||||
|
### Visual Analysis
|
||||||
|
```bash
|
||||||
|
# Generate similarity maps for specific queries
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||||
|
# Edit multi-vector-leann-similarity-map.py with your query
|
||||||
|
python multi-vector-leann-similarity-map.py
|
||||||
|
# Check ./figures/ for generated heatmaps
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**
|
||||||
@@ -158,6 +158,95 @@ builder.build_index("./indexes/my-notes", chunks)
|
|||||||
|
|
||||||
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
|
## Optional Embedding Features
|
||||||
|
|
||||||
|
### Task-Specific Prompt Templates
|
||||||
|
|
||||||
|
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
|
||||||
|
|
||||||
|
- **Indexing documents**: `"title: none | text: "`
|
||||||
|
- **Search queries**: `"task: search result | query: "`
|
||||||
|
|
||||||
|
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build index with EmbeddingGemma (via LM Studio or Ollama)
|
||||||
|
leann build my-docs \
|
||||||
|
--docs ./documents \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-embeddinggemma-300m-qat \
|
||||||
|
--embedding-api-base http://localhost:1234/v1 \
|
||||||
|
--embedding-prompt-template "title: none | text: " \
|
||||||
|
--force
|
||||||
|
|
||||||
|
# Search with query-specific prompt
|
||||||
|
leann search my-docs \
|
||||||
|
--query "What is quantum computing?" \
|
||||||
|
--embedding-prompt-template "task: search result | query: "
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important Notes:**
|
||||||
|
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
|
||||||
|
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
|
||||||
|
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
|
||||||
|
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
|
||||||
|
|
||||||
|
**Python API:**
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-embeddinggemma-300m-qat",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://localhost:1234/v1",
|
||||||
|
"api_key": "lm-studio",
|
||||||
|
"prompt_template": "title: none | text: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-docs", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
**References:**
|
||||||
|
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
|
||||||
|
|
||||||
|
### LM Studio Auto-Detection (Optional)
|
||||||
|
|
||||||
|
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
|
||||||
|
|
||||||
|
**Prerequisites:**
|
||||||
|
```bash
|
||||||
|
# Install Node.js (if not already installed)
|
||||||
|
# Then install the LM Studio SDK globally
|
||||||
|
npm install -g @lmstudio/sdk
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works:**
|
||||||
|
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
|
||||||
|
2. Queries model metadata via Node.js subprocess
|
||||||
|
3. Automatically unloads model after query (respects your JIT auto-evict settings)
|
||||||
|
4. Falls back to static registry if SDK unavailable
|
||||||
|
|
||||||
|
**No configuration needed** - it works automatically when SDK is installed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
leann build my-docs \
|
||||||
|
--docs ./documents \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-nomic-embed-text-v1.5 \
|
||||||
|
--embedding-api-base http://localhost:1234/v1
|
||||||
|
# Context length auto-detected if SDK available
|
||||||
|
# Falls back to registry (2048) if not
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
- ✅ Automatic token limit detection
|
||||||
|
- ✅ Respects LM Studio JIT auto-evict settings
|
||||||
|
- ✅ No manual registry maintenance
|
||||||
|
- ✅ Graceful fallback if SDK unavailable
|
||||||
|
|
||||||
|
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
@@ -365,7 +454,7 @@ leann search my-index "your query" \
|
|||||||
|
|
||||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||||
|
|
||||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# One-time: install and configure SkyPilot
|
# One-time: install and configure SkyPilot
|
||||||
|
|||||||
48
docs/faq.md
48
docs/faq.md
@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
|||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
```
|
```
|
||||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
|
|
||||||
|
## 2. When should I use prompt templates?
|
||||||
|
|
||||||
|
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
|
||||||
|
|
||||||
|
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
|
||||||
|
|
||||||
|
**Example usage with EmbeddingGemma:**
|
||||||
|
```bash
|
||||||
|
# Build with document prompt
|
||||||
|
leann build my-docs --embedding-prompt-template "title: none | text: "
|
||||||
|
|
||||||
|
# Search with query prompt
|
||||||
|
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
|
||||||
|
|
||||||
|
## 3. Why is LM Studio loading multiple copies of my model?
|
||||||
|
|
||||||
|
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
|
||||||
|
|
||||||
|
**If you still see duplicates:**
|
||||||
|
- Update to the latest LEANN version
|
||||||
|
- Restart LM Studio to clear loaded models
|
||||||
|
- Check that you have JIT auto-evict enabled in LM Studio settings
|
||||||
|
|
||||||
|
**How it works now:**
|
||||||
|
1. LEANN loads model temporarily to get context length
|
||||||
|
2. Immediately unloads after query
|
||||||
|
3. LM Studio JIT loads model on-demand for actual embeddings
|
||||||
|
4. Auto-evicts per your settings
|
||||||
|
|
||||||
|
## 4. Do I need Node.js and @lmstudio/sdk?
|
||||||
|
|
||||||
|
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
|
||||||
|
|
||||||
|
**Benefits if you install it:**
|
||||||
|
- Automatic context length detection for LM Studio models
|
||||||
|
- No manual registry maintenance
|
||||||
|
- Always gets accurate token limits from the model itself
|
||||||
|
|
||||||
|
**To install (optional):**
|
||||||
|
```bash
|
||||||
|
npm install -g @lmstudio/sdk
|
||||||
|
```
|
||||||
|
|
||||||
|
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
@@ -215,6 +215,8 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
|
if hasattr(self._index, "set_zmq_port"):
|
||||||
|
self._index.set_zmq_port(zmq_port)
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.3.4",
|
"leann-core==0.3.5",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: c69511a99c...e2d243c40d
@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
|
|
||||||
# All required dependencies included
|
# All required dependencies included
|
||||||
|
|||||||
@@ -820,10 +820,10 @@ class LeannBuilder:
|
|||||||
actual_port,
|
actual_port,
|
||||||
requested_zmq_port,
|
requested_zmq_port,
|
||||||
)
|
)
|
||||||
try:
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
index.hnsw.zmq_port = actual_port
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
except AttributeError:
|
elif hasattr(index, "set_zmq_port"):
|
||||||
pass
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
if needs_recompute:
|
if needs_recompute:
|
||||||
for i in range(embeddings.shape[0]):
|
for i in range(embeddings.shape[0]):
|
||||||
@@ -916,6 +916,7 @@ class LeannSearcher:
|
|||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
use_grep: bool = False,
|
use_grep: bool = False,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
@@ -979,10 +980,24 @@ class LeannSearcher:
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Extract query template from stored embedding_options with fallback chain:
|
||||||
|
# 1. Check provider_options override (highest priority)
|
||||||
|
# 2. Check query_prompt_template (new format)
|
||||||
|
# 3. Check prompt_template (old format for backward compat)
|
||||||
|
# 4. None (no template)
|
||||||
|
query_template = None
|
||||||
|
if provider_options and "prompt_template" in provider_options:
|
||||||
|
query_template = provider_options["prompt_template"]
|
||||||
|
elif "query_prompt_template" in self.embedding_options:
|
||||||
|
query_template = self.embedding_options["query_prompt_template"]
|
||||||
|
elif "prompt_template" in self.embedding_options:
|
||||||
|
query_template = self.embedding_options["prompt_template"]
|
||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
query,
|
query,
|
||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
|
query_template=query_template,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
@@ -1236,15 +1251,15 @@ class LeannChat:
|
|||||||
"Please provide the best answer you can based on this context and your knowledge."
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
)
|
)
|
||||||
|
|
||||||
print("The context provided to the LLM is:")
|
logger.info("The context provided to the LLM is:")
|
||||||
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
||||||
print("-" * 150)
|
logger.info("-" * 150)
|
||||||
for r in results:
|
for r in results:
|
||||||
chunk_relevance = f"{r.score:.3f}"
|
chunk_relevance = f"{r.score:.3f}"
|
||||||
chunk_id = r.id
|
chunk_id = r.id
|
||||||
chunk_content = r.text[:60]
|
chunk_content = r.text[:60]
|
||||||
chunk_source = r.metadata.get("source", "")[:80]
|
chunk_source = r.metadata.get("source", "")[:80]
|
||||||
print(
|
logger.info(
|
||||||
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
||||||
)
|
)
|
||||||
ask_time = time.time()
|
ask_time = time.time()
|
||||||
|
|||||||
@@ -12,7 +12,13 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import (
|
||||||
|
resolve_anthropic_api_key,
|
||||||
|
resolve_anthropic_base_url,
|
||||||
|
resolve_ollama_host,
|
||||||
|
resolve_openai_api_key,
|
||||||
|
resolve_openai_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -845,6 +851,81 @@ class OpenAIChat(LLMInterface):
|
|||||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicChat(LLMInterface):
|
||||||
|
"""LLM interface for Anthropic Claude models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "claude-haiku-4-5",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.base_url = resolve_anthropic_base_url(base_url)
|
||||||
|
self.api_key = resolve_anthropic_api_key(api_key)
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
# Allow custom Anthropic-compatible endpoints via base_url
|
||||||
|
self.client = anthropic.Anthropic(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
logger.info(f"Sending request to Anthropic with model {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Anthropic API parameters
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if "temperature" in kwargs:
|
||||||
|
params["temperature"] = kwargs["temperature"]
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
params["top_p"] = kwargs["top_p"]
|
||||||
|
|
||||||
|
response = self.client.messages.create(**params)
|
||||||
|
|
||||||
|
# Extract text from response
|
||||||
|
response_text = response.content[0].text
|
||||||
|
|
||||||
|
# Log token usage
|
||||||
|
print(
|
||||||
|
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
|
||||||
|
f"input tokens = {response.usage.input_tokens}, "
|
||||||
|
f"output tokens = {response.usage.output_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.stop_reason == "max_tokens":
|
||||||
|
print("The query is exceeding the maximum allowed number of tokens")
|
||||||
|
|
||||||
|
return response_text.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error communicating with Anthropic: {e}")
|
||||||
|
return f"Error: Could not get a response from Anthropic. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
class SimulatedChat(LLMInterface):
|
class SimulatedChat(LLMInterface):
|
||||||
"""A simple simulated chat for testing and development."""
|
"""A simple simulated chat for testing and development."""
|
||||||
|
|
||||||
@@ -897,6 +978,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
)
|
)
|
||||||
elif llm_type == "gemini":
|
elif llm_type == "gemini":
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
|
elif llm_type == "anthropic":
|
||||||
|
return AnthropicChat(
|
||||||
|
model=model or "claude-3-5-sonnet-20241022",
|
||||||
|
api_key=llm_config.get("api_key"),
|
||||||
|
base_url=llm_config.get("base_url"),
|
||||||
|
)
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
return SimulatedChat()
|
return SimulatedChat()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -5,12 +5,15 @@ Packaged within leann-core so installed wheels can import it reliably.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Flag to ensure AST token warning only shown once per session
|
||||||
|
_ast_token_warning_shown = False
|
||||||
|
|
||||||
|
|
||||||
def estimate_token_count(text: str) -> int:
|
def estimate_token_count(text: str) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -174,37 +177,44 @@ def create_ast_chunks(
|
|||||||
max_chunk_size: int = 512,
|
max_chunk_size: int = 512,
|
||||||
chunk_overlap: int = 64,
|
chunk_overlap: int = 64,
|
||||||
metadata_template: str = "default",
|
metadata_template: str = "default",
|
||||||
) -> list[str]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Create AST-aware chunks from code documents using astchunk.
|
"""Create AST-aware chunks from code documents using astchunk.
|
||||||
|
|
||||||
Falls back to traditional chunking if astchunk is unavailable.
|
Falls back to traditional chunking if astchunk is unavailable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with {"text": str, "metadata": dict}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from astchunk import ASTChunkBuilder # optional dependency
|
from astchunk import ASTChunkBuilder # optional dependency
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"astchunk not available: {e}")
|
logger.error(f"astchunk not available: {e}")
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
logger.info("Falling back to traditional chunking for code files")
|
||||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
|
||||||
|
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
language = doc.metadata.get("language")
|
language = doc.metadata.get("language")
|
||||||
if not language:
|
if not language:
|
||||||
logger.warning("No language detected; falling back to traditional chunking")
|
logger.warning("No language detected; falling back to traditional chunking")
|
||||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Warn if AST chunk size + overlap might exceed common token limits
|
# Warn once if AST chunk size + overlap might exceed common token limits
|
||||||
|
# Note: Actual truncation happens at embedding time with dynamic model limits
|
||||||
|
global _ast_token_warning_shown
|
||||||
estimated_max_tokens = int(
|
estimated_max_tokens = int(
|
||||||
(max_chunk_size + chunk_overlap) * 1.2
|
(max_chunk_size + chunk_overlap) * 1.2
|
||||||
) # Conservative estimate
|
) # Conservative estimate
|
||||||
if estimated_max_tokens > 512:
|
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
||||||
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
||||||
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}"
|
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
|
||||||
|
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
|
||||||
)
|
)
|
||||||
|
_ast_token_warning_shown = True
|
||||||
|
|
||||||
configs = {
|
configs = {
|
||||||
"max_chunk_size": max_chunk_size,
|
"max_chunk_size": max_chunk_size,
|
||||||
@@ -229,17 +239,40 @@ def create_ast_chunks(
|
|||||||
|
|
||||||
chunks = chunk_builder.chunkify(code_content)
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
chunk_text = None
|
||||||
|
astchunk_metadata = {}
|
||||||
|
|
||||||
if hasattr(chunk, "text"):
|
if hasattr(chunk, "text"):
|
||||||
chunk_text = chunk.text
|
chunk_text = chunk.text
|
||||||
elif isinstance(chunk, dict) and "text" in chunk:
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
chunk_text = chunk
|
chunk_text = chunk
|
||||||
|
elif isinstance(chunk, dict):
|
||||||
|
# Handle astchunk format: {"content": "...", "metadata": {...}}
|
||||||
|
if "content" in chunk:
|
||||||
|
chunk_text = chunk["content"]
|
||||||
|
astchunk_metadata = chunk.get("metadata", {})
|
||||||
|
elif "text" in chunk:
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
else:
|
||||||
|
chunk_text = str(chunk) # Last resort
|
||||||
else:
|
else:
|
||||||
chunk_text = str(chunk)
|
chunk_text = str(chunk)
|
||||||
|
|
||||||
if chunk_text and chunk_text.strip():
|
if chunk_text and chunk_text.strip():
|
||||||
all_chunks.append(chunk_text.strip())
|
# Extract document-level metadata
|
||||||
|
doc_metadata = {
|
||||||
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
}
|
||||||
|
if "creation_date" in doc.metadata:
|
||||||
|
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||||
|
if "last_modified_date" in doc.metadata:
|
||||||
|
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||||
|
|
||||||
|
# Merge document metadata + astchunk metadata
|
||||||
|
combined_metadata = {**doc_metadata, **astchunk_metadata}
|
||||||
|
|
||||||
|
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||||
@@ -247,15 +280,19 @@ def create_ast_chunks(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||||
logger.info("Falling back to traditional chunking")
|
logger.info("Falling back to traditional chunking")
|
||||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||||
|
|
||||||
return all_chunks
|
return all_chunks
|
||||||
|
|
||||||
|
|
||||||
def create_traditional_chunks(
|
def create_traditional_chunks(
|
||||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
) -> list[str]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with {"text": str, "metadata": dict}
|
||||||
|
"""
|
||||||
if chunk_size <= 0:
|
if chunk_size <= 0:
|
||||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||||
chunk_size = 256
|
chunk_size = 256
|
||||||
@@ -271,19 +308,40 @@ def create_traditional_chunks(
|
|||||||
paragraph_separator="\n\n",
|
paragraph_separator="\n\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
all_texts = []
|
result = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
|
# Extract document-level metadata
|
||||||
|
doc_metadata = {
|
||||||
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
}
|
||||||
|
if "creation_date" in doc.metadata:
|
||||||
|
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||||
|
if "last_modified_date" in doc.metadata:
|
||||||
|
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
if nodes:
|
if nodes:
|
||||||
all_texts.extend(node.get_content() for node in nodes)
|
for node in nodes:
|
||||||
|
result.append({"text": node.get_content(), "metadata": doc_metadata})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Traditional chunking failed for document: {e}")
|
logger.error(f"Traditional chunking failed for document: {e}")
|
||||||
content = doc.get_content()
|
content = doc.get_content()
|
||||||
if content and content.strip():
|
if content and content.strip():
|
||||||
all_texts.append(content.strip())
|
result.append({"text": content.strip(), "metadata": doc_metadata})
|
||||||
|
|
||||||
return all_texts
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _traditional_chunks_as_dicts(
|
||||||
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Helper: Traditional chunking that returns dict format for consistency.
|
||||||
|
|
||||||
|
This is now just an alias for create_traditional_chunks for backwards compatibility.
|
||||||
|
"""
|
||||||
|
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
|
||||||
def create_text_chunks(
|
def create_text_chunks(
|
||||||
@@ -295,8 +353,12 @@ def create_text_chunks(
|
|||||||
ast_chunk_overlap: int = 64,
|
ast_chunk_overlap: int = 64,
|
||||||
code_file_extensions: Optional[list[str]] = None,
|
code_file_extensions: Optional[list[str]] = None,
|
||||||
ast_fallback_traditional: bool = True,
|
ast_fallback_traditional: bool = True,
|
||||||
) -> list[str]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Create text chunks from documents with optional AST support for code files."""
|
"""Create text chunks from documents with optional AST support for code files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with {"text": str, "metadata": dict}
|
||||||
|
"""
|
||||||
if not documents:
|
if not documents:
|
||||||
logger.warning("No documents provided for chunking")
|
logger.warning("No documents provided for chunking")
|
||||||
return []
|
return []
|
||||||
@@ -331,24 +393,17 @@ def create_text_chunks(
|
|||||||
logger.error(f"AST chunking failed: {e}")
|
logger.error(f"AST chunking failed: {e}")
|
||||||
if ast_fallback_traditional:
|
if ast_fallback_traditional:
|
||||||
all_chunks.extend(
|
all_chunks.extend(
|
||||||
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
if text_docs:
|
if text_docs:
|
||||||
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
|
||||||
else:
|
else:
|
||||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
|
|
||||||
# Validate chunk token limits (default to 512 for safety)
|
# Note: Token truncation is now handled at embedding time with dynamic model limits
|
||||||
# This provides a safety net for embedding models with token limits
|
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
|
||||||
validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512)
|
return all_chunks
|
||||||
|
|
||||||
if num_truncated > 0:
|
|
||||||
logger.info(
|
|
||||||
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
return validated_chunks
|
|
||||||
|
|||||||
@@ -11,7 +11,12 @@ from tqdm import tqdm
|
|||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .interactive_utils import create_cli_session
|
from .interactive_utils import create_cli_session
|
||||||
from .registry import register_project_directory
|
from .registry import register_project_directory
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import (
|
||||||
|
resolve_anthropic_base_url,
|
||||||
|
resolve_ollama_host,
|
||||||
|
resolve_openai_api_key,
|
||||||
|
resolve_openai_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||||
@@ -144,6 +149,18 @@ Examples:
|
|||||||
default=None,
|
default=None,
|
||||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
)
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--query-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template for queries (different from build template for task-specific models)",
|
||||||
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||||
)
|
)
|
||||||
@@ -260,6 +277,12 @@ Examples:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Display file paths and metadata in search results",
|
help="Display file paths and metadata in search results",
|
||||||
)
|
)
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
|
||||||
|
)
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
@@ -273,7 +296,7 @@ Examples:
|
|||||||
"--llm",
|
"--llm",
|
||||||
type=str,
|
type=str,
|
||||||
default="ollama",
|
default="ollama",
|
||||||
choices=["simulated", "ollama", "hf", "openai"],
|
choices=["simulated", "ollama", "hf", "openai", "anthropic"],
|
||||||
help="LLM provider (default: ollama)",
|
help="LLM provider (default: ollama)",
|
||||||
)
|
)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
@@ -323,7 +346,7 @@ Examples:
|
|||||||
"--api-key",
|
"--api-key",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
help="API key for cloud LLM providers (OpenAI, Anthropic)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
@@ -1162,6 +1185,11 @@ Examples:
|
|||||||
print(f"Warning: Could not process {file_path}: {e}")
|
print(f"Warning: Could not process {file_path}: {e}")
|
||||||
|
|
||||||
# Load other file types with default reader
|
# Load other file types with default reader
|
||||||
|
# Exclude PDFs from code_extensions if they were already processed separately
|
||||||
|
other_file_extensions = code_extensions
|
||||||
|
if should_process_pdfs and ".pdf" in code_extensions:
|
||||||
|
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create a custom file filter function using our PathSpec
|
# Create a custom file filter function using our PathSpec
|
||||||
def file_filter(
|
def file_filter(
|
||||||
@@ -1177,15 +1205,19 @@ Examples:
|
|||||||
except (ValueError, OSError):
|
except (ValueError, OSError):
|
||||||
return True # Include files that can't be processed
|
return True # Include files that can't be processed
|
||||||
|
|
||||||
other_docs = SimpleDirectoryReader(
|
# Only load other file types if there are extensions to process
|
||||||
docs_dir,
|
if other_file_extensions:
|
||||||
recursive=True,
|
other_docs = SimpleDirectoryReader(
|
||||||
encoding="utf-8",
|
docs_dir,
|
||||||
required_exts=code_extensions,
|
recursive=True,
|
||||||
file_extractor={}, # Use default extractors
|
encoding="utf-8",
|
||||||
exclude_hidden=not include_hidden,
|
required_exts=other_file_extensions,
|
||||||
filename_as_id=True,
|
file_extractor={}, # Use default extractors
|
||||||
).load_data(show_progress=True)
|
exclude_hidden=not include_hidden,
|
||||||
|
filename_as_id=True,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
else:
|
||||||
|
other_docs = []
|
||||||
|
|
||||||
# Filter documents after loading based on gitignore rules
|
# Filter documents after loading based on gitignore rules
|
||||||
filtered_docs = []
|
filtered_docs = []
|
||||||
@@ -1279,13 +1311,8 @@ Examples:
|
|||||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: AST chunking currently returns plain text chunks without metadata
|
# create_text_chunks now returns list[dict] with metadata preserved
|
||||||
# We preserve basic file info by associating chunks with their source documents
|
all_texts.extend(chunk_texts)
|
||||||
# For better metadata preservation, documents list order should be maintained
|
|
||||||
for chunk_text in chunk_texts:
|
|
||||||
# TODO: Enhance create_text_chunks to return metadata alongside text
|
|
||||||
# For now, we store chunks with empty metadata
|
|
||||||
all_texts.append({"text": chunk_text, "metadata": {}})
|
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(
|
print(
|
||||||
@@ -1403,6 +1430,14 @@ Examples:
|
|||||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
if resolved_embedding_key:
|
if resolved_embedding_key:
|
||||||
embedding_options["api_key"] = resolved_embedding_key
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
if args.query_prompt_template:
|
||||||
|
# New format: separate templates
|
||||||
|
if args.embedding_prompt_template:
|
||||||
|
embedding_options["build_prompt_template"] = args.embedding_prompt_template
|
||||||
|
embedding_options["query_prompt_template"] = args.query_prompt_template
|
||||||
|
elif args.embedding_prompt_template:
|
||||||
|
# Old format: single template (backward compat)
|
||||||
|
embedding_options["prompt_template"] = args.embedding_prompt_template
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend_name,
|
backend_name=args.backend_name,
|
||||||
@@ -1524,6 +1559,11 @@ Examples:
|
|||||||
print("Invalid input. Aborting search.")
|
print("Invalid input. Aborting search.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Build provider_options for runtime override
|
||||||
|
provider_options = {}
|
||||||
|
if args.embedding_prompt_template:
|
||||||
|
provider_options["prompt_template"] = args.embedding_prompt_template
|
||||||
|
|
||||||
searcher = LeannSearcher(index_path=index_path)
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
results = searcher.search(
|
results = searcher.search(
|
||||||
query,
|
query,
|
||||||
@@ -1533,6 +1573,7 @@ Examples:
|
|||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy,
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
provider_options=provider_options if provider_options else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Search results for '{query}' (top {len(results)}):")
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
@@ -1580,6 +1621,12 @@ Examples:
|
|||||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||||
if resolved_api_key:
|
if resolved_api_key:
|
||||||
llm_config["api_key"] = resolved_api_key
|
llm_config["api_key"] = resolved_api_key
|
||||||
|
elif args.llm == "anthropic":
|
||||||
|
# For Anthropic, pass base_url and API key if provided
|
||||||
|
if args.api_base:
|
||||||
|
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
|
||||||
|
if args.api_key:
|
||||||
|
llm_config["api_key"] = args.api_key
|
||||||
|
|
||||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
|||||||
@@ -4,118 +4,310 @@ Consolidates all embedding computation logic using SentenceTransformer
|
|||||||
Preserves all optimization parameters to ensure performance
|
Preserves all optimization parameters to ensure performance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
|
||||||
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
|
|
||||||
"""
|
|
||||||
Truncate texts to token limit using tiktoken or conservative character truncation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to truncate
|
|
||||||
max_tokens: Maximum tokens allowed per text
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of truncated texts that should fit within token limit
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import tiktoken
|
|
||||||
|
|
||||||
encoder = tiktoken.get_encoding("cl100k_base")
|
|
||||||
truncated = []
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
tokens = encoder.encode(text)
|
|
||||||
if len(tokens) > max_tokens:
|
|
||||||
# Truncate to max_tokens and decode back to text
|
|
||||||
truncated_tokens = tokens[:max_tokens]
|
|
||||||
truncated_text = encoder.decode(truncated_tokens)
|
|
||||||
truncated.append(truncated_text)
|
|
||||||
logger.warning(
|
|
||||||
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
|
|
||||||
f"(from {len(text)} to {len(truncated_text)} characters)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
truncated.append(text)
|
|
||||||
return truncated
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
# Fallback: Conservative character truncation
|
|
||||||
# Assume worst case: 1.5 tokens per character for code content
|
|
||||||
char_limit = int(max_tokens / 1.5)
|
|
||||||
truncated = []
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
if len(text) > char_limit:
|
|
||||||
truncated_text = text[:char_limit]
|
|
||||||
truncated.append(truncated_text)
|
|
||||||
logger.warning(
|
|
||||||
f"Truncated text from {len(text)} to {char_limit} characters "
|
|
||||||
f"(conservative estimate for {max_tokens} tokens)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
truncated.append(text)
|
|
||||||
return truncated
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(model_name: str) -> int:
|
|
||||||
"""
|
|
||||||
Get token limit for a given embedding model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the embedding model
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Token limit for the model, defaults to 512 if unknown
|
|
||||||
"""
|
|
||||||
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
|
||||||
base_model_name = model_name.split(":")[0]
|
|
||||||
|
|
||||||
# Check exact match first
|
|
||||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
|
||||||
return EMBEDDING_MODEL_LIMITS[model_name]
|
|
||||||
|
|
||||||
# Check base name match
|
|
||||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
|
||||||
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
|
||||||
|
|
||||||
# Check partial matches for common patterns
|
|
||||||
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
|
||||||
if known_model in base_model_name or base_model_name in known_model:
|
|
||||||
return limit
|
|
||||||
|
|
||||||
# Default to conservative 512 token limit
|
|
||||||
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
|
|
||||||
return 512
|
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# Global model cache to avoid repeated loading
|
# Token limit registry for embedding models
|
||||||
_model_cache: dict[str, Any] = {}
|
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
|
||||||
|
# Ollama models use dynamic discovery via /api/show
|
||||||
# Known embedding model token limits
|
|
||||||
EMBEDDING_MODEL_LIMITS = {
|
EMBEDDING_MODEL_LIMITS = {
|
||||||
"nomic-embed-text": 512,
|
# Nomic models (common across servers)
|
||||||
|
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
|
||||||
|
"nomic-embed-text-v1.5": 2048,
|
||||||
"nomic-embed-text-v2": 512,
|
"nomic-embed-text-v2": 512,
|
||||||
|
# Other embedding models
|
||||||
"mxbai-embed-large": 512,
|
"mxbai-embed-large": 512,
|
||||||
"all-minilm": 512,
|
"all-minilm": 512,
|
||||||
"bge-m3": 8192,
|
"bge-m3": 8192,
|
||||||
"snowflake-arctic-embed": 512,
|
"snowflake-arctic-embed": 512,
|
||||||
# Add more models as needed
|
# OpenAI models
|
||||||
|
"text-embedding-3-small": 8192,
|
||||||
|
"text-embedding-3-large": 8192,
|
||||||
|
"text-embedding-ada-002": 8192,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Runtime cache for dynamically discovered token limits
|
||||||
|
# Key: (model_name, base_url), Value: token_limit
|
||||||
|
# Prevents repeated SDK/API calls for the same model
|
||||||
|
_token_limit_cache: dict[tuple[str, str], int] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_token_limit(
|
||||||
|
model_name: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
default: int = 2048,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get token limit for a given embedding model.
|
||||||
|
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||||
|
Caches discovered limits to prevent repeated API/SDK calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the embedding model
|
||||||
|
base_url: Base URL of the embedding server (for dynamic discovery)
|
||||||
|
default: Default token limit if model not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token limit for the model in tokens
|
||||||
|
"""
|
||||||
|
# Check cache first to avoid repeated SDK/API calls
|
||||||
|
cache_key = (model_name, base_url or "")
|
||||||
|
if cache_key in _token_limit_cache:
|
||||||
|
cached_limit = _token_limit_cache[cache_key]
|
||||||
|
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
|
||||||
|
return cached_limit
|
||||||
|
|
||||||
|
# Try Ollama dynamic discovery if base_url provided
|
||||||
|
if base_url:
|
||||||
|
# Detect Ollama servers by port or "ollama" in URL
|
||||||
|
if "11434" in base_url or "ollama" in base_url.lower():
|
||||||
|
limit = _query_ollama_context_limit(model_name, base_url)
|
||||||
|
if limit:
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Try LM Studio SDK discovery
|
||||||
|
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
|
||||||
|
# Convert HTTP to WebSocket URL
|
||||||
|
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
|
||||||
|
# Remove /v1 suffix if present
|
||||||
|
if ws_url.endswith("/v1"):
|
||||||
|
ws_url = ws_url[:-3]
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(model_name, ws_url)
|
||||||
|
if limit:
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Fallback to known model registry with version handling (from PR #154)
|
||||||
|
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
||||||
|
base_model_name = model_name.split(":")[0]
|
||||||
|
|
||||||
|
# Check exact match first
|
||||||
|
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
|
limit = EMBEDDING_MODEL_LIMITS[model_name]
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Check base name match
|
||||||
|
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
|
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Check partial matches for common patterns
|
||||||
|
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
|
||||||
|
if known_model in base_model_name or base_model_name in known_model:
|
||||||
|
_token_limit_cache[cache_key] = registry_limit
|
||||||
|
return registry_limit
|
||||||
|
|
||||||
|
# Default fallback
|
||||||
|
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||||
|
_token_limit_cache[cache_key] = default
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
|
||||||
|
"""
|
||||||
|
Truncate texts to fit within token limit using tiktoken.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of text strings to truncate
|
||||||
|
token_limit: Maximum number of tokens allowed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of truncated texts (same length as input)
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use tiktoken with cl100k_base encoding
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
truncated_texts = []
|
||||||
|
truncation_count = 0
|
||||||
|
total_tokens_removed = 0
|
||||||
|
max_original_length = 0
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
tokens = enc.encode(text)
|
||||||
|
original_length = len(tokens)
|
||||||
|
|
||||||
|
if original_length <= token_limit:
|
||||||
|
# Text is within limit, keep as is
|
||||||
|
truncated_texts.append(text)
|
||||||
|
else:
|
||||||
|
# Truncate to token_limit
|
||||||
|
truncated_tokens = tokens[:token_limit]
|
||||||
|
truncated_text = enc.decode(truncated_tokens)
|
||||||
|
truncated_texts.append(truncated_text)
|
||||||
|
|
||||||
|
# Track truncation statistics
|
||||||
|
truncation_count += 1
|
||||||
|
tokens_removed = original_length - token_limit
|
||||||
|
total_tokens_removed += tokens_removed
|
||||||
|
max_original_length = max(max_original_length, original_length)
|
||||||
|
|
||||||
|
# Log individual truncation at WARNING level (first few only)
|
||||||
|
if truncation_count <= 3:
|
||||||
|
logger.warning(
|
||||||
|
f"Text {i + 1} truncated: {original_length} → {token_limit} tokens "
|
||||||
|
f"({tokens_removed} tokens removed)"
|
||||||
|
)
|
||||||
|
elif truncation_count == 4:
|
||||||
|
logger.warning("Further truncation warnings suppressed...")
|
||||||
|
|
||||||
|
# Log summary at INFO level
|
||||||
|
if truncation_count > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
|
||||||
|
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncated_texts
|
||||||
|
|
||||||
|
|
||||||
|
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Query Ollama /api/show for model context limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the Ollama model
|
||||||
|
base_url: Base URL of the Ollama server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context limit in tokens if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/api/show",
|
||||||
|
json={"name": model_name},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if "model_info" in data:
|
||||||
|
# Look for *.context_length in model_info
|
||||||
|
for key, value in data["model_info"].items():
|
||||||
|
if "context_length" in key and isinstance(value, int):
|
||||||
|
logger.info(f"Detected {model_name} context limit: {value} tokens")
|
||||||
|
return value
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to query Ollama context limit: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Query LM Studio SDK for model context length via Node.js subprocess.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the LM Studio model
|
||||||
|
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context limit in tokens if found, None otherwise
|
||||||
|
"""
|
||||||
|
# Inline JavaScript using @lmstudio/sdk
|
||||||
|
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
|
||||||
|
js_code = f"""
|
||||||
|
const {{ LMStudioClient }} = require('@lmstudio/sdk');
|
||||||
|
(async () => {{
|
||||||
|
try {{
|
||||||
|
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
|
||||||
|
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
|
||||||
|
const contextLength = await model.getContextLength();
|
||||||
|
await model.unload(); // Unload immediately to respect JIT auto-evict settings
|
||||||
|
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
|
||||||
|
}} catch (error) {{
|
||||||
|
console.error(JSON.stringify({{ error: error.message }}));
|
||||||
|
process.exit(1);
|
||||||
|
}}
|
||||||
|
}})();
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
# Try to get npm global root (works with nvm, brew node, etc.)
|
||||||
|
try:
|
||||||
|
npm_root = subprocess.run(
|
||||||
|
["npm", "root", "-g"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if npm_root.returncode == 0:
|
||||||
|
global_modules = npm_root.stdout.strip()
|
||||||
|
# Append to existing NODE_PATH if present
|
||||||
|
existing_node_path = env.get("NODE_PATH", "")
|
||||||
|
env["NODE_PATH"] = (
|
||||||
|
f"{global_modules}:{existing_node_path}"
|
||||||
|
if existing_node_path
|
||||||
|
else global_modules
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# If npm not available, continue with existing NODE_PATH
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
["node", "-e", js_code],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=10,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.debug(f"LM Studio SDK error: {result.stderr}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
context_length = data.get("contextLength")
|
||||||
|
|
||||||
|
if context_length and context_length > 0:
|
||||||
|
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
|
||||||
|
return context_length
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.debug("LM Studio SDK query timeout")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("LM Studio SDK returned invalid JSON")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"LM Studio SDK query failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Global model cache to avoid repeated loading
|
||||||
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@@ -160,6 +352,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
base_url=provider_options.get("base_url"),
|
base_url=provider_options.get("base_url"),
|
||||||
api_key=provider_options.get("api_key"),
|
api_key=provider_options.get("api_key"),
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
@@ -169,6 +362,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
host=provider_options.get("host"),
|
host=provider_options.get("host"),
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
elif mode == "gemini":
|
elif mode == "gemini":
|
||||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
@@ -507,6 +701,7 @@ def compute_embeddings_openai(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
@@ -525,26 +720,40 @@ def compute_embeddings_openai(
|
|||||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||||
)
|
)
|
||||||
|
|
||||||
resolved_base_url = resolve_openai_base_url(base_url)
|
# Extract base_url and api_key from provider_options if not provided directly
|
||||||
resolved_api_key = resolve_openai_api_key(api_key)
|
provider_options = provider_options or {}
|
||||||
|
effective_base_url = base_url or provider_options.get("base_url")
|
||||||
|
effective_api_key = api_key or provider_options.get("api_key")
|
||||||
|
|
||||||
|
resolved_base_url = resolve_openai_base_url(effective_base_url)
|
||||||
|
resolved_api_key = resolve_openai_api_key(effective_api_key)
|
||||||
|
|
||||||
if not resolved_api_key:
|
if not resolved_api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
# Cache OpenAI client
|
# Create OpenAI client
|
||||||
cache_key = f"openai_client::{resolved_base_url}"
|
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||||
if cache_key in _model_cache:
|
|
||||||
client = _model_cache[cache_key]
|
|
||||||
else:
|
|
||||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
|
||||||
_model_cache[cache_key] = client
|
|
||||||
logger.info("OpenAI client cached")
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
)
|
)
|
||||||
print(f"len of texts: {len(texts)}")
|
print(f"len of texts: {len(texts)}")
|
||||||
|
|
||||||
|
# Apply prompt template if provided
|
||||||
|
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||||
|
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||||
|
"prompt_template"
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||||
|
texts = [f"{prompt_template}{text}" for text in texts]
|
||||||
|
|
||||||
|
# Query token limit and apply truncation
|
||||||
|
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
|
||||||
|
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
|
||||||
|
texts = truncate_to_token_limit(texts, token_limit)
|
||||||
|
|
||||||
# OpenAI has limits on batch size and input length
|
# OpenAI has limits on batch size and input length
|
||||||
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
@@ -575,7 +784,15 @@ def compute_embeddings_openai(
|
|||||||
try:
|
try:
|
||||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
|
# Verify we got the expected number of embeddings
|
||||||
|
if len(batch_embeddings) != len(batch_texts):
|
||||||
|
logger.warning(
|
||||||
|
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only take the number of embeddings that match the batch size
|
||||||
|
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Batch {i} failed: {e}")
|
logger.error(f"Batch {i} failed: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -665,6 +882,7 @@ def compute_embeddings_ollama(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with true batch processing.
|
Compute embeddings using Ollama API with true batch processing.
|
||||||
@@ -677,6 +895,7 @@ def compute_embeddings_ollama(
|
|||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||||
|
provider_options: Optional provider-specific options (e.g., prompt_template)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
@@ -813,15 +1032,24 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||||
|
|
||||||
# Get model token limit and apply truncation
|
# Apply prompt template if provided
|
||||||
token_limit = get_model_token_limit(model_name)
|
provider_options = provider_options or {}
|
||||||
|
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||||
|
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||||
|
"prompt_template"
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||||
|
texts = [f"{prompt_template}{text}" for text in texts]
|
||||||
|
|
||||||
|
# Get model token limit and apply truncation before batching
|
||||||
|
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||||
|
|
||||||
# Apply token-aware truncation to all texts
|
# Apply truncation to all texts before batch processing
|
||||||
truncated_texts = truncate_to_token_limit(texts, token_limit)
|
# Function logs truncation details internally
|
||||||
if len(truncated_texts) != len(texts):
|
texts = truncate_to_token_limit(texts, token_limit)
|
||||||
logger.error("Truncation failed - text count mismatch")
|
|
||||||
truncated_texts = texts # Fallback to original texts
|
|
||||||
|
|
||||||
def get_batch_embeddings(batch_texts):
|
def get_batch_embeddings(batch_texts):
|
||||||
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||||
@@ -866,7 +1094,9 @@ def compute_embeddings_ollama(
|
|||||||
if retry_count >= max_retries:
|
if retry_count >= max_retries:
|
||||||
# Enhanced error detection for token limit violations
|
# Enhanced error detection for token limit violations
|
||||||
error_msg = str(e).lower()
|
error_msg = str(e).lower()
|
||||||
if "token" in error_msg and ("limit" in error_msg or "exceed" in error_msg or "length" in error_msg):
|
if "token" in error_msg and (
|
||||||
|
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
|
||||||
|
):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Token limit exceeded for batch. Error: {e}. "
|
f"Token limit exceeded for batch. Error: {e}. "
|
||||||
f"Consider reducing chunk sizes or check token truncation."
|
f"Consider reducing chunk sizes or check token truncation."
|
||||||
@@ -877,12 +1107,12 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
return None, list(range(len(batch_texts)))
|
return None, list(range(len(batch_texts)))
|
||||||
|
|
||||||
# Process truncated texts in batches
|
# Process texts in batches
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
all_failed_indices = []
|
all_failed_indices = []
|
||||||
|
|
||||||
# Setup progress bar if needed
|
# Setup progress bar if needed
|
||||||
show_progress = is_build or len(truncated_texts) > 10
|
show_progress = is_build or len(texts) > 10
|
||||||
try:
|
try:
|
||||||
if show_progress:
|
if show_progress:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -890,7 +1120,7 @@ def compute_embeddings_ollama(
|
|||||||
show_progress = False
|
show_progress = False
|
||||||
|
|
||||||
# Process batches
|
# Process batches
|
||||||
num_batches = (len(truncated_texts) + batch_size - 1) // batch_size
|
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
if show_progress:
|
if show_progress:
|
||||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||||
@@ -899,8 +1129,8 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
for batch_idx in batch_iterator:
|
for batch_idx in batch_iterator:
|
||||||
start_idx = batch_idx * batch_size
|
start_idx = batch_idx * batch_size
|
||||||
end_idx = min(start_idx + batch_size, len(truncated_texts))
|
end_idx = min(start_idx + batch_size, len(texts))
|
||||||
batch_texts = truncated_texts[start_idx:end_idx]
|
batch_texts = texts[start_idx:end_idx]
|
||||||
|
|
||||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||||
|
|
||||||
@@ -915,11 +1145,11 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
# Handle failed embeddings
|
# Handle failed embeddings
|
||||||
if all_failed_indices:
|
if all_failed_indices:
|
||||||
if len(all_failed_indices) == len(truncated_texts):
|
if len(all_failed_indices) == len(texts):
|
||||||
raise RuntimeError("Failed to compute any embeddings")
|
raise RuntimeError("Failed to compute any embeddings")
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts"
|
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use zero embeddings as fallback for failed ones
|
# Use zero embeddings as fallback for failed ones
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
|
query_template: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
query_template: Optional prompt template to prepend to query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Query embedding as numpy array with shape (1, D)
|
Query embedding as numpy array with shape (1, D)
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ def autodiscover_backends():
|
|||||||
discovered_backends = []
|
discovered_backends = []
|
||||||
for dist in importlib.metadata.distributions():
|
for dist in importlib.metadata.distributions():
|
||||||
dist_name = dist.metadata["name"]
|
dist_name = dist.metadata["name"]
|
||||||
|
if dist_name is None:
|
||||||
|
continue
|
||||||
if dist_name.startswith("leann-backend-"):
|
if dist_name.startswith("leann-backend-"):
|
||||||
backend_module_name = dist_name.replace("-", "_")
|
backend_module_name = dist_name.replace("-", "_")
|
||||||
discovered_backends.append(backend_module_name)
|
discovered_backends.append(backend_module_name)
|
||||||
|
|||||||
@@ -71,6 +71,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
or "mips"
|
or "mips"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Filter out ALL prompt templates from provider_options during search
|
||||||
|
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
|
||||||
|
# The server should never apply templates during search to avoid double-templating
|
||||||
|
search_provider_options = {
|
||||||
|
k: v
|
||||||
|
for k, v in self.embedding_options.items()
|
||||||
|
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
|
||||||
|
}
|
||||||
|
|
||||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||||
port=port,
|
port=port,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
@@ -78,7 +87,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=distance_metric,
|
distance_metric=distance_metric,
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
provider_options=self.embedding_options,
|
provider_options=search_provider_options,
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||||
@@ -90,6 +99,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: int = 5557,
|
zmq_port: int = 5557,
|
||||||
|
query_template: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embedding for a query string.
|
Compute embedding for a query string.
|
||||||
@@ -98,10 +108,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
query_template: Optional prompt template to prepend to query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Query embedding as numpy array
|
Query embedding as numpy array
|
||||||
"""
|
"""
|
||||||
|
# Apply query template BEFORE any computation path
|
||||||
|
# This ensures template is applied consistently for both server and fallback paths
|
||||||
|
if query_template:
|
||||||
|
query = f"{query_template}{query}"
|
||||||
|
|
||||||
# Try to use embedding server if available and requested
|
# Try to use embedding server if available and requested
|
||||||
if use_server_if_available:
|
if use_server_if_available:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Any
|
|||||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||||
|
|
||||||
|
|
||||||
def _clean_url(value: str) -> str:
|
def _clean_url(value: str) -> str:
|
||||||
@@ -52,6 +53,23 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
|
|||||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the base URL for Anthropic-compatible services."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
|
||||||
|
os.getenv("ANTHROPIC_BASE_URL"),
|
||||||
|
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||||
"""Resolve the API key for OpenAI-compatible services."""
|
"""Resolve the API key for OpenAI-compatible services."""
|
||||||
|
|
||||||
@@ -61,6 +79,15 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
|||||||
return os.getenv("OPENAI_API_KEY")
|
return os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
|
||||||
|
"""Resolve the API key for Anthropic services."""
|
||||||
|
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
|
||||||
|
return os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||||
"""Serialize provider options for child processes."""
|
"""Serialize provider options for child processes."""
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,11 @@ leann build my-project --docs $(git ls-files)
|
|||||||
# Start Claude Code
|
# Start Claude Code
|
||||||
claude
|
claude
|
||||||
```
|
```
|
||||||
|
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
leann build my-project --docs $(git ls-files) --no-recompute
|
||||||
|
```
|
||||||
|
|
||||||
## 🚀 Advanced Usage Examples to build the index
|
## 🚀 Advanced Usage Examples to build the index
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "LEANN Team" }
|
{ name = "LEANN Team" }
|
||||||
@@ -18,10 +18,10 @@ classifiers = [
|
|||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Default installation: core + hnsw + diskann
|
# Default installation: core + hnsw + diskann
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
[project]
|
[project]
|
||||||
name = "leann-workspace"
|
name = "leann-workspace"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core",
|
"leann-core",
|
||||||
@@ -57,6 +57,8 @@ dependencies = [
|
|||||||
"tree-sitter-c-sharp>=0.20.0",
|
"tree-sitter-c-sharp>=0.20.0",
|
||||||
"tree-sitter-typescript>=0.20.0",
|
"tree-sitter-typescript>=0.20.0",
|
||||||
"torchvision>=0.23.0",
|
"torchvision>=0.23.0",
|
||||||
|
"einops",
|
||||||
|
"seaborn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -67,7 +69,8 @@ diskann = [
|
|||||||
# Add a new optional dependency group for document processing
|
# Add a new optional dependency group for document processing
|
||||||
documents = [
|
documents = [
|
||||||
"beautifulsoup4>=4.13.0", # For HTML parsing
|
"beautifulsoup4>=4.13.0", # For HTML parsing
|
||||||
"python-docx>=0.8.11", # For Word documents
|
"python-docx>=0.8.11", # For Word documents (creating/editing)
|
||||||
|
"docx2txt>=0.9", # For Word documents (text extraction)
|
||||||
"openpyxl>=3.1.0", # For Excel files
|
"openpyxl>=3.1.0", # For Excel files
|
||||||
"pandas>=2.2.0", # For data processing
|
"pandas>=2.2.0", # For data processing
|
||||||
]
|
]
|
||||||
@@ -162,6 +165,7 @@ python_functions = ["test_*"]
|
|||||||
markers = [
|
markers = [
|
||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"openai: marks tests that require OpenAI API key",
|
"openai: marks tests that require OpenAI API key",
|
||||||
|
"integration: marks tests that require live services (Ollama, LM Studio, etc.)",
|
||||||
]
|
]
|
||||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||||
addopts = [
|
addopts = [
|
||||||
|
|||||||
@@ -36,6 +36,14 @@ Tests DiskANN graph partitioning functionality:
|
|||||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||||
|
|
||||||
|
### `test_prompt_template_e2e.py`
|
||||||
|
Integration tests for prompt template feature with live embedding services:
|
||||||
|
- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio)
|
||||||
|
- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default)
|
||||||
|
- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk)
|
||||||
|
- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration`
|
||||||
|
- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
### Install test dependencies:
|
### Install test dependencies:
|
||||||
@@ -66,6 +74,12 @@ pytest tests/ -m "not openai"
|
|||||||
# Skip slow tests
|
# Skip slow tests
|
||||||
pytest tests/ -m "not slow"
|
pytest tests/ -m "not slow"
|
||||||
|
|
||||||
|
# Skip integration tests (that require live services)
|
||||||
|
pytest tests/ -m "not integration"
|
||||||
|
|
||||||
|
# Run only integration tests (requires LM Studio or Ollama running)
|
||||||
|
pytest tests/test_prompt_template_e2e.py -v -s
|
||||||
|
|
||||||
# Run DiskANN partition tests (requires local machine, not CI)
|
# Run DiskANN partition tests (requires local machine, not CI)
|
||||||
pytest tests/test_diskann_partition.py
|
pytest tests/test_diskann_partition.py
|
||||||
```
|
```
|
||||||
@@ -101,6 +115,20 @@ The `pytest.ini` file configures:
|
|||||||
- Custom markers for slow and OpenAI tests
|
- Custom markers for slow and OpenAI tests
|
||||||
- Verbose output with short tracebacks
|
- Verbose output with short tracebacks
|
||||||
|
|
||||||
|
### Integration Test Prerequisites
|
||||||
|
|
||||||
|
Integration tests (`test_prompt_template_e2e.py`) require live services:
|
||||||
|
|
||||||
|
**Required:**
|
||||||
|
- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded
|
||||||
|
|
||||||
|
**Optional:**
|
||||||
|
- Ollama running at `http://localhost:11434` for token limit detection tests
|
||||||
|
- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests
|
||||||
|
|
||||||
|
Tests gracefully skip if services are unavailable.
|
||||||
|
|
||||||
### Known Issues
|
### Known Issues
|
||||||
|
|
||||||
- OpenAI tests are automatically skipped if no API key is provided
|
- OpenAI tests are automatically skipped if no API key is provided
|
||||||
|
- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -116,8 +116,10 @@ class TestChunkingFunctions:
|
|||||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
# Traditional chunks now return dict format for consistency
|
||||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
assert all(isinstance(chunk, dict) for chunk in chunks)
|
||||||
|
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
|
||||||
|
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
|
||||||
|
|
||||||
def test_create_traditional_chunks_empty_docs(self):
|
def test_create_traditional_chunks_empty_docs(self):
|
||||||
"""Test traditional chunking with empty documents."""
|
"""Test traditional chunking with empty documents."""
|
||||||
@@ -158,11 +160,22 @@ class Calculator:
|
|||||||
|
|
||||||
# Should have multiple chunks due to different functions/classes
|
# Should have multiple chunks due to different functions/classes
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
# R3: Expect dict format with "text" and "metadata" keys
|
||||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||||
|
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||||
|
"Each chunk should have 'text' and 'metadata' keys"
|
||||||
|
)
|
||||||
|
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
|
||||||
|
"Each chunk text should be non-empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check metadata is present
|
||||||
|
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
|
||||||
|
"Each chunk should have file_path metadata"
|
||||||
|
)
|
||||||
|
|
||||||
# Check that code structure is somewhat preserved
|
# Check that code structure is somewhat preserved
|
||||||
combined_content = " ".join(chunks)
|
combined_content = " ".join([c["text"] for c in chunks])
|
||||||
assert "def hello_world" in combined_content
|
assert "def hello_world" in combined_content
|
||||||
assert "class Calculator" in combined_content
|
assert "class Calculator" in combined_content
|
||||||
|
|
||||||
@@ -194,7 +207,11 @@ class Calculator:
|
|||||||
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
# R3: Traditional chunking should also return dict format for consistency
|
||||||
|
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||||
|
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||||
|
"Each chunk should have 'text' and 'metadata' keys"
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_text_chunks_ast_mode(self):
|
def test_create_text_chunks_ast_mode(self):
|
||||||
"""Test text chunking in AST mode."""
|
"""Test text chunking in AST mode."""
|
||||||
@@ -213,7 +230,11 @@ class Calculator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
# R3: AST mode should also return dict format
|
||||||
|
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||||
|
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||||
|
"Each chunk should have 'text' and 'metadata' keys"
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_text_chunks_custom_extensions(self):
|
def test_create_text_chunks_custom_extensions(self):
|
||||||
"""Test text chunking with custom code file extensions."""
|
"""Test text chunking with custom code file extensions."""
|
||||||
@@ -353,6 +374,552 @@ class MathUtils:
|
|||||||
pytest.skip("Test timed out - likely due to model download in CI")
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
|
||||||
|
class TestASTContentExtraction:
|
||||||
|
"""Test AST content extraction bug fix.
|
||||||
|
|
||||||
|
These tests verify that astchunk's dict format with 'content' key is handled correctly,
|
||||||
|
and that the extraction logic doesn't fall through to stringifying entire dicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_extract_content_from_astchunk_dict(self):
|
||||||
|
"""Test that astchunk dict format with 'content' key is handled correctly.
|
||||||
|
|
||||||
|
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
|
||||||
|
This causes fallthrough to str(chunk), stringifying the entire dict.
|
||||||
|
|
||||||
|
This test will FAIL until the bug is fixed because:
|
||||||
|
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
|
||||||
|
- Fixed code should extract just the content value
|
||||||
|
"""
|
||||||
|
# Mock the ASTChunkBuilder class
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
# Astchunk returns this format
|
||||||
|
astchunk_format_chunk = {
|
||||||
|
"content": "def hello():\n print('world')",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "test.py",
|
||||||
|
"line_count": 2,
|
||||||
|
"start_line_no": 0,
|
||||||
|
"end_line_no": 1,
|
||||||
|
"node_count": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_builder.chunkify.return_value = [astchunk_format_chunk]
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
doc = MockDocument(
|
||||||
|
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the astchunk module and its ASTChunkBuilder class
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
# Patch sys.modules to inject our mock before the import
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
# Call create_ast_chunks
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should return dict format with proper metadata
|
||||||
|
assert len(chunks) > 0, "Should return at least one chunk"
|
||||||
|
|
||||||
|
# R3: Each chunk should be a dict
|
||||||
|
chunk = chunks[0]
|
||||||
|
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||||
|
assert "text" in chunk, "Chunk should have 'text' key"
|
||||||
|
assert "metadata" in chunk, "Chunk should have 'metadata' key"
|
||||||
|
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
|
||||||
|
# CRITICAL: Should NOT contain stringified dict markers in the text field
|
||||||
|
# These assertions will FAIL with current buggy code
|
||||||
|
assert "'content':" not in chunk_text, (
|
||||||
|
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
|
||||||
|
)
|
||||||
|
assert "'metadata':" not in chunk_text, (
|
||||||
|
"Chunk text contains stringified metadata - extraction failed! "
|
||||||
|
f"Got: {chunk_text[:100]}..."
|
||||||
|
)
|
||||||
|
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
|
||||||
|
"Chunk text appears to be a stringified dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should contain actual content
|
||||||
|
assert "def hello()" in chunk_text, "Should extract actual code content"
|
||||||
|
assert "print('world')" in chunk_text, "Should extract complete code content"
|
||||||
|
|
||||||
|
# R3: Should preserve astchunk metadata
|
||||||
|
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
|
||||||
|
"Should preserve file path metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_extract_text_key_fallback(self):
|
||||||
|
"""Test that 'text' key still works for backward compatibility.
|
||||||
|
|
||||||
|
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
|
||||||
|
This test should PASS even with current code.
|
||||||
|
"""
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
# Some chunks might use "text" key
|
||||||
|
text_key_chunk = {"text": "def legacy_function():\n return True"}
|
||||||
|
mock_builder.chunkify.return_value = [text_key_chunk]
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
doc = MockDocument(
|
||||||
|
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the astchunk module
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
# Call create_ast_chunks
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should extract text correctly as dict format
|
||||||
|
assert len(chunks) > 0
|
||||||
|
chunk = chunks[0]
|
||||||
|
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||||
|
assert "text" in chunk, "Chunk should have 'text' key"
|
||||||
|
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
|
||||||
|
# Should NOT be stringified
|
||||||
|
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
|
||||||
|
|
||||||
|
# Should contain actual content
|
||||||
|
assert "def legacy_function()" in chunk_text
|
||||||
|
assert "return True" in chunk_text
|
||||||
|
|
||||||
|
def test_handles_string_chunks(self):
|
||||||
|
"""Test that plain string chunks still work.
|
||||||
|
|
||||||
|
Some chunkers might return plain strings - verify these are preserved.
|
||||||
|
This test should PASS with current code.
|
||||||
|
"""
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
# Plain string chunk
|
||||||
|
plain_string_chunk = "def simple_function():\n pass"
|
||||||
|
mock_builder.chunkify.return_value = [plain_string_chunk]
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
doc = MockDocument(
|
||||||
|
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the astchunk module
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
# Call create_ast_chunks
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should wrap string in dict format
|
||||||
|
assert len(chunks) > 0
|
||||||
|
chunk = chunks[0]
|
||||||
|
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
|
||||||
|
assert "text" in chunk, "Chunk should have 'text' key"
|
||||||
|
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
|
||||||
|
assert chunk_text == plain_string_chunk.strip(), (
|
||||||
|
"Should preserve plain string chunk content"
|
||||||
|
)
|
||||||
|
assert "def simple_function()" in chunk_text
|
||||||
|
assert "pass" in chunk_text
|
||||||
|
|
||||||
|
def test_multiple_chunks_with_mixed_formats(self):
|
||||||
|
"""Test handling of multiple chunks with different formats.
|
||||||
|
|
||||||
|
Real-world scenario: astchunk might return a mix of formats.
|
||||||
|
This test will FAIL if any chunk with 'content' key gets stringified.
|
||||||
|
"""
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
# Mix of formats
|
||||||
|
mixed_chunks = [
|
||||||
|
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
|
||||||
|
"def second():\n return 2", # Plain string
|
||||||
|
{"text": "def third():\n return 3"}, # Old format
|
||||||
|
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
|
||||||
|
]
|
||||||
|
mock_builder.chunkify.return_value = mixed_chunks
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
|
||||||
|
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
|
||||||
|
|
||||||
|
# Mock the astchunk module
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
# Call create_ast_chunks
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should extract all chunks correctly as dicts
|
||||||
|
assert len(chunks) == 4, "Should extract all 4 chunks"
|
||||||
|
|
||||||
|
# Check each chunk
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
|
||||||
|
assert "text" in chunk, f"Chunk {i} should have 'text' key"
|
||||||
|
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
|
||||||
|
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
# None should be stringified dicts
|
||||||
|
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
|
||||||
|
assert "'metadata':" not in chunk_text, (
|
||||||
|
f"Chunk {i} text is stringified (has 'metadata':)"
|
||||||
|
)
|
||||||
|
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
|
||||||
|
|
||||||
|
# Verify actual content is present
|
||||||
|
combined = "\n".join([c["text"] for c in chunks])
|
||||||
|
assert "def first()" in combined
|
||||||
|
assert "def second()" in combined
|
||||||
|
assert "def third()" in combined
|
||||||
|
assert "class MyClass:" in combined
|
||||||
|
|
||||||
|
def test_empty_content_value_handling(self):
|
||||||
|
"""Test handling of chunks with empty content values.
|
||||||
|
|
||||||
|
Edge case: chunk has 'content' key but value is empty.
|
||||||
|
Should skip these chunks, not stringify them.
|
||||||
|
"""
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
chunks_with_empty = [
|
||||||
|
{"content": "", "metadata": {"line_count": 0}}, # Empty content
|
||||||
|
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
|
||||||
|
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
|
||||||
|
]
|
||||||
|
mock_builder.chunkify.return_value = chunks_with_empty
|
||||||
|
|
||||||
|
doc = MockDocument(
|
||||||
|
"def valid():\n return True", "/test/empty.py", {"language": "python"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the astchunk module
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should only have the valid chunk (empty ones filtered out)
|
||||||
|
assert len(chunks) == 1, "Should filter out empty content chunks"
|
||||||
|
|
||||||
|
chunk = chunks[0]
|
||||||
|
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||||
|
assert "text" in chunk, "Chunk should have 'text' key"
|
||||||
|
assert "def valid()" in chunk["text"]
|
||||||
|
|
||||||
|
# Should not have stringified the empty dict
|
||||||
|
assert "'content': ''" not in chunk["text"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestASTMetadataPreservation:
|
||||||
|
"""Test metadata preservation in AST chunk dictionaries.
|
||||||
|
|
||||||
|
R3: These tests define the contract for metadata preservation when returning
|
||||||
|
chunk dictionaries instead of plain strings. Each chunk dict should have:
|
||||||
|
- "text": str - the actual chunk content
|
||||||
|
- "metadata": dict - all metadata from document AND astchunk
|
||||||
|
|
||||||
|
These tests will FAIL until G3 implementation changes return type to list[dict].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_ast_chunks_preserve_file_metadata(self):
|
||||||
|
"""Test that document metadata is preserved in chunk metadata.
|
||||||
|
|
||||||
|
This test verifies that all document-level metadata (file_path, file_name,
|
||||||
|
creation_date, last_modified_date) is included in each chunk's metadata dict.
|
||||||
|
|
||||||
|
This will FAIL because current code returns list[str], not list[dict].
|
||||||
|
"""
|
||||||
|
# Create mock document with rich metadata
|
||||||
|
python_code = '''
|
||||||
|
def calculate_sum(numbers):
|
||||||
|
"""Calculate sum of numbers."""
|
||||||
|
return sum(numbers)
|
||||||
|
|
||||||
|
class DataProcessor:
|
||||||
|
"""Process data records."""
|
||||||
|
|
||||||
|
def process(self, data):
|
||||||
|
return [x * 2 for x in data]
|
||||||
|
'''
|
||||||
|
doc = MockDocument(
|
||||||
|
python_code,
|
||||||
|
file_path="/project/src/utils.py",
|
||||||
|
metadata={
|
||||||
|
"language": "python",
|
||||||
|
"file_path": "/project/src/utils.py",
|
||||||
|
"file_name": "utils.py",
|
||||||
|
"creation_date": "2024-01-15T10:30:00",
|
||||||
|
"last_modified_date": "2024-10-31T15:45:00",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock astchunk to return chunks with metadata
|
||||||
|
mock_builder = Mock()
|
||||||
|
astchunk_chunks = [
|
||||||
|
{
|
||||||
|
"content": "def calculate_sum(numbers):\n return sum(numbers)",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "/project/src/utils.py",
|
||||||
|
"line_count": 2,
|
||||||
|
"start_line_no": 1,
|
||||||
|
"end_line_no": 2,
|
||||||
|
"node_count": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "/project/src/utils.py",
|
||||||
|
"line_count": 3,
|
||||||
|
"start_line_no": 5,
|
||||||
|
"end_line_no": 7,
|
||||||
|
"node_count": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_builder.chunkify.return_value = astchunk_chunks
|
||||||
|
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# CRITICAL: These assertions will FAIL with current list[str] return type
|
||||||
|
assert len(chunks) == 2, "Should return 2 chunks"
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
# Structure assertions - WILL FAIL: current code returns strings
|
||||||
|
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
||||||
|
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
||||||
|
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
||||||
|
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
|
||||||
|
|
||||||
|
# Document metadata preservation - WILL FAIL
|
||||||
|
metadata = chunk["metadata"]
|
||||||
|
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
|
||||||
|
assert metadata["file_path"] == "/project/src/utils.py", (
|
||||||
|
f"Chunk {i} file_path incorrect"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
|
||||||
|
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
|
||||||
|
|
||||||
|
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
|
||||||
|
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
|
||||||
|
f"Chunk {i} creation_date incorrect"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
|
||||||
|
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
|
||||||
|
f"Chunk {i} last_modified_date incorrect"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify metadata is consistent across chunks from same document
|
||||||
|
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
|
||||||
|
"All chunks from same document should have same file_path"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify text content is present and not stringified
|
||||||
|
assert "def calculate_sum" in chunks[0]["text"]
|
||||||
|
assert "class DataProcessor" in chunks[1]["text"]
|
||||||
|
|
||||||
|
def test_ast_chunks_include_astchunk_metadata(self):
|
||||||
|
"""Test that astchunk-specific metadata is merged into chunk metadata.
|
||||||
|
|
||||||
|
This test verifies that astchunk's metadata (line_count, start_line_no,
|
||||||
|
end_line_no, node_count) is merged with document metadata.
|
||||||
|
|
||||||
|
This will FAIL because current code returns list[str], not list[dict].
|
||||||
|
"""
|
||||||
|
python_code = '''
|
||||||
|
def function_one():
|
||||||
|
"""First function."""
|
||||||
|
x = 1
|
||||||
|
y = 2
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
def function_two():
|
||||||
|
"""Second function."""
|
||||||
|
return 42
|
||||||
|
'''
|
||||||
|
doc = MockDocument(
|
||||||
|
python_code,
|
||||||
|
file_path="/test/code.py",
|
||||||
|
metadata={
|
||||||
|
"language": "python",
|
||||||
|
"file_path": "/test/code.py",
|
||||||
|
"file_name": "code.py",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock astchunk with detailed metadata
|
||||||
|
mock_builder = Mock()
|
||||||
|
astchunk_chunks = [
|
||||||
|
{
|
||||||
|
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "/test/code.py",
|
||||||
|
"line_count": 4,
|
||||||
|
"start_line_no": 1,
|
||||||
|
"end_line_no": 4,
|
||||||
|
"node_count": 5, # function, assignments, return
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "def function_two():\n return 42",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "/test/code.py",
|
||||||
|
"line_count": 2,
|
||||||
|
"start_line_no": 7,
|
||||||
|
"end_line_no": 8,
|
||||||
|
"node_count": 2, # function, return
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_builder.chunkify.return_value = astchunk_chunks
|
||||||
|
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# CRITICAL: These will FAIL with current list[str] return
|
||||||
|
assert len(chunks) == 2
|
||||||
|
|
||||||
|
# First chunk - function_one
|
||||||
|
chunk1 = chunks[0]
|
||||||
|
assert isinstance(chunk1, dict), "Chunk should be dict"
|
||||||
|
assert "metadata" in chunk1
|
||||||
|
|
||||||
|
metadata1 = chunk1["metadata"]
|
||||||
|
|
||||||
|
# Check astchunk metadata is present
|
||||||
|
assert "line_count" in metadata1, "Should include astchunk line_count"
|
||||||
|
assert metadata1["line_count"] == 4, "line_count should be 4"
|
||||||
|
|
||||||
|
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
|
||||||
|
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
|
||||||
|
|
||||||
|
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
|
||||||
|
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
|
||||||
|
|
||||||
|
assert "node_count" in metadata1, "Should include astchunk node_count"
|
||||||
|
assert metadata1["node_count"] == 5, "node_count should be 5"
|
||||||
|
|
||||||
|
# Second chunk - function_two
|
||||||
|
chunk2 = chunks[1]
|
||||||
|
metadata2 = chunk2["metadata"]
|
||||||
|
|
||||||
|
assert metadata2["line_count"] == 2, "line_count should be 2"
|
||||||
|
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
|
||||||
|
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
|
||||||
|
assert metadata2["node_count"] == 2, "node_count should be 2"
|
||||||
|
|
||||||
|
# Verify document metadata is ALSO present (merged, not replaced)
|
||||||
|
assert metadata1["file_path"] == "/test/code.py"
|
||||||
|
assert metadata1["file_name"] == "code.py"
|
||||||
|
assert metadata2["file_path"] == "/test/code.py"
|
||||||
|
assert metadata2["file_name"] == "code.py"
|
||||||
|
|
||||||
|
# Verify text content is correct
|
||||||
|
assert "def function_one" in chunk1["text"]
|
||||||
|
assert "def function_two" in chunk2["text"]
|
||||||
|
|
||||||
|
def test_traditional_chunks_as_dicts_helper(self):
|
||||||
|
"""Test the helper function that wraps traditional chunks as dicts.
|
||||||
|
|
||||||
|
This test verifies that when create_traditional_chunks is called,
|
||||||
|
its plain string chunks are wrapped into dict format with metadata.
|
||||||
|
|
||||||
|
This will FAIL because the helper function _traditional_chunks_as_dicts()
|
||||||
|
doesn't exist yet, and create_traditional_chunks returns list[str].
|
||||||
|
"""
|
||||||
|
# Create documents with various metadata
|
||||||
|
docs = [
|
||||||
|
MockDocument(
|
||||||
|
"This is the first paragraph of text. It contains multiple sentences. "
|
||||||
|
"This should be split into chunks based on size.",
|
||||||
|
file_path="/docs/readme.txt",
|
||||||
|
metadata={
|
||||||
|
"file_path": "/docs/readme.txt",
|
||||||
|
"file_name": "readme.txt",
|
||||||
|
"creation_date": "2024-01-01",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockDocument(
|
||||||
|
"Second document with different metadata. It also has content that needs chunking.",
|
||||||
|
file_path="/docs/guide.md",
|
||||||
|
metadata={
|
||||||
|
"file_path": "/docs/guide.md",
|
||||||
|
"file_name": "guide.md",
|
||||||
|
"last_modified_date": "2024-10-31",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call create_traditional_chunks (which should now return list[dict])
|
||||||
|
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
|
# CRITICAL: Will FAIL - current code returns list[str]
|
||||||
|
assert len(chunks) > 0, "Should return chunks"
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
# Structure assertions - WILL FAIL
|
||||||
|
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
||||||
|
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
||||||
|
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
||||||
|
|
||||||
|
# Text should be non-empty
|
||||||
|
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
|
||||||
|
|
||||||
|
# Metadata should include document info
|
||||||
|
metadata = chunk["metadata"]
|
||||||
|
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
|
||||||
|
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
|
||||||
|
|
||||||
|
# Verify metadata tracking works correctly
|
||||||
|
# At least one chunk should be from readme.txt
|
||||||
|
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
|
||||||
|
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
|
||||||
|
|
||||||
|
# At least one chunk should be from guide.md
|
||||||
|
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
|
||||||
|
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
|
||||||
|
|
||||||
|
# Verify creation_date is preserved for readme chunks
|
||||||
|
for chunk in readme_chunks:
|
||||||
|
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
|
||||||
|
"readme.txt chunks should preserve creation_date"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify last_modified_date is preserved for guide chunks
|
||||||
|
for chunk in guide_chunks:
|
||||||
|
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
|
||||||
|
"guide.md chunks should preserve last_modified_date"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify text content is present
|
||||||
|
all_text = " ".join([c["text"] for c in chunks])
|
||||||
|
assert "first paragraph" in all_text
|
||||||
|
assert "Second document" in all_text
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
class TestErrorHandling:
|
||||||
"""Test error handling and edge cases."""
|
"""Test error handling and edge cases."""
|
||||||
|
|
||||||
|
|||||||
533
tests/test_cli_prompt_template.py
Normal file
533
tests/test_cli_prompt_template.py
Normal file
@@ -0,0 +1,533 @@
|
|||||||
|
"""
|
||||||
|
Tests for CLI argument integration of --embedding-prompt-template.
|
||||||
|
|
||||||
|
These tests verify that:
|
||||||
|
1. The --embedding-prompt-template flag is properly registered on build and search commands
|
||||||
|
2. The template value flows from CLI args to embedding_options dict
|
||||||
|
3. The template is passed through to compute_embeddings() function
|
||||||
|
4. Default behavior (no flag) is handled correctly
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from leann.cli import LeannCLI
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLIPromptTemplateArgument:
|
||||||
|
"""Tests for --embedding-prompt-template on build and search commands."""
|
||||||
|
|
||||||
|
def test_commands_accept_prompt_template_argument(self):
|
||||||
|
"""Verify that build and search parsers accept --embedding-prompt-template flag."""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
template_value = "search_query: "
|
||||||
|
|
||||||
|
# Test build command
|
||||||
|
build_args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
"/tmp/test-docs",
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template_value,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert build_args.command == "build"
|
||||||
|
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||||
|
"build command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert build_args.embedding_prompt_template == template_value
|
||||||
|
|
||||||
|
# Test search command
|
||||||
|
search_args = parser.parse_args(
|
||||||
|
["search", "test-index", "my query", "--embedding-prompt-template", template_value]
|
||||||
|
)
|
||||||
|
assert search_args.command == "search"
|
||||||
|
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||||
|
"search command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert search_args.embedding_prompt_template == template_value
|
||||||
|
|
||||||
|
def test_commands_default_to_none(self):
|
||||||
|
"""Verify default value is None when flag not provided (backward compatibility)."""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
# Test build command default
|
||||||
|
build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"])
|
||||||
|
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||||
|
"build command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert build_args.embedding_prompt_template is None, (
|
||||||
|
"Build default value should be None when flag not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test search command default
|
||||||
|
search_args = parser.parse_args(["search", "test-index", "my query"])
|
||||||
|
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||||
|
"search command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert search_args.embedding_prompt_template is None, (
|
||||||
|
"Search default value should be None when flag not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildCommandPromptTemplateArgumentExtras:
|
||||||
|
"""Additional build-specific tests for prompt template argument."""
|
||||||
|
|
||||||
|
def test_build_command_prompt_template_with_multiword_value(self):
|
||||||
|
"""
|
||||||
|
Verify that template values with spaces are handled correctly.
|
||||||
|
|
||||||
|
Templates like "search_document: " or "Represent this sentence for searching: "
|
||||||
|
should be accepted as a single string argument.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "Represent this sentence for searching: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
"/tmp/test-docs",
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert args.embedding_prompt_template == template
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateStoredInEmbeddingOptions:
|
||||||
|
"""Tests for template storage in embedding_options dict."""
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_prompt_template_stored_in_embedding_options_on_build(
|
||||||
|
self, mock_builder_class, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify that when --embedding-prompt-template is provided to build command,
|
||||||
|
the value is stored in embedding_options dict passed to LeannBuilder.
|
||||||
|
|
||||||
|
This test will fail because the CLI doesn't currently process this argument
|
||||||
|
and add it to embedding_options.
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
# Create CLI and run build command
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "search_query: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with embedding_options containing prompt_template
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when template provided"
|
||||||
|
)
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got {embedding_options.get('prompt_template')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
Verify that when --embedding-prompt-template is NOT provided,
|
||||||
|
embedding_options either doesn't have the key or it's None.
|
||||||
|
|
||||||
|
This ensures we don't pass empty/None values unnecessarily.
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that if embedding_options is passed, it doesn't have prompt_template
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
if call_kwargs.get("embedding_options"):
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
# Either the key shouldn't exist, or it should be None
|
||||||
|
assert (
|
||||||
|
"prompt_template" not in embedding_options
|
||||||
|
or embedding_options["prompt_template"] is None
|
||||||
|
), "prompt_template should not be set when flag not provided"
|
||||||
|
|
||||||
|
# R1 Tests: Build-time separate template storage
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_stores_separate_templates(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 1: Verify that when both --embedding-prompt-template and
|
||||||
|
--query-prompt-template are provided to build command, both values
|
||||||
|
are stored separately in embedding_options dict as build_prompt_template
|
||||||
|
and query_prompt_template.
|
||||||
|
|
||||||
|
This test will fail because:
|
||||||
|
1. CLI doesn't accept --query-prompt-template flag yet
|
||||||
|
2. CLI doesn't store templates as separate build_prompt_template and
|
||||||
|
query_prompt_template keys
|
||||||
|
|
||||||
|
Expected behavior after implementation:
|
||||||
|
- .meta.json contains: {"embedding_options": {
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: "
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
build_template = "doc: "
|
||||||
|
query_template = "query: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
build_template,
|
||||||
|
"--query-prompt-template",
|
||||||
|
query_template,
|
||||||
|
"--force",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with separate template keys
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when templates provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "build_prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'build_prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["build_prompt_template"] == build_template, (
|
||||||
|
f"build_prompt_template should be '{build_template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "query_prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'query_prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["query_prompt_template"] == query_template, (
|
||||||
|
f"query_prompt_template should be '{query_template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Old key should NOT be present when using new separate template format
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"Old 'prompt_template' key should not be present with separate templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 2: Verify backward compatibility - when only
|
||||||
|
--embedding-prompt-template is provided (old behavior), it should
|
||||||
|
still be stored as 'prompt_template' in embedding_options.
|
||||||
|
|
||||||
|
This ensures existing workflows continue to work unchanged.
|
||||||
|
|
||||||
|
This test currently passes because it matches existing behavior, but it
|
||||||
|
documents the requirement that this behavior must be preserved after
|
||||||
|
implementing the separate template feature.
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}}
|
||||||
|
- No build_prompt_template or query_prompt_template keys
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "prompt: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--force",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with old format
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when template provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain old 'prompt_template' key for backward compat"
|
||||||
|
)
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"prompt_template should be '{template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# New keys should NOT be present in backward compat mode
|
||||||
|
assert "build_prompt_template" not in embedding_options, (
|
||||||
|
"build_prompt_template should not be present with single template flag"
|
||||||
|
)
|
||||||
|
assert "query_prompt_template" not in embedding_options, (
|
||||||
|
"query_prompt_template should not be present with single template flag"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_no_templates(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 3: Verify that when no template flags are provided,
|
||||||
|
embedding_options has no prompt template keys.
|
||||||
|
|
||||||
|
This ensures clean defaults and no unnecessary keys in .meta.json.
|
||||||
|
|
||||||
|
This test currently passes because it matches existing behavior, but it
|
||||||
|
documents the requirement that this behavior must be preserved after
|
||||||
|
implementing the separate template feature.
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- .meta.json has no prompt_template, build_prompt_template, or
|
||||||
|
query_prompt_template keys (or embedding_options is empty/None)
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"])
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that no template keys are present
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
if call_kwargs.get("embedding_options"):
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
|
||||||
|
# None of the template keys should be present
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
assert "build_prompt_template" not in embedding_options, (
|
||||||
|
"build_prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
assert "query_prompt_template" not in embedding_options, (
|
||||||
|
"query_prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateFlowsToComputeEmbeddings:
|
||||||
|
"""Tests for template flowing through to compute_embeddings function."""
|
||||||
|
|
||||||
|
@patch("leann.api.compute_embeddings")
|
||||||
|
def test_prompt_template_flows_to_compute_embeddings_via_provider_options(
|
||||||
|
self, mock_compute_embeddings, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify that the prompt template flows from CLI args through LeannBuilder
|
||||||
|
to compute_embeddings() function via provider_options parameter.
|
||||||
|
|
||||||
|
This is an integration test that verifies the complete flow:
|
||||||
|
CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options)
|
||||||
|
|
||||||
|
This test will fail because:
|
||||||
|
1. CLI doesn't capture the argument yet
|
||||||
|
2. embedding_options doesn't include prompt_template
|
||||||
|
3. LeannBuilder doesn't pass it through to compute_embeddings
|
||||||
|
"""
|
||||||
|
# Mock compute_embeddings to return dummy embeddings as numpy array
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
|
||||||
|
# Use real LeannBuilder (not mocked) to test the actual flow
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a simple document
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "search_document: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--backend-name",
|
||||||
|
"hnsw", # Use hnsw backend
|
||||||
|
"--force", # Force rebuild to ensure index is created
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should fail because the flow isn't implemented yet
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Verify compute_embeddings was called with provider_options containing prompt_template
|
||||||
|
assert mock_compute_embeddings.called, "compute_embeddings should have been called"
|
||||||
|
|
||||||
|
# Check the call arguments
|
||||||
|
call_kwargs = mock_compute_embeddings.call_args.kwargs
|
||||||
|
assert "provider_options" in call_kwargs, (
|
||||||
|
"compute_embeddings should receive provider_options parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_options = call_kwargs["provider_options"]
|
||||||
|
assert provider_options is not None, "provider_options should not be None"
|
||||||
|
assert "prompt_template" in provider_options, (
|
||||||
|
"provider_options should contain prompt_template key"
|
||||||
|
)
|
||||||
|
assert provider_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got {provider_options.get('prompt_template')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateArgumentHelp:
|
||||||
|
"""Tests for argument help text and documentation."""
|
||||||
|
|
||||||
|
def test_build_command_prompt_template_has_help_text(self):
|
||||||
|
"""
|
||||||
|
Verify that --embedding-prompt-template has descriptive help text.
|
||||||
|
|
||||||
|
Good help text is crucial for CLI usability.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
# Get the build subparser
|
||||||
|
# This is a bit tricky - we need to parse to get the help
|
||||||
|
# We'll check that the help includes relevant keywords
|
||||||
|
import io
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
f = io.StringIO()
|
||||||
|
try:
|
||||||
|
with redirect_stdout(f):
|
||||||
|
parser.parse_args(["build", "--help"])
|
||||||
|
except SystemExit:
|
||||||
|
pass # --help causes sys.exit(0)
|
||||||
|
|
||||||
|
help_text = f.getvalue()
|
||||||
|
assert "--embedding-prompt-template" in help_text, (
|
||||||
|
"Help text should mention --embedding-prompt-template"
|
||||||
|
)
|
||||||
|
# Check for keywords that should be in the help
|
||||||
|
help_lower = help_text.lower()
|
||||||
|
assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), (
|
||||||
|
"Help text should explain what the prompt template does"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_command_prompt_template_has_help_text(self):
|
||||||
|
"""
|
||||||
|
Verify that search command also has help text for --embedding-prompt-template.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
import io
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
f = io.StringIO()
|
||||||
|
try:
|
||||||
|
with redirect_stdout(f):
|
||||||
|
parser.parse_args(["search", "--help"])
|
||||||
|
except SystemExit:
|
||||||
|
pass # --help causes sys.exit(0)
|
||||||
|
|
||||||
|
help_text = f.getvalue()
|
||||||
|
assert "--embedding-prompt-template" in help_text, (
|
||||||
|
"Search help text should mention --embedding-prompt-template"
|
||||||
|
)
|
||||||
281
tests/test_embedding_prompt_template.py
Normal file
281
tests/test_embedding_prompt_template.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""Unit tests for prompt template prepending in OpenAI embeddings.
|
||||||
|
|
||||||
|
This test suite defines the contract for prompt template functionality that allows
|
||||||
|
users to prepend a consistent prompt to all embedding inputs. These tests verify:
|
||||||
|
|
||||||
|
1. Template prepending to all input texts before embedding computation
|
||||||
|
2. Graceful handling of None/missing provider_options
|
||||||
|
3. Empty string template behavior (no-op)
|
||||||
|
4. Logging of template application for observability
|
||||||
|
5. Template application before token truncation
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
implementation does not exist yet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from leann.embedding_compute import compute_embeddings_openai
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplatePrepending:
|
||||||
|
"""Tests for prompt template prepending in compute_embeddings_openai."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client(self):
|
||||||
|
"""Create mock OpenAI client that captures input texts."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
# Mock the embeddings.create response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [
|
||||||
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||||
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||||
|
]
|
||||||
|
mock_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_module(self, mock_openai_client, monkeypatch):
|
||||||
|
"""Mock the openai module to return our mock client."""
|
||||||
|
# Mock the API key environment variable
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking")
|
||||||
|
|
||||||
|
# openai is imported inside the function, so we need to patch it there
|
||||||
|
with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai:
|
||||||
|
yield mock_openai
|
||||||
|
|
||||||
|
def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template is prepended to all input texts.
|
||||||
|
|
||||||
|
When provider_options contains "prompt_template", that template should
|
||||||
|
be prepended to every text in the input list before sending to OpenAI API.
|
||||||
|
|
||||||
|
This is the core functionality: the template acts as a consistent prefix
|
||||||
|
that provides context or instruction for the embedding model.
|
||||||
|
"""
|
||||||
|
texts = ["First document", "Second document"]
|
||||||
|
template = "search_document: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
# Call compute_embeddings_openai with provider_options
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embeddings.create was called with templated texts
|
||||||
|
mock_openai_client.embeddings.create.assert_called_once()
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
|
||||||
|
# Extract the input texts sent to API
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
|
||||||
|
# Verify template was prepended to all texts
|
||||||
|
assert len(sent_texts) == 2, "Should send same number of texts"
|
||||||
|
assert sent_texts[0] == "search_document: First document", (
|
||||||
|
"Template should be prepended to first text"
|
||||||
|
)
|
||||||
|
assert sent_texts[1] == "search_document: Second document", (
|
||||||
|
"Template should be prepended to second text"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result is valid embeddings array
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
assert result.shape == (2, 3), "Should return correct shape"
|
||||||
|
|
||||||
|
def test_template_not_applied_when_missing_or_empty(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template not applied when provider_options is None, missing key, or empty string.
|
||||||
|
|
||||||
|
This consolidated test covers three scenarios where templates should NOT be applied:
|
||||||
|
1. provider_options is None (default behavior)
|
||||||
|
2. provider_options exists but missing 'prompt_template' key
|
||||||
|
3. prompt_template is explicitly set to empty string ""
|
||||||
|
|
||||||
|
In all cases, texts should be sent to the API unchanged.
|
||||||
|
"""
|
||||||
|
# Scenario 1: None provider_options
|
||||||
|
texts = ["Original text one", "Original text two"]
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=None,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Original text one", (
|
||||||
|
"Text should be unchanged with None provider_options"
|
||||||
|
)
|
||||||
|
assert sent_texts[1] == "Original text two"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
assert result.shape == (2, 3)
|
||||||
|
|
||||||
|
# Reset mock for next scenario
|
||||||
|
mock_openai_client.reset_mock()
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [
|
||||||
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||||
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||||
|
]
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Scenario 2: Missing 'prompt_template' key
|
||||||
|
texts = ["Text without template", "Another text"]
|
||||||
|
provider_options = {"base_url": "https://api.openai.com/v1"}
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key"
|
||||||
|
assert sent_texts[1] == "Another text"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
# Reset mock for next scenario
|
||||||
|
mock_openai_client.reset_mock()
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Scenario 3: Empty string template
|
||||||
|
texts = ["Text one", "Text two"]
|
||||||
|
provider_options = {"prompt_template": ""}
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Text one", "Empty template should not modify text"
|
||||||
|
assert sent_texts[1] == "Text two"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template is prepended in all batches when texts exceed batch size.
|
||||||
|
|
||||||
|
OpenAI API has batch size limits. When input texts are split into
|
||||||
|
multiple batches, the template should be prepended to texts in every batch.
|
||||||
|
|
||||||
|
This ensures consistency across all API calls.
|
||||||
|
"""
|
||||||
|
# Create many texts that will be split into multiple batches
|
||||||
|
texts = [f"Document {i}" for i in range(1000)]
|
||||||
|
template = "passage: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
# Mock multiple batch responses
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)]
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embeddings.create was called multiple times (batching)
|
||||||
|
assert mock_openai_client.embeddings.create.call_count >= 2, (
|
||||||
|
"Should make multiple API calls for large text list"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was prepended in ALL batches
|
||||||
|
for call in mock_openai_client.embeddings.create.call_args_list:
|
||||||
|
sent_texts = call.kwargs["input"]
|
||||||
|
for text in sent_texts:
|
||||||
|
assert text.startswith(template), (
|
||||||
|
f"All texts in all batches should start with template. Got: {text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result shape
|
||||||
|
assert result.shape[0] == 1000, "Should return embeddings for all texts"
|
||||||
|
|
||||||
|
def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template with special characters is handled correctly.
|
||||||
|
|
||||||
|
Templates may contain special characters, Unicode, newlines, etc.
|
||||||
|
These should all be prepended correctly without encoding issues.
|
||||||
|
"""
|
||||||
|
texts = ["Document content"]
|
||||||
|
# Template with various special characters
|
||||||
|
template = "🔍 Search query [EN]: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify special characters in template were preserved
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
|
||||||
|
assert sent_texts[0] == "🔍 Search query [EN]: Document content", (
|
||||||
|
"Special characters in template should be preserved"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
def test_prompt_template_integration_with_existing_validation(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template works with existing input validation.
|
||||||
|
|
||||||
|
compute_embeddings_openai has validation for empty texts and whitespace.
|
||||||
|
Template prepending should happen AFTER validation, so validation errors
|
||||||
|
are thrown based on original texts, not templated texts.
|
||||||
|
|
||||||
|
This ensures users get clear error messages about their input.
|
||||||
|
"""
|
||||||
|
# Empty text should still raise ValueError even with template
|
||||||
|
texts = [""]
|
||||||
|
provider_options = {"prompt_template": "prefix: "}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="empty/invalid"):
|
||||||
|
compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prompt_template_with_api_key_and_base_url(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template works alongside other provider_options.
|
||||||
|
|
||||||
|
provider_options may contain multiple settings: prompt_template,
|
||||||
|
base_url, api_key. All should work together correctly.
|
||||||
|
"""
|
||||||
|
texts = ["Test document"]
|
||||||
|
provider_options = {
|
||||||
|
"prompt_template": "embed: ",
|
||||||
|
"base_url": "https://custom.api.com/v1",
|
||||||
|
"api_key": "test-key-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "embed: Test document"
|
||||||
|
|
||||||
|
# Verify OpenAI client was created with correct base_url
|
||||||
|
mock_openai_module.assert_called()
|
||||||
|
client_init_kwargs = mock_openai_module.call_args.kwargs
|
||||||
|
assert client_init_kwargs["base_url"] == "https://custom.api.com/v1"
|
||||||
|
assert client_init_kwargs["api_key"] == "test-key-123"
|
||||||
|
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
315
tests/test_lmstudio_bridge.py
Normal file
315
tests/test_lmstudio_bridge.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
"""Unit tests for LM Studio TypeScript SDK bridge functionality.
|
||||||
|
|
||||||
|
This test suite defines the contract for the LM Studio SDK bridge that queries
|
||||||
|
model context length via Node.js subprocess. These tests verify:
|
||||||
|
|
||||||
|
1. Successful SDK query returns context length
|
||||||
|
2. Graceful fallback when Node.js not installed (FileNotFoundError)
|
||||||
|
3. Graceful fallback when SDK not installed (npm error)
|
||||||
|
4. Timeout handling (subprocess.TimeoutExpired)
|
||||||
|
5. Invalid JSON response handling
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
`_query_lmstudio_context_limit` function does not exist yet.
|
||||||
|
|
||||||
|
The function contract:
|
||||||
|
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
|
||||||
|
- Outputs: context_length (int) or None on error
|
||||||
|
- Requirements:
|
||||||
|
1. Call Node.js with inline JavaScript using @lmstudio/sdk
|
||||||
|
2. 10-second timeout (accounts for Node.js startup)
|
||||||
|
3. Graceful fallback on any error (returns None, doesn't raise)
|
||||||
|
4. Parse JSON response with contextLength field
|
||||||
|
5. Log errors at debug level (not warning/error)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Try to import the function - if it doesn't exist, tests will fail as expected
|
||||||
|
try:
|
||||||
|
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||||
|
except ImportError:
|
||||||
|
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
|
||||||
|
def _query_lmstudio_context_limit(*args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioBridge:
|
||||||
|
"""Tests for LM Studio TypeScript SDK bridge integration."""
|
||||||
|
|
||||||
|
def test_query_lmstudio_success(self, monkeypatch):
|
||||||
|
"""Verify successful SDK query returns context length.
|
||||||
|
|
||||||
|
When the Node.js subprocess successfully queries the LM Studio SDK,
|
||||||
|
it should return a JSON response with contextLength field. The function
|
||||||
|
should parse this and return the integer context length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
# Verify timeout is set to 10 seconds
|
||||||
|
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
|
||||||
|
|
||||||
|
# Verify capture_output and text=True are set
|
||||||
|
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
|
||||||
|
assert kwargs.get("text") is True, "Should decode output as text"
|
||||||
|
|
||||||
|
# Return successful JSON response
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
# Test with typical LM Studio model
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 8192, "Should return context length from SDK response"
|
||||||
|
|
||||||
|
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when Node.js not installed.
|
||||||
|
|
||||||
|
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
|
||||||
|
The function should catch this and return None (graceful fallback to registry).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise FileNotFoundError("node: command not found")
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when Node.js not installed"
|
||||||
|
|
||||||
|
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when @lmstudio/sdk not installed.
|
||||||
|
|
||||||
|
When the SDK npm package is not installed, Node.js will return non-zero
|
||||||
|
exit code with error message in stderr. The function should detect this
|
||||||
|
and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 1
|
||||||
|
mock_result.stdout = ""
|
||||||
|
mock_result.stderr = (
|
||||||
|
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
|
||||||
|
)
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when SDK not installed"
|
||||||
|
|
||||||
|
def test_query_lmstudio_timeout(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when subprocess times out.
|
||||||
|
|
||||||
|
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
|
||||||
|
not responding), subprocess.TimeoutExpired should be raised. The function
|
||||||
|
should catch this and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None on timeout"
|
||||||
|
|
||||||
|
def test_query_lmstudio_invalid_json(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when response is invalid JSON.
|
||||||
|
|
||||||
|
When the subprocess returns malformed JSON (e.g., due to SDK error),
|
||||||
|
json.loads will raise ValueError/JSONDecodeError. The function should
|
||||||
|
catch this and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = "This is not valid JSON"
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when JSON parsing fails"
|
||||||
|
|
||||||
|
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when JSON lacks contextLength field.
|
||||||
|
|
||||||
|
When the SDK returns valid JSON but without the expected contextLength
|
||||||
|
field (e.g., error response), the function should return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="nonexistent-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength field missing"
|
||||||
|
|
||||||
|
def test_query_lmstudio_null_context_length(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when contextLength is null.
|
||||||
|
|
||||||
|
When the SDK returns contextLength: null (model couldn't be loaded),
|
||||||
|
the function should return None for registry fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength is null"
|
||||||
|
|
||||||
|
def test_query_lmstudio_zero_context_length(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when contextLength is zero.
|
||||||
|
|
||||||
|
When the SDK returns contextLength: 0 (invalid value), the function
|
||||||
|
should return None to trigger registry fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength is zero"
|
||||||
|
|
||||||
|
def test_query_lmstudio_with_custom_port(self, monkeypatch):
|
||||||
|
"""Verify SDK query works with non-default WebSocket port.
|
||||||
|
|
||||||
|
LM Studio can run on custom ports. The function should pass the
|
||||||
|
provided base_url to the Node.js subprocess.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
# Verify the base_url argument is passed correctly
|
||||||
|
command = args[0] if args else kwargs.get("args", [])
|
||||||
|
assert "ws://localhost:8080" in " ".join(command), (
|
||||||
|
"Should pass custom port to subprocess"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:8080"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 4096, "Should work with custom WebSocket port"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"context_length,expected",
|
||||||
|
[
|
||||||
|
(512, 512), # Small context
|
||||||
|
(2048, 2048), # Common context
|
||||||
|
(8192, 8192), # Large context
|
||||||
|
(32768, 32768), # Very large context
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
|
||||||
|
"""Verify SDK query handles various context length values.
|
||||||
|
|
||||||
|
Different models have different context lengths. The function should
|
||||||
|
correctly parse and return any positive integer value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == expected, f"Should return {expected} for context length {context_length}"
|
||||||
|
|
||||||
|
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
|
||||||
|
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
|
||||||
|
|
||||||
|
Following the graceful fallback pattern from Ollama implementation,
|
||||||
|
errors should be logged at debug level to avoid alarming users when
|
||||||
|
fallback to registry works fine.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise FileNotFoundError("node: command not found")
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
|
||||||
|
|
||||||
|
# Check that debug logging occurred (not warning/error)
|
||||||
|
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
|
||||||
|
assert len(debug_logs) > 0, "Should log error at DEBUG level"
|
||||||
|
|
||||||
|
# Verify no WARNING or ERROR logs
|
||||||
|
warning_or_error_logs = [
|
||||||
|
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
|
||||||
|
]
|
||||||
|
assert len(warning_or_error_logs) == 0, (
|
||||||
|
"Should not log at WARNING/ERROR level for expected failures"
|
||||||
|
)
|
||||||
400
tests/test_prompt_template_e2e.py
Normal file
400
tests/test_prompt_template_e2e.py
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
"""End-to-end integration tests for prompt template and token limit features.
|
||||||
|
|
||||||
|
These tests verify real-world functionality with live services:
|
||||||
|
- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support
|
||||||
|
- Ollama with dynamic token limit detection
|
||||||
|
- Hybrid token limit discovery mechanism
|
||||||
|
|
||||||
|
Run with: pytest tests/test_prompt_template_e2e.py -v -s
|
||||||
|
Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration"
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
1. LM Studio running with embedding model: http://localhost:1234
|
||||||
|
2. [Optional] Ollama running: ollama serve
|
||||||
|
3. [Optional] Ollama model: ollama pull nomic-embed-text
|
||||||
|
4. [Optional] Node.js + @lmstudio/sdk for context length detection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from leann.embedding_compute import (
|
||||||
|
compute_embeddings_ollama,
|
||||||
|
compute_embeddings_openai,
|
||||||
|
get_model_token_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test markers for conditional execution
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool:
|
||||||
|
"""Check if a service is available on the given host:port."""
|
||||||
|
try:
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(timeout)
|
||||||
|
result = sock.connect_ex((host, port))
|
||||||
|
sock.close()
|
||||||
|
return result == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_available() -> bool:
|
||||||
|
"""Check if Ollama service is available."""
|
||||||
|
if not check_service_available("localhost", 11434):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_lmstudio_available() -> bool:
|
||||||
|
"""Check if LM Studio service is available."""
|
||||||
|
if not check_service_available("localhost", 1234):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=2.0)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_lmstudio_first_model() -> str:
|
||||||
|
"""Get the first available model from LM Studio."""
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||||
|
data = response.json()
|
||||||
|
models = data.get("data", [])
|
||||||
|
if models:
|
||||||
|
return models[0]["id"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateOpenAI:
|
||||||
|
"""End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio)."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234"
|
||||||
|
)
|
||||||
|
def test_lmstudio_embedding_with_prompt_template(self):
|
||||||
|
"""Test prompt templates with LM Studio using OpenAI-compatible API."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
texts = ["artificial intelligence", "machine learning"]
|
||||||
|
prompt_template = "search_query: "
|
||||||
|
|
||||||
|
# Get embeddings with prompt template via provider_options
|
||||||
|
provider_options = {"prompt_template": prompt_template}
|
||||||
|
embeddings = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name=model_name,
|
||||||
|
base_url="http://localhost:1234/v1",
|
||||||
|
api_key="lm-studio", # LM Studio doesn't require real key
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embeddings is not None
|
||||||
|
assert len(embeddings) == 2
|
||||||
|
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||||
|
assert all(len(emb) > 0 for emb in embeddings)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_prompt_template_affects_embeddings(self):
|
||||||
|
"""Verify that prompt templates actually change embedding values."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
text = "machine learning"
|
||||||
|
base_url = "http://localhost:1234/v1"
|
||||||
|
api_key = "lm-studio"
|
||||||
|
|
||||||
|
# Get embeddings without template
|
||||||
|
embeddings_no_template = compute_embeddings_openai(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
provider_options={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embeddings with template
|
||||||
|
embeddings_with_template = compute_embeddings_openai(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
provider_options={"prompt_template": "search_query: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embeddings should be different when template is applied
|
||||||
|
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||||
|
|
||||||
|
logger.info("✓ Prompt template changes embedding values as expected")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateOllama:
|
||||||
|
"""End-to-end tests for prompt template with Ollama."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not check_ollama_available(), reason="Ollama service not available on localhost:11434"
|
||||||
|
)
|
||||||
|
def test_ollama_embedding_with_prompt_template(self):
|
||||||
|
"""Test prompt templates with Ollama using any available embedding model."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
model_name = embedding_models[0]
|
||||||
|
|
||||||
|
texts = ["artificial intelligence", "machine learning"]
|
||||||
|
prompt_template = "search_query: "
|
||||||
|
|
||||||
|
# Get embeddings with prompt template via provider_options
|
||||||
|
provider_options = {"prompt_template": prompt_template}
|
||||||
|
embeddings = compute_embeddings_ollama(
|
||||||
|
texts=texts,
|
||||||
|
model_name=model_name,
|
||||||
|
is_build=False,
|
||||||
|
host="http://localhost:11434",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embeddings is not None
|
||||||
|
assert len(embeddings) == 2
|
||||||
|
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||||
|
assert all(len(emb) > 0 for emb in embeddings)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_ollama_prompt_template_affects_embeddings(self):
|
||||||
|
"""Verify that prompt templates actually change embedding values with Ollama."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
model_name = embedding_models[0]
|
||||||
|
text = "machine learning"
|
||||||
|
host = "http://localhost:11434"
|
||||||
|
|
||||||
|
# Get embeddings without template
|
||||||
|
embeddings_no_template = compute_embeddings_ollama(
|
||||||
|
texts=[text], model_name=model_name, is_build=False, host=host, provider_options={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embeddings with template
|
||||||
|
embeddings_with_template = compute_embeddings_ollama(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
is_build=False,
|
||||||
|
host=host,
|
||||||
|
provider_options={"prompt_template": "search_query: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embeddings should be different when template is applied
|
||||||
|
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||||
|
|
||||||
|
logger.info("✓ Ollama prompt template changes embedding values as expected")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioSDK:
|
||||||
|
"""End-to-end tests for LM Studio SDK integration."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_model_listing(self):
|
||||||
|
"""Test that we can list models from LM Studio."""
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "data" in data
|
||||||
|
|
||||||
|
models = data["data"]
|
||||||
|
logger.info(f"✓ LM Studio models available: {len(models)}")
|
||||||
|
|
||||||
|
if models:
|
||||||
|
logger.info(f" First model: {models[0].get('id', 'unknown')}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"LM Studio API error: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_sdk_context_length_detection(self):
|
||||||
|
"""Test context length detection via LM Studio SDK bridge (requires Node.js + SDK)."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||||
|
|
||||||
|
# SDK requires WebSocket URL (ws://)
|
||||||
|
context_length = _query_lmstudio_context_limit(
|
||||||
|
model_name=model_name, base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
if context_length is None:
|
||||||
|
logger.warning(
|
||||||
|
"⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)"
|
||||||
|
)
|
||||||
|
pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable")
|
||||||
|
else:
|
||||||
|
assert context_length > 0
|
||||||
|
logger.info(
|
||||||
|
f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("_query_lmstudio_context_limit not implemented yet")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LM Studio SDK test error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaTokenLimit:
|
||||||
|
"""End-to-end tests for Ollama token limit discovery."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_ollama_token_limit_detection(self):
|
||||||
|
"""Test dynamic token limit detection from Ollama /api/show endpoint."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
test_model = embedding_models[0]
|
||||||
|
|
||||||
|
# Test token limit detection
|
||||||
|
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||||
|
|
||||||
|
assert limit > 0
|
||||||
|
logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama token detection: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridTokenLimit:
|
||||||
|
"""End-to-end tests for hybrid token limit discovery mechanism."""
|
||||||
|
|
||||||
|
def test_hybrid_discovery_registry_fallback(self):
|
||||||
|
"""Test fallback to static registry for known OpenAI models."""
|
||||||
|
# Use a known OpenAI model (should be in registry)
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
base_url="http://fake-server:9999", # Fake URL to force registry lookup
|
||||||
|
)
|
||||||
|
|
||||||
|
# text-embedding-3-small should have 8192 in registry
|
||||||
|
assert limit == 8192
|
||||||
|
logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens")
|
||||||
|
|
||||||
|
def test_hybrid_discovery_default_fallback(self):
|
||||||
|
"""Test fallback to safe default for completely unknown models."""
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="completely-unknown-model-xyz-12345",
|
||||||
|
base_url="http://fake-server:9999",
|
||||||
|
default=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get the specified default
|
||||||
|
assert limit == 512
|
||||||
|
logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_hybrid_discovery_ollama_dynamic_first(self):
|
||||||
|
"""Test that Ollama models use dynamic discovery first."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
test_model = embedding_models[0]
|
||||||
|
|
||||||
|
# Should query Ollama /api/show dynamically
|
||||||
|
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||||
|
|
||||||
|
assert limit > 0
|
||||||
|
logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test hybrid Ollama discovery: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("INTEGRATION TEST SUITE - Real Service Testing")
|
||||||
|
print("=" * 70)
|
||||||
|
print("\nThese tests require live services:")
|
||||||
|
print(" • LM Studio: http://localhost:1234 (with embedding model loaded)")
|
||||||
|
print(" • [Optional] Ollama: http://localhost:11434")
|
||||||
|
print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests")
|
||||||
|
print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s")
|
||||||
|
print("=" * 70 + "\n")
|
||||||
808
tests/test_prompt_template_persistence.py
Normal file
808
tests/test_prompt_template_persistence.py
Normal file
@@ -0,0 +1,808 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for prompt template metadata persistence and reuse.
|
||||||
|
|
||||||
|
These tests verify the complete lifecycle of prompt template persistence:
|
||||||
|
1. Template is saved to .meta.json during index build
|
||||||
|
2. Template is automatically loaded during search operations
|
||||||
|
3. Template can be overridden with explicit flag during search
|
||||||
|
4. Template is reused during chat/ask operations
|
||||||
|
|
||||||
|
These are integration tests that:
|
||||||
|
- Use real file system with temporary directories
|
||||||
|
- Run actual build and search operations
|
||||||
|
- Inspect .meta.json file contents directly
|
||||||
|
- Mock embedding servers to avoid external dependencies
|
||||||
|
- Use small test codebases for fast execution
|
||||||
|
|
||||||
|
Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateMetadataPersistence:
|
||||||
|
"""Tests for prompt template storage in .meta.json during build."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
# Return dummy embeddings as numpy array
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that when build is run with embedding_options containing prompt_template,
|
||||||
|
the template value is saved to .meta.json file.
|
||||||
|
|
||||||
|
This is the core persistence requirement - templates must be saved to allow
|
||||||
|
reuse in subsequent search operations without re-specifying the flag.
|
||||||
|
|
||||||
|
Expected failure: .meta.json exists but doesn't contain embedding_options
|
||||||
|
with prompt_template, or the value is not persisted correctly.
|
||||||
|
"""
|
||||||
|
# Setup test data
|
||||||
|
index_path = temp_index_dir / "test_index.leann"
|
||||||
|
template = "search_document: "
|
||||||
|
|
||||||
|
# Build index with prompt template in embedding_options
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a simple document
|
||||||
|
builder.add_text("This is a test document for indexing")
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify .meta.json was created and contains the template
|
||||||
|
meta_path = temp_index_dir / "test_index.leann.meta.json"
|
||||||
|
assert meta_path.exists(), ".meta.json file should be created during build"
|
||||||
|
|
||||||
|
# Read and parse metadata
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
# Verify embedding_options exists in metadata
|
||||||
|
assert "embedding_options" in meta_data, (
|
||||||
|
"embedding_options should be saved to .meta.json when provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify prompt_template is in embedding_options
|
||||||
|
embedding_options = meta_data["embedding_options"]
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"prompt_template should be saved within embedding_options"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the template value matches what we provided
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that when no prompt template is provided during build,
|
||||||
|
.meta.json either doesn't have embedding_options or prompt_template key.
|
||||||
|
|
||||||
|
This ensures clean metadata without unnecessary keys when features aren't used.
|
||||||
|
|
||||||
|
Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template.
|
||||||
|
"""
|
||||||
|
index_path = temp_index_dir / "test_no_template.leann"
|
||||||
|
|
||||||
|
# Build index WITHOUT prompt template
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
# No embedding_options provided
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text("Document without template")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
meta_path = temp_index_dir / "test_no_template.leann.meta.json"
|
||||||
|
assert meta_path.exists()
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
# If embedding_options exists, it should not contain prompt_template
|
||||||
|
if "embedding_options" in meta_data:
|
||||||
|
embedding_options = meta_data["embedding_options"]
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"prompt_template should not be in metadata when not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateAutoLoadOnSearch:
|
||||||
|
"""Tests for automatic loading of prompt template during search operations.
|
||||||
|
|
||||||
|
NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search).
|
||||||
|
This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad
|
||||||
|
which uses simpler mocking and doesn't hang.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that searching an index built WITHOUT a prompt template
|
||||||
|
works correctly (backward compatibility).
|
||||||
|
|
||||||
|
The searcher should handle missing prompt_template gracefully.
|
||||||
|
|
||||||
|
Expected behavior: Search succeeds, no template is used.
|
||||||
|
"""
|
||||||
|
# Build index without template
|
||||||
|
index_path = temp_index_dir / "no_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
)
|
||||||
|
builder.add_text("Document without template")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mocks
|
||||||
|
mock_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Create searcher and search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
# Verify no template in embedding_options
|
||||||
|
assert "prompt_template" not in searcher.embedding_options, (
|
||||||
|
"Searcher should not have prompt_template when not in metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryPromptTemplateAutoLoad:
|
||||||
|
"""Tests for automatic loading of separate query_prompt_template during search (R2).
|
||||||
|
|
||||||
|
These tests verify the new two-template system where:
|
||||||
|
- build_prompt_template: Applied during index building
|
||||||
|
- query_prompt_template: Applied during search operations
|
||||||
|
|
||||||
|
Expected to FAIL in Red Phase (R2) because query template extraction
|
||||||
|
and application is not yet implemented in LeannSearcher.search().
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_compute_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||||
|
with patch("leann.embedding_compute.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that search() automatically loads and applies query_prompt_template from .meta.json.
|
||||||
|
|
||||||
|
Given: Index built with separate build_prompt_template and query_prompt_template
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding is computed with "query: my query" (query template applied)
|
||||||
|
|
||||||
|
This is the core R2 requirement - query templates must be auto-loaded and applied
|
||||||
|
during search without user intervention.
|
||||||
|
|
||||||
|
Expected failure: compute_embeddings called with raw "my query" instead of
|
||||||
|
"query: my query" because query template extraction is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with separate templates in new format
|
||||||
|
index_path = temp_index_dir / "query_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock to ignore build calls
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search with query
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
# Mock the backend search to avoid actual search
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {
|
||||||
|
"labels": [["test_id_0"]], # IDs (nested list for batch support)
|
||||||
|
"distances": [[0.9]], # Distances (nested list for batch support)
|
||||||
|
}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: compute_embeddings was called with query template applied
|
||||||
|
assert mock_compute_embeddings.called, "compute_embeddings should be called during search"
|
||||||
|
|
||||||
|
# Get the actual text passed to compute_embeddings
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0] # First positional arg (list of texts)
|
||||||
|
|
||||||
|
assert len(texts_arg) == 1, "Should compute embedding for one query"
|
||||||
|
assert texts_arg[0] == "query: my query", (
|
||||||
|
f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify backward compatibility with old single prompt_template format.
|
||||||
|
|
||||||
|
Given: Index with old format (single prompt_template, no query_prompt_template)
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding computed with "doc: my query" (old template applied)
|
||||||
|
|
||||||
|
This ensures indexes built with the old single-template system continue
|
||||||
|
to work correctly with the new search implementation.
|
||||||
|
|
||||||
|
Expected failure: Old template not recognized/applied because backward
|
||||||
|
compatibility logic is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with old single-template format
|
||||||
|
index_path = temp_index_dir / "old_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": "doc: "}, # Old format
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: Old template was applied
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "doc: my query", (
|
||||||
|
f"Old prompt_template should be applied for backward compatibility: "
|
||||||
|
f"expected 'doc: my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify backward compatibility when no template is present in .meta.json.
|
||||||
|
|
||||||
|
Given: Index with no template in .meta.json (very old indexes)
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding computed with "my query" (no template, raw query)
|
||||||
|
|
||||||
|
This ensures the most basic backward compatibility - indexes without
|
||||||
|
any template support continue to work as before.
|
||||||
|
|
||||||
|
Expected failure: May fail if default template is incorrectly applied,
|
||||||
|
or if missing template causes error.
|
||||||
|
"""
|
||||||
|
# Setup: Build index without any template
|
||||||
|
index_path = temp_index_dir / "no_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
# No embedding_options at all
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: No template applied (raw query)
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "my query", (
|
||||||
|
f"No template should be applied when missing from metadata: "
|
||||||
|
f"expected 'my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that explicit provider_options can override stored query template.
|
||||||
|
|
||||||
|
Given: Index with query_prompt_template: "query: "
|
||||||
|
When: search() called with provider_options={"prompt_template": "override: "}
|
||||||
|
Then: Query embedding computed with "override: test" (override takes precedence)
|
||||||
|
|
||||||
|
This enables users to experiment with different query templates without
|
||||||
|
rebuilding the index, or to handle special query types differently.
|
||||||
|
|
||||||
|
Expected failure: provider_options parameter is accepted via **kwargs but
|
||||||
|
not used. Query embedding computed with raw "test" instead of "override: test"
|
||||||
|
because override logic is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with query template
|
||||||
|
index_path = temp_index_dir / "override_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search with override
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
# This should accept provider_options parameter
|
||||||
|
searcher.search(
|
||||||
|
"test",
|
||||||
|
top_k=1,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
provider_options={"prompt_template": "override: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Override template was applied
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "override: test", (
|
||||||
|
f"Override template should take precedence: "
|
||||||
|
f"expected 'override: test', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateReuseInChat:
|
||||||
|
"""Tests for prompt template reuse in chat/ask operations."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedding_server_manager(self):
|
||||||
|
"""Mock EmbeddingServerManager for chat tests."""
|
||||||
|
with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class:
|
||||||
|
mock_manager = Mock()
|
||||||
|
mock_manager.start_server.return_value = (True, 5557)
|
||||||
|
mock_manager_class.return_value = mock_manager
|
||||||
|
yield mock_manager
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def index_with_template(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""Build an index with a prompt template."""
|
||||||
|
index_path = temp_index_dir / "chat_template_index.leann"
|
||||||
|
template = "document_query: "
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text("Test document for chat")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
return str(index_path), template
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateIntegrationWithEmbeddingModes:
|
||||||
|
"""Tests for prompt template compatibility with different embedding modes."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mode,model,template,filename_prefix",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"openai",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
"Represent this for searching: ",
|
||||||
|
"openai_template",
|
||||||
|
),
|
||||||
|
("ollama", "nomic-embed-text", "search_query: ", "ollama_template"),
|
||||||
|
("sentence-transformers", "facebook/contriever", "query: ", "st_template"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_prompt_template_metadata_with_embedding_modes(
|
||||||
|
self, temp_index_dir, mode, model, template, filename_prefix
|
||||||
|
):
|
||||||
|
"""Verify prompt template is saved correctly across different embedding modes.
|
||||||
|
|
||||||
|
Tests that prompt templates are persisted to .meta.json for:
|
||||||
|
- OpenAI mode (primary use case)
|
||||||
|
- Ollama mode (also supports templates)
|
||||||
|
- Sentence-transformers mode (saved for forward compatibility)
|
||||||
|
|
||||||
|
Expected behavior: Template is saved to .meta.json regardless of mode.
|
||||||
|
"""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
|
||||||
|
index_path = temp_index_dir / f"{filename_prefix}.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model,
|
||||||
|
embedding_mode=mode,
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text(f"{mode.capitalize()} test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
assert meta_data["embedding_mode"] == mode
|
||||||
|
# Template should be saved for all modes (even if not used by some)
|
||||||
|
if "embedding_options" in meta_data:
|
||||||
|
assert meta_data["embedding_options"]["prompt_template"] == template
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryTemplateApplicationInComputeEmbedding:
|
||||||
|
"""Tests for query template application in compute_query_embedding() (Bug Fix).
|
||||||
|
|
||||||
|
These tests verify that query templates are applied consistently in BOTH
|
||||||
|
code paths (server and fallback) when computing query embeddings.
|
||||||
|
|
||||||
|
This addresses the bug where query templates were only applied in the
|
||||||
|
fallback path, not when using the embedding server (the default path).
|
||||||
|
|
||||||
|
Bug Context:
|
||||||
|
- Issue: Query templates were stored in metadata but only applied during
|
||||||
|
fallback (direct) computation, not when using embedding server
|
||||||
|
- Fix: Move template application to BEFORE any computation path in
|
||||||
|
compute_query_embedding() (searcher_base.py:107-110)
|
||||||
|
- Impact: Critical for models like EmbeddingGemma that require task-specific
|
||||||
|
templates for optimal performance
|
||||||
|
|
||||||
|
These tests ensure the fix works correctly and prevent regression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_with_template(self):
|
||||||
|
"""Create a temporary index with query template in metadata"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
index_dir = Path(tmpdir)
|
||||||
|
index_file = index_dir / "test.leann"
|
||||||
|
meta_file = index_dir / "test.leann.meta.json"
|
||||||
|
|
||||||
|
# Create minimal metadata with query template
|
||||||
|
metadata = {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": "hnsw",
|
||||||
|
"embedding_model": "text-embedding-embeddinggemma-300m-qat",
|
||||||
|
"dimensions": 768,
|
||||||
|
"embedding_mode": "openai",
|
||||||
|
"backend_kwargs": {
|
||||||
|
"graph_degree": 32,
|
||||||
|
"complexity": 64,
|
||||||
|
"distance_metric": "cosine",
|
||||||
|
},
|
||||||
|
"embedding_options": {
|
||||||
|
"base_url": "http://localhost:1234/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"build_prompt_template": "title: none | text: ",
|
||||||
|
"query_prompt_template": "task: search result | query: ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||||
|
|
||||||
|
# Create minimal HNSW index file (empty is okay for this test)
|
||||||
|
index_file.write_bytes(b"")
|
||||||
|
|
||||||
|
yield str(index_file)
|
||||||
|
|
||||||
|
def test_query_template_applied_in_fallback_path(self, temp_index_with_template):
|
||||||
|
"""Test that query template is applied when using fallback (direct) path"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
# Create a concrete implementation for testing
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
# Mock compute_embeddings to capture the query text
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
# Call compute_query_embedding with template (fallback path)
|
||||||
|
result = searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False, # Force fallback path
|
||||||
|
query_template="task: search result | query: ",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "task: search result | query: vector database"
|
||||||
|
assert result.shape == (1, 768)
|
||||||
|
|
||||||
|
def test_query_template_applied_in_server_path(self, temp_index_with_template):
|
||||||
|
"""Test that query template is applied when using server path"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
# Create a concrete implementation for testing
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
# Mock the server methods to capture the query text
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_ensure_server_running(passages_file, port):
|
||||||
|
return port
|
||||||
|
|
||||||
|
def mock_compute_embedding_via_server(chunks, port):
|
||||||
|
captured_queries.extend(chunks)
|
||||||
|
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||||
|
|
||||||
|
searcher._ensure_server_running = mock_ensure_server_running
|
||||||
|
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||||
|
|
||||||
|
# Call compute_query_embedding with template (server path)
|
||||||
|
result = searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=True, # Use server path
|
||||||
|
query_template="task: search result | query: ",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied BEFORE calling server
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "task: search result | query: vector database"
|
||||||
|
assert result.shape == (1, 768)
|
||||||
|
|
||||||
|
def test_query_template_without_template_parameter(self, temp_index_with_template):
|
||||||
|
"""Test that query is unchanged when no template is provided"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template=None, # No template
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify query is unchanged
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "vector database"
|
||||||
|
|
||||||
|
def test_query_template_consistency_between_paths(self, temp_index_with_template):
|
||||||
|
"""Test that both paths apply template identically"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
query_template = "task: search result | query: "
|
||||||
|
original_query = "vector database"
|
||||||
|
|
||||||
|
# Capture queries from fallback path
|
||||||
|
fallback_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
fallback_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query=original_query,
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template=query_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture queries from server path
|
||||||
|
server_queries = []
|
||||||
|
|
||||||
|
def mock_ensure_server_running(passages_file, port):
|
||||||
|
return port
|
||||||
|
|
||||||
|
def mock_compute_embedding_via_server(chunks, port):
|
||||||
|
server_queries.extend(chunks)
|
||||||
|
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||||
|
|
||||||
|
searcher._ensure_server_running = mock_ensure_server_running
|
||||||
|
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||||
|
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query=original_query,
|
||||||
|
use_server_if_available=True,
|
||||||
|
query_template=query_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify both paths produced identical templated queries
|
||||||
|
assert len(fallback_queries) == 1
|
||||||
|
assert len(server_queries) == 1
|
||||||
|
assert fallback_queries[0] == server_queries[0]
|
||||||
|
assert fallback_queries[0] == f"{query_template}{original_query}"
|
||||||
|
|
||||||
|
def test_query_template_with_empty_string(self, temp_index_with_template):
|
||||||
|
"""Test behavior with empty template string"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template="", # Empty string
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty string is falsy, so no template should be applied
|
||||||
|
assert captured_queries[0] == "vector database"
|
||||||
643
tests/test_token_truncation.py
Normal file
643
tests/test_token_truncation.py
Normal file
@@ -0,0 +1,643 @@
|
|||||||
|
"""Unit tests for token-aware truncation functionality.
|
||||||
|
|
||||||
|
This test suite defines the contract for token truncation functions that prevent
|
||||||
|
500 errors from Ollama when text exceeds model token limits. These tests verify:
|
||||||
|
|
||||||
|
1. Model token limit retrieval (known and unknown models)
|
||||||
|
2. Text truncation behavior for single and multiple texts
|
||||||
|
3. Token counting and truncation accuracy using tiktoken
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
implementation does not exist yet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tiktoken
|
||||||
|
from leann.embedding_compute import (
|
||||||
|
EMBEDDING_MODEL_LIMITS,
|
||||||
|
get_model_token_limit,
|
||||||
|
truncate_to_token_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelTokenLimits:
|
||||||
|
"""Tests for retrieving model-specific token limits."""
|
||||||
|
|
||||||
|
def test_get_model_token_limit_known_model(self):
|
||||||
|
"""Verify correct token limit is returned for known models.
|
||||||
|
|
||||||
|
Known models should return their specific token limits from
|
||||||
|
EMBEDDING_MODEL_LIMITS dictionary.
|
||||||
|
"""
|
||||||
|
# Test nomic-embed-text (2048 tokens)
|
||||||
|
limit = get_model_token_limit("nomic-embed-text")
|
||||||
|
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
|
||||||
|
|
||||||
|
# Test nomic-embed-text-v1.5 (2048 tokens)
|
||||||
|
limit = get_model_token_limit("nomic-embed-text-v1.5")
|
||||||
|
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
|
||||||
|
|
||||||
|
# Test nomic-embed-text-v2 (512 tokens)
|
||||||
|
limit = get_model_token_limit("nomic-embed-text-v2")
|
||||||
|
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
|
||||||
|
|
||||||
|
# Test OpenAI models (8192 tokens)
|
||||||
|
limit = get_model_token_limit("text-embedding-3-small")
|
||||||
|
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_unknown_model(self):
|
||||||
|
"""Verify default token limit is returned for unknown models.
|
||||||
|
|
||||||
|
Unknown models should return the default limit (2048) to allow
|
||||||
|
operation with reasonable safety margin.
|
||||||
|
"""
|
||||||
|
# Test with completely unknown model
|
||||||
|
limit = get_model_token_limit("unknown-model-xyz")
|
||||||
|
assert limit == 2048, "Unknown models should return default 2048"
|
||||||
|
|
||||||
|
# Test with empty string
|
||||||
|
limit = get_model_token_limit("")
|
||||||
|
assert limit == 2048, "Empty model name should return default 2048"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_custom_default(self):
|
||||||
|
"""Verify custom default can be specified for unknown models.
|
||||||
|
|
||||||
|
Allow callers to specify their own default token limit when
|
||||||
|
model is not in the known models dictionary.
|
||||||
|
"""
|
||||||
|
limit = get_model_token_limit("unknown-model", default=4096)
|
||||||
|
assert limit == 4096, "Should return custom default for unknown models"
|
||||||
|
|
||||||
|
# Known model should ignore custom default
|
||||||
|
limit = get_model_token_limit("nomic-embed-text", default=4096)
|
||||||
|
assert limit == 2048, "Known model should ignore custom default"
|
||||||
|
|
||||||
|
def test_embedding_model_limits_dictionary_exists(self):
|
||||||
|
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
|
||||||
|
|
||||||
|
The dictionary should be importable and contain at least the
|
||||||
|
known nomic models with correct token limits.
|
||||||
|
"""
|
||||||
|
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
|
||||||
|
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
|
||||||
|
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
|
||||||
|
"Should contain nomic-embed-text-v1.5"
|
||||||
|
)
|
||||||
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
|
||||||
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
|
||||||
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
|
||||||
|
# OpenAI models
|
||||||
|
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenTruncation:
|
||||||
|
"""Tests for truncating texts to token limits."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokenizer(self):
|
||||||
|
"""Provide tiktoken tokenizer for token counting verification."""
|
||||||
|
return tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
def test_truncate_single_text_under_limit(self, tokenizer):
|
||||||
|
"""Verify text under token limit remains unchanged.
|
||||||
|
|
||||||
|
When text is already within the token limit, it should be
|
||||||
|
returned unchanged with no truncation.
|
||||||
|
"""
|
||||||
|
text = "This is a short text that is well under the token limit."
|
||||||
|
token_count = len(tokenizer.encode(text))
|
||||||
|
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
|
||||||
|
|
||||||
|
# Truncate with generous limit
|
||||||
|
result = truncate_to_token_limit([text], token_limit=512)
|
||||||
|
|
||||||
|
assert len(result) == 1, "Should return same number of texts"
|
||||||
|
assert result[0] == text, "Text under limit should be unchanged"
|
||||||
|
|
||||||
|
def test_truncate_single_text_over_limit(self, tokenizer):
|
||||||
|
"""Verify text over token limit is truncated correctly.
|
||||||
|
|
||||||
|
When text exceeds the token limit, it should be truncated to
|
||||||
|
fit within the limit while maintaining valid token boundaries.
|
||||||
|
"""
|
||||||
|
# Create a text that definitely exceeds limit
|
||||||
|
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
|
||||||
|
original_token_count = len(tokenizer.encode(text))
|
||||||
|
assert original_token_count > 50, (
|
||||||
|
f"Test setup: text should be long (has {original_token_count} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate to 50 tokens
|
||||||
|
result = truncate_to_token_limit([text], token_limit=50)
|
||||||
|
|
||||||
|
assert len(result) == 1, "Should return same number of texts"
|
||||||
|
assert result[0] != text, "Text over limit should be truncated"
|
||||||
|
assert len(result[0]) < len(text), "Truncated text should be shorter"
|
||||||
|
|
||||||
|
# Verify truncated text is within token limit
|
||||||
|
truncated_token_count = len(tokenizer.encode(result[0]))
|
||||||
|
assert truncated_token_count <= 50, (
|
||||||
|
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
|
||||||
|
"""Verify multiple texts with mixed lengths are handled correctly.
|
||||||
|
|
||||||
|
When processing multiple texts:
|
||||||
|
- Texts under limit should remain unchanged
|
||||||
|
- Texts over limit should be truncated independently
|
||||||
|
- Output list should maintain same order and length
|
||||||
|
"""
|
||||||
|
texts = [
|
||||||
|
"Short text.", # Under limit
|
||||||
|
"word " * 200, # Over limit
|
||||||
|
"Another short one.", # Under limit
|
||||||
|
"token " * 150, # Over limit
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify test setup
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
token_count = len(tokenizer.encode(text))
|
||||||
|
if i in [1, 3]:
|
||||||
|
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
|
||||||
|
else:
|
||||||
|
assert token_count < 50, (
|
||||||
|
f"Text {i} should be under limit (has {token_count} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate with 50 token limit
|
||||||
|
result = truncate_to_token_limit(texts, token_limit=50)
|
||||||
|
|
||||||
|
assert len(result) == len(texts), "Should return same number of texts"
|
||||||
|
|
||||||
|
# Verify each text individually
|
||||||
|
for i, (original, truncated) in enumerate(zip(texts, result)):
|
||||||
|
token_count = len(tokenizer.encode(truncated))
|
||||||
|
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
|
||||||
|
|
||||||
|
# Short texts should be unchanged
|
||||||
|
if i in [0, 2]:
|
||||||
|
assert truncated == original, f"Short text {i} should be unchanged"
|
||||||
|
# Long texts should be truncated
|
||||||
|
else:
|
||||||
|
assert len(truncated) < len(original), f"Long text {i} should be truncated"
|
||||||
|
|
||||||
|
def test_truncate_empty_list(self):
|
||||||
|
"""Verify empty input list returns empty output list.
|
||||||
|
|
||||||
|
Edge case: empty list should return empty list without errors.
|
||||||
|
"""
|
||||||
|
result = truncate_to_token_limit([], token_limit=512)
|
||||||
|
assert result == [], "Empty input should return empty output"
|
||||||
|
|
||||||
|
def test_truncate_preserves_order(self, tokenizer):
|
||||||
|
"""Verify truncation preserves original text order.
|
||||||
|
|
||||||
|
Output list should maintain the same order as input list,
|
||||||
|
regardless of which texts were truncated.
|
||||||
|
"""
|
||||||
|
texts = [
|
||||||
|
"First text " * 50, # Will be truncated
|
||||||
|
"Second text.", # Won't be truncated
|
||||||
|
"Third text " * 50, # Will be truncated
|
||||||
|
]
|
||||||
|
|
||||||
|
result = truncate_to_token_limit(texts, token_limit=20)
|
||||||
|
|
||||||
|
assert len(result) == 3, "Should preserve list length"
|
||||||
|
# Check that order is maintained by looking for distinctive words
|
||||||
|
assert "First" in result[0], "First text should remain in first position"
|
||||||
|
assert "Second" in result[1], "Second text should remain in second position"
|
||||||
|
assert "Third" in result[2], "Third text should remain in third position"
|
||||||
|
|
||||||
|
def test_truncate_extremely_long_text(self, tokenizer):
|
||||||
|
"""Verify extremely long texts are truncated efficiently.
|
||||||
|
|
||||||
|
Test with text that far exceeds token limit to ensure
|
||||||
|
truncation handles extreme cases without performance issues.
|
||||||
|
"""
|
||||||
|
# Create very long text (simulate real-world scenario)
|
||||||
|
text = "token " * 5000 # ~5000+ tokens
|
||||||
|
original_token_count = len(tokenizer.encode(text))
|
||||||
|
assert original_token_count > 1000, "Test setup: text should be very long"
|
||||||
|
|
||||||
|
# Truncate to small limit
|
||||||
|
result = truncate_to_token_limit([text], token_limit=100)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
truncated_token_count = len(tokenizer.encode(result[0]))
|
||||||
|
assert truncated_token_count <= 100, (
|
||||||
|
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
|
||||||
|
)
|
||||||
|
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
|
||||||
|
|
||||||
|
def test_truncate_exact_token_limit(self, tokenizer):
|
||||||
|
"""Verify text at exactly token limit is handled correctly.
|
||||||
|
|
||||||
|
Edge case: text with exactly the token limit should either
|
||||||
|
remain unchanged or be safely truncated by 1 token.
|
||||||
|
"""
|
||||||
|
# Create text with approximately 50 tokens
|
||||||
|
# We'll adjust to get exactly 50
|
||||||
|
target_tokens = 50
|
||||||
|
text = "word " * 50
|
||||||
|
tokens = tokenizer.encode(text)
|
||||||
|
|
||||||
|
# Adjust to get exactly target_tokens
|
||||||
|
if len(tokens) > target_tokens:
|
||||||
|
tokens = tokens[:target_tokens]
|
||||||
|
text = tokenizer.decode(tokens)
|
||||||
|
elif len(tokens) < target_tokens:
|
||||||
|
# Add more words
|
||||||
|
while len(tokenizer.encode(text)) < target_tokens:
|
||||||
|
text += "word "
|
||||||
|
tokens = tokenizer.encode(text)[:target_tokens]
|
||||||
|
text = tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
# Verify we have exactly target_tokens
|
||||||
|
assert len(tokenizer.encode(text)) == target_tokens, (
|
||||||
|
"Test setup: should have exactly 50 tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = truncate_to_token_limit([text], token_limit=target_tokens)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
result_tokens = len(tokenizer.encode(result[0]))
|
||||||
|
assert result_tokens <= target_tokens, (
|
||||||
|
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioHybridDiscovery:
|
||||||
|
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
|
||||||
|
|
||||||
|
These tests verify that get_model_token_limit() properly integrates with
|
||||||
|
the LM Studio SDK bridge for dynamic token limit discovery. The integration
|
||||||
|
should:
|
||||||
|
|
||||||
|
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
|
||||||
|
2. Convert HTTP URLs to WebSocket format for SDK queries
|
||||||
|
3. Query LM Studio SDK and use discovered limit
|
||||||
|
4. Fall back to registry when SDK returns None
|
||||||
|
5. Execute AFTER Ollama detection but BEFORE registry fallback
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
|
||||||
|
"""Verify LM Studio SDK query succeeds and returns detected limit.
|
||||||
|
|
||||||
|
When a LM Studio base_url is detected and the SDK query succeeds,
|
||||||
|
get_model_token_limit() should return the dynamically discovered
|
||||||
|
context length without falling back to the registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock _query_lmstudio_context_limit to return successful SDK query
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
# Verify WebSocket URL was passed (not HTTP)
|
||||||
|
assert base_url.startswith("ws://"), (
|
||||||
|
f"Should convert HTTP to WebSocket format, got: {base_url}"
|
||||||
|
)
|
||||||
|
return 8192 # Successful SDK query
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with HTTP URL that should be converted to WebSocket
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="custom-model", base_url="http://localhost:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 8192, "Should return limit from LM Studio SDK query"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
|
||||||
|
"""Verify fallback to registry when LM Studio SDK returns None.
|
||||||
|
|
||||||
|
When LM Studio SDK query fails (returns None), get_model_token_limit()
|
||||||
|
should fall back to the EMBEDDING_MODEL_LIMITS registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock _query_lmstudio_context_limit to return None (SDK failure)
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
return None # SDK query failed
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with known model that exists in registry
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fall back to registry value
|
||||||
|
assert limit == 2048, "Should fall back to registry when SDK returns None"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
|
||||||
|
"""Verify detection of LM Studio via port 1234.
|
||||||
|
|
||||||
|
get_model_token_limit() should recognize port 1234 as a LM Studio
|
||||||
|
server and attempt SDK query, regardless of hostname.
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with port 1234 (default LM Studio port)
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
|
||||||
|
|
||||||
|
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
|
||||||
|
assert limit == 4096, "Should return SDK query result"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_url,expected_limit,keyword",
|
||||||
|
[
|
||||||
|
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
|
||||||
|
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_model_token_limit_lmstudio_url_keyword_detection(
|
||||||
|
self, monkeypatch, test_url, expected_limit, keyword
|
||||||
|
):
|
||||||
|
"""Verify detection of LM Studio via keywords in URL.
|
||||||
|
|
||||||
|
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
|
||||||
|
in the URL as indicating a LM Studio server.
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return expected_limit
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
|
||||||
|
|
||||||
|
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
|
||||||
|
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_url,expected_protocol,expected_host",
|
||||||
|
[
|
||||||
|
("http://localhost:1234/v1", "ws://", "localhost:1234"),
|
||||||
|
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_model_token_limit_protocol_conversion(
|
||||||
|
self, monkeypatch, input_url, expected_protocol, expected_host
|
||||||
|
):
|
||||||
|
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
|
||||||
|
|
||||||
|
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
|
||||||
|
1. Convert 'http://' to 'ws://'
|
||||||
|
2. Convert 'https://' to 'wss://'
|
||||||
|
3. Remove '/v1' or other path suffixes (SDK expects base URL)
|
||||||
|
4. Preserve host and port
|
||||||
|
"""
|
||||||
|
conversions_tested = []
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
conversions_tested.append(base_url)
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_model_token_limit(model_name="test-model", base_url=input_url)
|
||||||
|
|
||||||
|
# Verify conversion happened
|
||||||
|
assert len(conversions_tested) == 1, "Should have called SDK query once"
|
||||||
|
assert conversions_tested[0].startswith(expected_protocol), (
|
||||||
|
f"Should convert to {expected_protocol}"
|
||||||
|
)
|
||||||
|
assert expected_host in conversions_tested[0], (
|
||||||
|
f"Should preserve host and port: {expected_host}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
|
||||||
|
"""Verify LM Studio detection happens AFTER Ollama detection.
|
||||||
|
|
||||||
|
The hybrid discovery order should be:
|
||||||
|
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
|
||||||
|
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
|
||||||
|
3. Registry fallback
|
||||||
|
|
||||||
|
If both Ollama and LM Studio patterns match, Ollama should take precedence.
|
||||||
|
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
|
||||||
|
"""
|
||||||
|
ollama_called = False
|
||||||
|
lmstudio_called = False
|
||||||
|
|
||||||
|
def mock_query_ollama(model_name, base_url):
|
||||||
|
nonlocal ollama_called
|
||||||
|
ollama_called = True
|
||||||
|
return 2048 # Ollama query succeeds
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal lmstudio_called
|
||||||
|
lmstudio_called = True
|
||||||
|
return None # Should not be reached if Ollama succeeds
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_ollama_context_limit",
|
||||||
|
mock_query_ollama,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with Ollama URL
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="test-model", base_url="http://localhost:11434/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ollama_called, "Should attempt Ollama query first"
|
||||||
|
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
|
||||||
|
assert limit == 2048, "Should return Ollama result"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
|
||||||
|
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
|
||||||
|
|
||||||
|
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
|
||||||
|
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
|
||||||
|
"""
|
||||||
|
lmstudio_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal lmstudio_called
|
||||||
|
lmstudio_called = True
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with non-LM Studio URLs
|
||||||
|
test_cases = [
|
||||||
|
"http://localhost:8080/v1", # Different port
|
||||||
|
"http://openai.example.com/v1", # Different service
|
||||||
|
"http://localhost:3000/v1", # Another port
|
||||||
|
]
|
||||||
|
|
||||||
|
for base_url in test_cases:
|
||||||
|
lmstudio_called = False # Reset for each test
|
||||||
|
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
|
||||||
|
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
|
||||||
|
"""Verify LM Studio detection is case-insensitive for keywords.
|
||||||
|
|
||||||
|
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
|
||||||
|
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test various case variations
|
||||||
|
test_cases = [
|
||||||
|
"http://LMStudio.local:8080/v1",
|
||||||
|
"http://LMSTUDIO.example.com/v1",
|
||||||
|
"http://LmStudio.local/v1",
|
||||||
|
"http://api.LM.STUDIO:5000/v1",
|
||||||
|
]
|
||||||
|
|
||||||
|
for base_url in test_cases:
|
||||||
|
query_called = False # Reset for each test
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
|
||||||
|
assert query_called, f"Should detect LM Studio in URL: {base_url}"
|
||||||
|
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenLimitCaching:
|
||||||
|
"""Tests for token limit caching to prevent repeated SDK/API calls.
|
||||||
|
|
||||||
|
Caching prevents duplicate SDK/API calls within the same Python process,
|
||||||
|
which is important because:
|
||||||
|
1. LM Studio SDK load() can load duplicate model instances
|
||||||
|
2. Ollama /api/show queries add latency
|
||||||
|
3. Registry lookups are pure overhead
|
||||||
|
|
||||||
|
Cache is process-scoped and resets between leann build invocations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Clear cache before each test."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
_token_limit_cache.clear()
|
||||||
|
|
||||||
|
def test_registry_lookup_is_cached(self):
|
||||||
|
"""Verify that registry lookups are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# First call
|
||||||
|
limit1 = get_model_token_limit("text-embedding-3-small")
|
||||||
|
assert limit1 == 8192
|
||||||
|
|
||||||
|
# Verify it's in cache
|
||||||
|
cache_key = ("text-embedding-3-small", "")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 8192
|
||||||
|
|
||||||
|
# Second call should use cache
|
||||||
|
limit2 = get_model_token_limit("text-embedding-3-small")
|
||||||
|
assert limit2 == 8192
|
||||||
|
|
||||||
|
def test_default_fallback_is_cached(self):
|
||||||
|
"""Verify that default fallbacks are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# First call with unknown model
|
||||||
|
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||||
|
assert limit1 == 512
|
||||||
|
|
||||||
|
# Verify it's in cache
|
||||||
|
cache_key = ("unknown-model-xyz", "")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 512
|
||||||
|
|
||||||
|
# Second call should use cache
|
||||||
|
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||||
|
assert limit2 == 512
|
||||||
|
|
||||||
|
def test_different_urls_create_separate_cache_entries(self):
|
||||||
|
"""Verify that different base_urls create separate cache entries."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# Same model, different URLs
|
||||||
|
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
|
||||||
|
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
|
||||||
|
|
||||||
|
# Both should find the model in registry (2048)
|
||||||
|
assert limit1 == 2048
|
||||||
|
assert limit2 == 2048
|
||||||
|
|
||||||
|
# But they should be separate cache entries
|
||||||
|
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
|
||||||
|
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
|
||||||
|
|
||||||
|
assert cache_key1 in _token_limit_cache
|
||||||
|
assert cache_key2 in _token_limit_cache
|
||||||
|
assert len(_token_limit_cache) == 2
|
||||||
|
|
||||||
|
def test_cache_prevents_repeated_lookups(self):
|
||||||
|
"""Verify that cache prevents repeated registry/API lookups."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
model_name = "text-embedding-ada-002"
|
||||||
|
|
||||||
|
# First call - should add to cache
|
||||||
|
assert len(_token_limit_cache) == 0
|
||||||
|
limit1 = get_model_token_limit(model_name)
|
||||||
|
|
||||||
|
cache_size_after_first = len(_token_limit_cache)
|
||||||
|
assert cache_size_after_first == 1
|
||||||
|
|
||||||
|
# Multiple subsequent calls - cache size should not change
|
||||||
|
for _ in range(5):
|
||||||
|
limit = get_model_token_limit(model_name)
|
||||||
|
assert limit == limit1
|
||||||
|
assert len(_token_limit_cache) == cache_size_after_first
|
||||||
|
|
||||||
|
def test_versioned_model_names_cached_correctly(self):
|
||||||
|
"""Verify that versioned model names (e.g., model:tag) are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# Model with version tag
|
||||||
|
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
|
||||||
|
assert limit == 2048
|
||||||
|
|
||||||
|
# Should be cached with full name including version
|
||||||
|
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 2048
|
||||||
Reference in New Issue
Block a user