Compare commits
12 Commits
embed-laun
...
fix/pdf-du
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
877fbe81f4 | ||
|
|
eb909ccec5 | ||
|
|
969f514564 | ||
|
|
1ef9cba7de | ||
|
|
a63550944b | ||
|
|
97493a2896 | ||
|
|
f7d2dc6e7c | ||
|
|
ea86b283cb | ||
|
|
e7519bceaa | ||
|
|
abf0b2c676 | ||
|
|
3c4785bb63 | ||
|
|
930b79cc98 |
2
.github/workflows/link-check.yml
vendored
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: lycheeverse/lychee-action@v2
|
- uses: lycheeverse/lychee-action@v2
|
||||||
with:
|
with:
|
||||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
12
README.md
12
README.md
@@ -16,12 +16,24 @@
|
|||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
|
||||||
|
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
|
||||||
|
</a>
|
||||||
|
<p>
|
||||||
|
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
|
||||||
|
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
|
||||||
|
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
The smallest vector index in the world. RAG Everything with LEANN!
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
|
|||||||
flexible message processing options.
|
flexible message processing options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -146,16 +147,16 @@ class SlackMCPReader:
|
|||||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
error_dict = eval(match.group(1))
|
error_dict = ast.literal_eval(match.group(1))
|
||||||
except (ValueError, SyntaxError, NameError):
|
except (ValueError, SyntaxError):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Try alternative format
|
# Try alternative format
|
||||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
error_dict = eval(match.group(1))
|
error_dict = ast.literal_eval(match.group(1))
|
||||||
except (ValueError, SyntaxError, NameError):
|
except (ValueError, SyntaxError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self._is_cache_sync_error(error_dict):
|
if self._is_cache_sync_error(error_dict):
|
||||||
|
|||||||
@@ -158,6 +158,95 @@ builder.build_index("./indexes/my-notes", chunks)
|
|||||||
|
|
||||||
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
|
## Optional Embedding Features
|
||||||
|
|
||||||
|
### Task-Specific Prompt Templates
|
||||||
|
|
||||||
|
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
|
||||||
|
|
||||||
|
- **Indexing documents**: `"title: none | text: "`
|
||||||
|
- **Search queries**: `"task: search result | query: "`
|
||||||
|
|
||||||
|
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build index with EmbeddingGemma (via LM Studio or Ollama)
|
||||||
|
leann build my-docs \
|
||||||
|
--docs ./documents \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-embeddinggemma-300m-qat \
|
||||||
|
--embedding-api-base http://localhost:1234/v1 \
|
||||||
|
--embedding-prompt-template "title: none | text: " \
|
||||||
|
--force
|
||||||
|
|
||||||
|
# Search with query-specific prompt
|
||||||
|
leann search my-docs \
|
||||||
|
--query "What is quantum computing?" \
|
||||||
|
--embedding-prompt-template "task: search result | query: "
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important Notes:**
|
||||||
|
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
|
||||||
|
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
|
||||||
|
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
|
||||||
|
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
|
||||||
|
|
||||||
|
**Python API:**
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-embeddinggemma-300m-qat",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://localhost:1234/v1",
|
||||||
|
"api_key": "lm-studio",
|
||||||
|
"prompt_template": "title: none | text: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-docs", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
**References:**
|
||||||
|
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
|
||||||
|
|
||||||
|
### LM Studio Auto-Detection (Optional)
|
||||||
|
|
||||||
|
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
|
||||||
|
|
||||||
|
**Prerequisites:**
|
||||||
|
```bash
|
||||||
|
# Install Node.js (if not already installed)
|
||||||
|
# Then install the LM Studio SDK globally
|
||||||
|
npm install -g @lmstudio/sdk
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works:**
|
||||||
|
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
|
||||||
|
2. Queries model metadata via Node.js subprocess
|
||||||
|
3. Automatically unloads model after query (respects your JIT auto-evict settings)
|
||||||
|
4. Falls back to static registry if SDK unavailable
|
||||||
|
|
||||||
|
**No configuration needed** - it works automatically when SDK is installed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
leann build my-docs \
|
||||||
|
--docs ./documents \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-nomic-embed-text-v1.5 \
|
||||||
|
--embedding-api-base http://localhost:1234/v1
|
||||||
|
# Context length auto-detected if SDK available
|
||||||
|
# Falls back to registry (2048) if not
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
- ✅ Automatic token limit detection
|
||||||
|
- ✅ Respects LM Studio JIT auto-evict settings
|
||||||
|
- ✅ No manual registry maintenance
|
||||||
|
- ✅ Graceful fallback if SDK unavailable
|
||||||
|
|
||||||
|
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
|
|||||||
48
docs/faq.md
48
docs/faq.md
@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
|||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
```
|
```
|
||||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
|
|
||||||
|
## 2. When should I use prompt templates?
|
||||||
|
|
||||||
|
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
|
||||||
|
|
||||||
|
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
|
||||||
|
|
||||||
|
**Example usage with EmbeddingGemma:**
|
||||||
|
```bash
|
||||||
|
# Build with document prompt
|
||||||
|
leann build my-docs --embedding-prompt-template "title: none | text: "
|
||||||
|
|
||||||
|
# Search with query prompt
|
||||||
|
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
|
||||||
|
|
||||||
|
## 3. Why is LM Studio loading multiple copies of my model?
|
||||||
|
|
||||||
|
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
|
||||||
|
|
||||||
|
**If you still see duplicates:**
|
||||||
|
- Update to the latest LEANN version
|
||||||
|
- Restart LM Studio to clear loaded models
|
||||||
|
- Check that you have JIT auto-evict enabled in LM Studio settings
|
||||||
|
|
||||||
|
**How it works now:**
|
||||||
|
1. LEANN loads model temporarily to get context length
|
||||||
|
2. Immediately unloads after query
|
||||||
|
3. LM Studio JIT loads model on-demand for actual embeddings
|
||||||
|
4. Auto-evicts per your settings
|
||||||
|
|
||||||
|
## 4. Do I need Node.js and @lmstudio/sdk?
|
||||||
|
|
||||||
|
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
|
||||||
|
|
||||||
|
**Benefits if you install it:**
|
||||||
|
- Automatic context length detection for LM Studio models
|
||||||
|
- No manual registry maintenance
|
||||||
|
- Always gets accurate token limits from the model itself
|
||||||
|
|
||||||
|
**To install (optional):**
|
||||||
|
```bash
|
||||||
|
npm install -g @lmstudio/sdk
|
||||||
|
```
|
||||||
|
|
||||||
|
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.3.4",
|
"leann-core==0.3.5",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -916,6 +916,7 @@ class LeannSearcher:
|
|||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
use_grep: bool = False,
|
use_grep: bool = False,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
@@ -979,10 +980,24 @@ class LeannSearcher:
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Extract query template from stored embedding_options with fallback chain:
|
||||||
|
# 1. Check provider_options override (highest priority)
|
||||||
|
# 2. Check query_prompt_template (new format)
|
||||||
|
# 3. Check prompt_template (old format for backward compat)
|
||||||
|
# 4. None (no template)
|
||||||
|
query_template = None
|
||||||
|
if provider_options and "prompt_template" in provider_options:
|
||||||
|
query_template = provider_options["prompt_template"]
|
||||||
|
elif "query_prompt_template" in self.embedding_options:
|
||||||
|
query_template = self.embedding_options["query_prompt_template"]
|
||||||
|
elif "prompt_template" in self.embedding_options:
|
||||||
|
query_template = self.embedding_options["prompt_template"]
|
||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
query,
|
query,
|
||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
|
query_template=query_template,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
|
|||||||
@@ -144,6 +144,18 @@ Examples:
|
|||||||
default=None,
|
default=None,
|
||||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
)
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--query-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template for queries (different from build template for task-specific models)",
|
||||||
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||||
)
|
)
|
||||||
@@ -260,6 +272,12 @@ Examples:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Display file paths and metadata in search results",
|
help="Display file paths and metadata in search results",
|
||||||
)
|
)
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
|
||||||
|
)
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
@@ -1162,6 +1180,11 @@ Examples:
|
|||||||
print(f"Warning: Could not process {file_path}: {e}")
|
print(f"Warning: Could not process {file_path}: {e}")
|
||||||
|
|
||||||
# Load other file types with default reader
|
# Load other file types with default reader
|
||||||
|
# Exclude PDFs from code_extensions if they were already processed separately
|
||||||
|
other_file_extensions = code_extensions
|
||||||
|
if should_process_pdfs and ".pdf" in code_extensions:
|
||||||
|
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create a custom file filter function using our PathSpec
|
# Create a custom file filter function using our PathSpec
|
||||||
def file_filter(
|
def file_filter(
|
||||||
@@ -1177,15 +1200,19 @@ Examples:
|
|||||||
except (ValueError, OSError):
|
except (ValueError, OSError):
|
||||||
return True # Include files that can't be processed
|
return True # Include files that can't be processed
|
||||||
|
|
||||||
other_docs = SimpleDirectoryReader(
|
# Only load other file types if there are extensions to process
|
||||||
docs_dir,
|
if other_file_extensions:
|
||||||
recursive=True,
|
other_docs = SimpleDirectoryReader(
|
||||||
encoding="utf-8",
|
docs_dir,
|
||||||
required_exts=code_extensions,
|
recursive=True,
|
||||||
file_extractor={}, # Use default extractors
|
encoding="utf-8",
|
||||||
exclude_hidden=not include_hidden,
|
required_exts=other_file_extensions,
|
||||||
filename_as_id=True,
|
file_extractor={}, # Use default extractors
|
||||||
).load_data(show_progress=True)
|
exclude_hidden=not include_hidden,
|
||||||
|
filename_as_id=True,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
else:
|
||||||
|
other_docs = []
|
||||||
|
|
||||||
# Filter documents after loading based on gitignore rules
|
# Filter documents after loading based on gitignore rules
|
||||||
filtered_docs = []
|
filtered_docs = []
|
||||||
@@ -1398,6 +1425,14 @@ Examples:
|
|||||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
if resolved_embedding_key:
|
if resolved_embedding_key:
|
||||||
embedding_options["api_key"] = resolved_embedding_key
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
if args.query_prompt_template:
|
||||||
|
# New format: separate templates
|
||||||
|
if args.embedding_prompt_template:
|
||||||
|
embedding_options["build_prompt_template"] = args.embedding_prompt_template
|
||||||
|
embedding_options["query_prompt_template"] = args.query_prompt_template
|
||||||
|
elif args.embedding_prompt_template:
|
||||||
|
# Old format: single template (backward compat)
|
||||||
|
embedding_options["prompt_template"] = args.embedding_prompt_template
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend_name,
|
backend_name=args.backend_name,
|
||||||
@@ -1519,6 +1554,11 @@ Examples:
|
|||||||
print("Invalid input. Aborting search.")
|
print("Invalid input. Aborting search.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Build provider_options for runtime override
|
||||||
|
provider_options = {}
|
||||||
|
if args.embedding_prompt_template:
|
||||||
|
provider_options["prompt_template"] = args.embedding_prompt_template
|
||||||
|
|
||||||
searcher = LeannSearcher(index_path=index_path)
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
results = searcher.search(
|
results = searcher.search(
|
||||||
query,
|
query,
|
||||||
@@ -1528,6 +1568,7 @@ Examples:
|
|||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy,
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
provider_options=provider_options if provider_options else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Search results for '{query}' (top {len(results)}):")
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ Consolidates all embedding computation logic using SentenceTransformer
|
|||||||
Preserves all optimization parameters to ensure performance
|
Preserves all optimization parameters to ensure performance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@@ -40,6 +42,11 @@ EMBEDDING_MODEL_LIMITS = {
|
|||||||
"text-embedding-ada-002": 8192,
|
"text-embedding-ada-002": 8192,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Runtime cache for dynamically discovered token limits
|
||||||
|
# Key: (model_name, base_url), Value: token_limit
|
||||||
|
# Prevents repeated SDK/API calls for the same model
|
||||||
|
_token_limit_cache: dict[tuple[str, str], int] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(
|
def get_model_token_limit(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -49,6 +56,7 @@ def get_model_token_limit(
|
|||||||
"""
|
"""
|
||||||
Get token limit for a given embedding model.
|
Get token limit for a given embedding model.
|
||||||
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||||
|
Caches discovered limits to prevent repeated API/SDK calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the embedding model
|
model_name: Name of the embedding model
|
||||||
@@ -58,12 +66,33 @@ def get_model_token_limit(
|
|||||||
Returns:
|
Returns:
|
||||||
Token limit for the model in tokens
|
Token limit for the model in tokens
|
||||||
"""
|
"""
|
||||||
|
# Check cache first to avoid repeated SDK/API calls
|
||||||
|
cache_key = (model_name, base_url or "")
|
||||||
|
if cache_key in _token_limit_cache:
|
||||||
|
cached_limit = _token_limit_cache[cache_key]
|
||||||
|
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
|
||||||
|
return cached_limit
|
||||||
|
|
||||||
# Try Ollama dynamic discovery if base_url provided
|
# Try Ollama dynamic discovery if base_url provided
|
||||||
if base_url:
|
if base_url:
|
||||||
# Detect Ollama servers by port or "ollama" in URL
|
# Detect Ollama servers by port or "ollama" in URL
|
||||||
if "11434" in base_url or "ollama" in base_url.lower():
|
if "11434" in base_url or "ollama" in base_url.lower():
|
||||||
limit = _query_ollama_context_limit(model_name, base_url)
|
limit = _query_ollama_context_limit(model_name, base_url)
|
||||||
if limit:
|
if limit:
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Try LM Studio SDK discovery
|
||||||
|
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
|
||||||
|
# Convert HTTP to WebSocket URL
|
||||||
|
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
|
||||||
|
# Remove /v1 suffix if present
|
||||||
|
if ws_url.endswith("/v1"):
|
||||||
|
ws_url = ws_url[:-3]
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(model_name, ws_url)
|
||||||
|
if limit:
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
return limit
|
return limit
|
||||||
|
|
||||||
# Fallback to known model registry with version handling (from PR #154)
|
# Fallback to known model registry with version handling (from PR #154)
|
||||||
@@ -72,19 +101,25 @@ def get_model_token_limit(
|
|||||||
|
|
||||||
# Check exact match first
|
# Check exact match first
|
||||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
return EMBEDDING_MODEL_LIMITS[model_name]
|
limit = EMBEDDING_MODEL_LIMITS[model_name]
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
# Check base name match
|
# Check base name match
|
||||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
# Check partial matches for common patterns
|
# Check partial matches for common patterns
|
||||||
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
|
||||||
if known_model in base_model_name or base_model_name in known_model:
|
if known_model in base_model_name or base_model_name in known_model:
|
||||||
return limit
|
_token_limit_cache[cache_key] = registry_limit
|
||||||
|
return registry_limit
|
||||||
|
|
||||||
# Default fallback
|
# Default fallback
|
||||||
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||||
|
_token_limit_cache[cache_key] = default
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
@@ -185,6 +220,91 @@ def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Query LM Studio SDK for model context length via Node.js subprocess.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the LM Studio model
|
||||||
|
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context limit in tokens if found, None otherwise
|
||||||
|
"""
|
||||||
|
# Inline JavaScript using @lmstudio/sdk
|
||||||
|
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
|
||||||
|
js_code = f"""
|
||||||
|
const {{ LMStudioClient }} = require('@lmstudio/sdk');
|
||||||
|
(async () => {{
|
||||||
|
try {{
|
||||||
|
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
|
||||||
|
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
|
||||||
|
const contextLength = await model.getContextLength();
|
||||||
|
await model.unload(); // Unload immediately to respect JIT auto-evict settings
|
||||||
|
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
|
||||||
|
}} catch (error) {{
|
||||||
|
console.error(JSON.stringify({{ error: error.message }}));
|
||||||
|
process.exit(1);
|
||||||
|
}}
|
||||||
|
}})();
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
# Try to get npm global root (works with nvm, brew node, etc.)
|
||||||
|
try:
|
||||||
|
npm_root = subprocess.run(
|
||||||
|
["npm", "root", "-g"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if npm_root.returncode == 0:
|
||||||
|
global_modules = npm_root.stdout.strip()
|
||||||
|
# Append to existing NODE_PATH if present
|
||||||
|
existing_node_path = env.get("NODE_PATH", "")
|
||||||
|
env["NODE_PATH"] = (
|
||||||
|
f"{global_modules}:{existing_node_path}"
|
||||||
|
if existing_node_path
|
||||||
|
else global_modules
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# If npm not available, continue with existing NODE_PATH
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
["node", "-e", js_code],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=10,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.debug(f"LM Studio SDK error: {result.stderr}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
context_length = data.get("contextLength")
|
||||||
|
|
||||||
|
if context_length and context_length > 0:
|
||||||
|
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
|
||||||
|
return context_length
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.debug("LM Studio SDK query timeout")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("LM Studio SDK returned invalid JSON")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"LM Studio SDK query failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: dict[str, Any] = {}
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
@@ -232,6 +352,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
base_url=provider_options.get("base_url"),
|
base_url=provider_options.get("base_url"),
|
||||||
api_key=provider_options.get("api_key"),
|
api_key=provider_options.get("api_key"),
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
@@ -241,6 +362,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
host=provider_options.get("host"),
|
host=provider_options.get("host"),
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
elif mode == "gemini":
|
elif mode == "gemini":
|
||||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
@@ -579,6 +701,7 @@ def compute_embeddings_openai(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
@@ -597,26 +720,40 @@ def compute_embeddings_openai(
|
|||||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||||
)
|
)
|
||||||
|
|
||||||
resolved_base_url = resolve_openai_base_url(base_url)
|
# Extract base_url and api_key from provider_options if not provided directly
|
||||||
resolved_api_key = resolve_openai_api_key(api_key)
|
provider_options = provider_options or {}
|
||||||
|
effective_base_url = base_url or provider_options.get("base_url")
|
||||||
|
effective_api_key = api_key or provider_options.get("api_key")
|
||||||
|
|
||||||
|
resolved_base_url = resolve_openai_base_url(effective_base_url)
|
||||||
|
resolved_api_key = resolve_openai_api_key(effective_api_key)
|
||||||
|
|
||||||
if not resolved_api_key:
|
if not resolved_api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
# Cache OpenAI client
|
# Create OpenAI client
|
||||||
cache_key = f"openai_client::{resolved_base_url}"
|
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||||
if cache_key in _model_cache:
|
|
||||||
client = _model_cache[cache_key]
|
|
||||||
else:
|
|
||||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
|
||||||
_model_cache[cache_key] = client
|
|
||||||
logger.info("OpenAI client cached")
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
)
|
)
|
||||||
print(f"len of texts: {len(texts)}")
|
print(f"len of texts: {len(texts)}")
|
||||||
|
|
||||||
|
# Apply prompt template if provided
|
||||||
|
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||||
|
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||||
|
"prompt_template"
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||||
|
texts = [f"{prompt_template}{text}" for text in texts]
|
||||||
|
|
||||||
|
# Query token limit and apply truncation
|
||||||
|
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
|
||||||
|
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
|
||||||
|
texts = truncate_to_token_limit(texts, token_limit)
|
||||||
|
|
||||||
# OpenAI has limits on batch size and input length
|
# OpenAI has limits on batch size and input length
|
||||||
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
@@ -647,7 +784,15 @@ def compute_embeddings_openai(
|
|||||||
try:
|
try:
|
||||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
|
# Verify we got the expected number of embeddings
|
||||||
|
if len(batch_embeddings) != len(batch_texts):
|
||||||
|
logger.warning(
|
||||||
|
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only take the number of embeddings that match the batch size
|
||||||
|
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Batch {i} failed: {e}")
|
logger.error(f"Batch {i} failed: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -737,6 +882,7 @@ def compute_embeddings_ollama(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with true batch processing.
|
Compute embeddings using Ollama API with true batch processing.
|
||||||
@@ -749,6 +895,7 @@ def compute_embeddings_ollama(
|
|||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||||
|
provider_options: Optional provider-specific options (e.g., prompt_template)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
@@ -885,6 +1032,17 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||||
|
|
||||||
|
# Apply prompt template if provided
|
||||||
|
provider_options = provider_options or {}
|
||||||
|
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||||
|
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||||
|
"prompt_template"
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||||
|
texts = [f"{prompt_template}{text}" for text in texts]
|
||||||
|
|
||||||
# Get model token limit and apply truncation before batching
|
# Get model token limit and apply truncation before batching
|
||||||
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
|
query_template: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
query_template: Optional prompt template to prepend to query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Query embedding as numpy array with shape (1, D)
|
Query embedding as numpy array with shape (1, D)
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ def autodiscover_backends():
|
|||||||
discovered_backends = []
|
discovered_backends = []
|
||||||
for dist in importlib.metadata.distributions():
|
for dist in importlib.metadata.distributions():
|
||||||
dist_name = dist.metadata["name"]
|
dist_name = dist.metadata["name"]
|
||||||
|
if dist_name is None:
|
||||||
|
continue
|
||||||
if dist_name.startswith("leann-backend-"):
|
if dist_name.startswith("leann-backend-"):
|
||||||
backend_module_name = dist_name.replace("-", "_")
|
backend_module_name = dist_name.replace("-", "_")
|
||||||
discovered_backends.append(backend_module_name)
|
discovered_backends.append(backend_module_name)
|
||||||
|
|||||||
@@ -71,6 +71,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
or "mips"
|
or "mips"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Filter out ALL prompt templates from provider_options during search
|
||||||
|
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
|
||||||
|
# The server should never apply templates during search to avoid double-templating
|
||||||
|
search_provider_options = {
|
||||||
|
k: v
|
||||||
|
for k, v in self.embedding_options.items()
|
||||||
|
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
|
||||||
|
}
|
||||||
|
|
||||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||||
port=port,
|
port=port,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
@@ -78,7 +87,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=distance_metric,
|
distance_metric=distance_metric,
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
provider_options=self.embedding_options,
|
provider_options=search_provider_options,
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||||
@@ -90,6 +99,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: int = 5557,
|
zmq_port: int = 5557,
|
||||||
|
query_template: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embedding for a query string.
|
Compute embedding for a query string.
|
||||||
@@ -98,10 +108,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
query_template: Optional prompt template to prepend to query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Query embedding as numpy array
|
Query embedding as numpy array
|
||||||
"""
|
"""
|
||||||
|
# Apply query template BEFORE any computation path
|
||||||
|
# This ensures template is applied consistently for both server and fallback paths
|
||||||
|
if query_template:
|
||||||
|
query = f"{query_template}{query}"
|
||||||
|
|
||||||
# Try to use embedding server if available and requested
|
# Try to use embedding server if available and requested
|
||||||
if use_server_if_available:
|
if use_server_if_available:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -69,7 +69,8 @@ diskann = [
|
|||||||
# Add a new optional dependency group for document processing
|
# Add a new optional dependency group for document processing
|
||||||
documents = [
|
documents = [
|
||||||
"beautifulsoup4>=4.13.0", # For HTML parsing
|
"beautifulsoup4>=4.13.0", # For HTML parsing
|
||||||
"python-docx>=0.8.11", # For Word documents
|
"python-docx>=0.8.11", # For Word documents (creating/editing)
|
||||||
|
"docx2txt>=0.9", # For Word documents (text extraction)
|
||||||
"openpyxl>=3.1.0", # For Excel files
|
"openpyxl>=3.1.0", # For Excel files
|
||||||
"pandas>=2.2.0", # For data processing
|
"pandas>=2.2.0", # For data processing
|
||||||
]
|
]
|
||||||
@@ -164,6 +165,7 @@ python_functions = ["test_*"]
|
|||||||
markers = [
|
markers = [
|
||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"openai: marks tests that require OpenAI API key",
|
"openai: marks tests that require OpenAI API key",
|
||||||
|
"integration: marks tests that require live services (Ollama, LM Studio, etc.)",
|
||||||
]
|
]
|
||||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||||
addopts = [
|
addopts = [
|
||||||
|
|||||||
@@ -36,6 +36,14 @@ Tests DiskANN graph partitioning functionality:
|
|||||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||||
|
|
||||||
|
### `test_prompt_template_e2e.py`
|
||||||
|
Integration tests for prompt template feature with live embedding services:
|
||||||
|
- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio)
|
||||||
|
- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default)
|
||||||
|
- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk)
|
||||||
|
- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration`
|
||||||
|
- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
### Install test dependencies:
|
### Install test dependencies:
|
||||||
@@ -66,6 +74,12 @@ pytest tests/ -m "not openai"
|
|||||||
# Skip slow tests
|
# Skip slow tests
|
||||||
pytest tests/ -m "not slow"
|
pytest tests/ -m "not slow"
|
||||||
|
|
||||||
|
# Skip integration tests (that require live services)
|
||||||
|
pytest tests/ -m "not integration"
|
||||||
|
|
||||||
|
# Run only integration tests (requires LM Studio or Ollama running)
|
||||||
|
pytest tests/test_prompt_template_e2e.py -v -s
|
||||||
|
|
||||||
# Run DiskANN partition tests (requires local machine, not CI)
|
# Run DiskANN partition tests (requires local machine, not CI)
|
||||||
pytest tests/test_diskann_partition.py
|
pytest tests/test_diskann_partition.py
|
||||||
```
|
```
|
||||||
@@ -101,6 +115,20 @@ The `pytest.ini` file configures:
|
|||||||
- Custom markers for slow and OpenAI tests
|
- Custom markers for slow and OpenAI tests
|
||||||
- Verbose output with short tracebacks
|
- Verbose output with short tracebacks
|
||||||
|
|
||||||
|
### Integration Test Prerequisites
|
||||||
|
|
||||||
|
Integration tests (`test_prompt_template_e2e.py`) require live services:
|
||||||
|
|
||||||
|
**Required:**
|
||||||
|
- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded
|
||||||
|
|
||||||
|
**Optional:**
|
||||||
|
- Ollama running at `http://localhost:11434` for token limit detection tests
|
||||||
|
- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests
|
||||||
|
|
||||||
|
Tests gracefully skip if services are unavailable.
|
||||||
|
|
||||||
### Known Issues
|
### Known Issues
|
||||||
|
|
||||||
- OpenAI tests are automatically skipped if no API key is provided
|
- OpenAI tests are automatically skipped if no API key is provided
|
||||||
|
- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed)
|
||||||
|
|||||||
533
tests/test_cli_prompt_template.py
Normal file
533
tests/test_cli_prompt_template.py
Normal file
@@ -0,0 +1,533 @@
|
|||||||
|
"""
|
||||||
|
Tests for CLI argument integration of --embedding-prompt-template.
|
||||||
|
|
||||||
|
These tests verify that:
|
||||||
|
1. The --embedding-prompt-template flag is properly registered on build and search commands
|
||||||
|
2. The template value flows from CLI args to embedding_options dict
|
||||||
|
3. The template is passed through to compute_embeddings() function
|
||||||
|
4. Default behavior (no flag) is handled correctly
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from leann.cli import LeannCLI
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLIPromptTemplateArgument:
|
||||||
|
"""Tests for --embedding-prompt-template on build and search commands."""
|
||||||
|
|
||||||
|
def test_commands_accept_prompt_template_argument(self):
|
||||||
|
"""Verify that build and search parsers accept --embedding-prompt-template flag."""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
template_value = "search_query: "
|
||||||
|
|
||||||
|
# Test build command
|
||||||
|
build_args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
"/tmp/test-docs",
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template_value,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert build_args.command == "build"
|
||||||
|
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||||
|
"build command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert build_args.embedding_prompt_template == template_value
|
||||||
|
|
||||||
|
# Test search command
|
||||||
|
search_args = parser.parse_args(
|
||||||
|
["search", "test-index", "my query", "--embedding-prompt-template", template_value]
|
||||||
|
)
|
||||||
|
assert search_args.command == "search"
|
||||||
|
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||||
|
"search command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert search_args.embedding_prompt_template == template_value
|
||||||
|
|
||||||
|
def test_commands_default_to_none(self):
|
||||||
|
"""Verify default value is None when flag not provided (backward compatibility)."""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
# Test build command default
|
||||||
|
build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"])
|
||||||
|
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||||
|
"build command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert build_args.embedding_prompt_template is None, (
|
||||||
|
"Build default value should be None when flag not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test search command default
|
||||||
|
search_args = parser.parse_args(["search", "test-index", "my query"])
|
||||||
|
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||||
|
"search command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert search_args.embedding_prompt_template is None, (
|
||||||
|
"Search default value should be None when flag not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildCommandPromptTemplateArgumentExtras:
|
||||||
|
"""Additional build-specific tests for prompt template argument."""
|
||||||
|
|
||||||
|
def test_build_command_prompt_template_with_multiword_value(self):
|
||||||
|
"""
|
||||||
|
Verify that template values with spaces are handled correctly.
|
||||||
|
|
||||||
|
Templates like "search_document: " or "Represent this sentence for searching: "
|
||||||
|
should be accepted as a single string argument.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "Represent this sentence for searching: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
"/tmp/test-docs",
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert args.embedding_prompt_template == template
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateStoredInEmbeddingOptions:
|
||||||
|
"""Tests for template storage in embedding_options dict."""
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_prompt_template_stored_in_embedding_options_on_build(
|
||||||
|
self, mock_builder_class, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify that when --embedding-prompt-template is provided to build command,
|
||||||
|
the value is stored in embedding_options dict passed to LeannBuilder.
|
||||||
|
|
||||||
|
This test will fail because the CLI doesn't currently process this argument
|
||||||
|
and add it to embedding_options.
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
# Create CLI and run build command
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "search_query: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with embedding_options containing prompt_template
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when template provided"
|
||||||
|
)
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got {embedding_options.get('prompt_template')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
Verify that when --embedding-prompt-template is NOT provided,
|
||||||
|
embedding_options either doesn't have the key or it's None.
|
||||||
|
|
||||||
|
This ensures we don't pass empty/None values unnecessarily.
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that if embedding_options is passed, it doesn't have prompt_template
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
if call_kwargs.get("embedding_options"):
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
# Either the key shouldn't exist, or it should be None
|
||||||
|
assert (
|
||||||
|
"prompt_template" not in embedding_options
|
||||||
|
or embedding_options["prompt_template"] is None
|
||||||
|
), "prompt_template should not be set when flag not provided"
|
||||||
|
|
||||||
|
# R1 Tests: Build-time separate template storage
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_stores_separate_templates(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 1: Verify that when both --embedding-prompt-template and
|
||||||
|
--query-prompt-template are provided to build command, both values
|
||||||
|
are stored separately in embedding_options dict as build_prompt_template
|
||||||
|
and query_prompt_template.
|
||||||
|
|
||||||
|
This test will fail because:
|
||||||
|
1. CLI doesn't accept --query-prompt-template flag yet
|
||||||
|
2. CLI doesn't store templates as separate build_prompt_template and
|
||||||
|
query_prompt_template keys
|
||||||
|
|
||||||
|
Expected behavior after implementation:
|
||||||
|
- .meta.json contains: {"embedding_options": {
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: "
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
build_template = "doc: "
|
||||||
|
query_template = "query: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
build_template,
|
||||||
|
"--query-prompt-template",
|
||||||
|
query_template,
|
||||||
|
"--force",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with separate template keys
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when templates provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "build_prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'build_prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["build_prompt_template"] == build_template, (
|
||||||
|
f"build_prompt_template should be '{build_template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "query_prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'query_prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["query_prompt_template"] == query_template, (
|
||||||
|
f"query_prompt_template should be '{query_template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Old key should NOT be present when using new separate template format
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"Old 'prompt_template' key should not be present with separate templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 2: Verify backward compatibility - when only
|
||||||
|
--embedding-prompt-template is provided (old behavior), it should
|
||||||
|
still be stored as 'prompt_template' in embedding_options.
|
||||||
|
|
||||||
|
This ensures existing workflows continue to work unchanged.
|
||||||
|
|
||||||
|
This test currently passes because it matches existing behavior, but it
|
||||||
|
documents the requirement that this behavior must be preserved after
|
||||||
|
implementing the separate template feature.
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}}
|
||||||
|
- No build_prompt_template or query_prompt_template keys
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "prompt: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--force",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with old format
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when template provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain old 'prompt_template' key for backward compat"
|
||||||
|
)
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"prompt_template should be '{template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# New keys should NOT be present in backward compat mode
|
||||||
|
assert "build_prompt_template" not in embedding_options, (
|
||||||
|
"build_prompt_template should not be present with single template flag"
|
||||||
|
)
|
||||||
|
assert "query_prompt_template" not in embedding_options, (
|
||||||
|
"query_prompt_template should not be present with single template flag"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_no_templates(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 3: Verify that when no template flags are provided,
|
||||||
|
embedding_options has no prompt template keys.
|
||||||
|
|
||||||
|
This ensures clean defaults and no unnecessary keys in .meta.json.
|
||||||
|
|
||||||
|
This test currently passes because it matches existing behavior, but it
|
||||||
|
documents the requirement that this behavior must be preserved after
|
||||||
|
implementing the separate template feature.
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- .meta.json has no prompt_template, build_prompt_template, or
|
||||||
|
query_prompt_template keys (or embedding_options is empty/None)
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"])
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that no template keys are present
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
if call_kwargs.get("embedding_options"):
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
|
||||||
|
# None of the template keys should be present
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
assert "build_prompt_template" not in embedding_options, (
|
||||||
|
"build_prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
assert "query_prompt_template" not in embedding_options, (
|
||||||
|
"query_prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateFlowsToComputeEmbeddings:
|
||||||
|
"""Tests for template flowing through to compute_embeddings function."""
|
||||||
|
|
||||||
|
@patch("leann.api.compute_embeddings")
|
||||||
|
def test_prompt_template_flows_to_compute_embeddings_via_provider_options(
|
||||||
|
self, mock_compute_embeddings, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify that the prompt template flows from CLI args through LeannBuilder
|
||||||
|
to compute_embeddings() function via provider_options parameter.
|
||||||
|
|
||||||
|
This is an integration test that verifies the complete flow:
|
||||||
|
CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options)
|
||||||
|
|
||||||
|
This test will fail because:
|
||||||
|
1. CLI doesn't capture the argument yet
|
||||||
|
2. embedding_options doesn't include prompt_template
|
||||||
|
3. LeannBuilder doesn't pass it through to compute_embeddings
|
||||||
|
"""
|
||||||
|
# Mock compute_embeddings to return dummy embeddings as numpy array
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
|
||||||
|
# Use real LeannBuilder (not mocked) to test the actual flow
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a simple document
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "search_document: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--backend-name",
|
||||||
|
"hnsw", # Use hnsw backend
|
||||||
|
"--force", # Force rebuild to ensure index is created
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should fail because the flow isn't implemented yet
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Verify compute_embeddings was called with provider_options containing prompt_template
|
||||||
|
assert mock_compute_embeddings.called, "compute_embeddings should have been called"
|
||||||
|
|
||||||
|
# Check the call arguments
|
||||||
|
call_kwargs = mock_compute_embeddings.call_args.kwargs
|
||||||
|
assert "provider_options" in call_kwargs, (
|
||||||
|
"compute_embeddings should receive provider_options parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_options = call_kwargs["provider_options"]
|
||||||
|
assert provider_options is not None, "provider_options should not be None"
|
||||||
|
assert "prompt_template" in provider_options, (
|
||||||
|
"provider_options should contain prompt_template key"
|
||||||
|
)
|
||||||
|
assert provider_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got {provider_options.get('prompt_template')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateArgumentHelp:
|
||||||
|
"""Tests for argument help text and documentation."""
|
||||||
|
|
||||||
|
def test_build_command_prompt_template_has_help_text(self):
|
||||||
|
"""
|
||||||
|
Verify that --embedding-prompt-template has descriptive help text.
|
||||||
|
|
||||||
|
Good help text is crucial for CLI usability.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
# Get the build subparser
|
||||||
|
# This is a bit tricky - we need to parse to get the help
|
||||||
|
# We'll check that the help includes relevant keywords
|
||||||
|
import io
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
f = io.StringIO()
|
||||||
|
try:
|
||||||
|
with redirect_stdout(f):
|
||||||
|
parser.parse_args(["build", "--help"])
|
||||||
|
except SystemExit:
|
||||||
|
pass # --help causes sys.exit(0)
|
||||||
|
|
||||||
|
help_text = f.getvalue()
|
||||||
|
assert "--embedding-prompt-template" in help_text, (
|
||||||
|
"Help text should mention --embedding-prompt-template"
|
||||||
|
)
|
||||||
|
# Check for keywords that should be in the help
|
||||||
|
help_lower = help_text.lower()
|
||||||
|
assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), (
|
||||||
|
"Help text should explain what the prompt template does"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_command_prompt_template_has_help_text(self):
|
||||||
|
"""
|
||||||
|
Verify that search command also has help text for --embedding-prompt-template.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
import io
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
f = io.StringIO()
|
||||||
|
try:
|
||||||
|
with redirect_stdout(f):
|
||||||
|
parser.parse_args(["search", "--help"])
|
||||||
|
except SystemExit:
|
||||||
|
pass # --help causes sys.exit(0)
|
||||||
|
|
||||||
|
help_text = f.getvalue()
|
||||||
|
assert "--embedding-prompt-template" in help_text, (
|
||||||
|
"Search help text should mention --embedding-prompt-template"
|
||||||
|
)
|
||||||
281
tests/test_embedding_prompt_template.py
Normal file
281
tests/test_embedding_prompt_template.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""Unit tests for prompt template prepending in OpenAI embeddings.
|
||||||
|
|
||||||
|
This test suite defines the contract for prompt template functionality that allows
|
||||||
|
users to prepend a consistent prompt to all embedding inputs. These tests verify:
|
||||||
|
|
||||||
|
1. Template prepending to all input texts before embedding computation
|
||||||
|
2. Graceful handling of None/missing provider_options
|
||||||
|
3. Empty string template behavior (no-op)
|
||||||
|
4. Logging of template application for observability
|
||||||
|
5. Template application before token truncation
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
implementation does not exist yet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from leann.embedding_compute import compute_embeddings_openai
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplatePrepending:
|
||||||
|
"""Tests for prompt template prepending in compute_embeddings_openai."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client(self):
|
||||||
|
"""Create mock OpenAI client that captures input texts."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
# Mock the embeddings.create response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [
|
||||||
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||||
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||||
|
]
|
||||||
|
mock_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_module(self, mock_openai_client, monkeypatch):
|
||||||
|
"""Mock the openai module to return our mock client."""
|
||||||
|
# Mock the API key environment variable
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking")
|
||||||
|
|
||||||
|
# openai is imported inside the function, so we need to patch it there
|
||||||
|
with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai:
|
||||||
|
yield mock_openai
|
||||||
|
|
||||||
|
def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template is prepended to all input texts.
|
||||||
|
|
||||||
|
When provider_options contains "prompt_template", that template should
|
||||||
|
be prepended to every text in the input list before sending to OpenAI API.
|
||||||
|
|
||||||
|
This is the core functionality: the template acts as a consistent prefix
|
||||||
|
that provides context or instruction for the embedding model.
|
||||||
|
"""
|
||||||
|
texts = ["First document", "Second document"]
|
||||||
|
template = "search_document: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
# Call compute_embeddings_openai with provider_options
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embeddings.create was called with templated texts
|
||||||
|
mock_openai_client.embeddings.create.assert_called_once()
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
|
||||||
|
# Extract the input texts sent to API
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
|
||||||
|
# Verify template was prepended to all texts
|
||||||
|
assert len(sent_texts) == 2, "Should send same number of texts"
|
||||||
|
assert sent_texts[0] == "search_document: First document", (
|
||||||
|
"Template should be prepended to first text"
|
||||||
|
)
|
||||||
|
assert sent_texts[1] == "search_document: Second document", (
|
||||||
|
"Template should be prepended to second text"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result is valid embeddings array
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
assert result.shape == (2, 3), "Should return correct shape"
|
||||||
|
|
||||||
|
def test_template_not_applied_when_missing_or_empty(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template not applied when provider_options is None, missing key, or empty string.
|
||||||
|
|
||||||
|
This consolidated test covers three scenarios where templates should NOT be applied:
|
||||||
|
1. provider_options is None (default behavior)
|
||||||
|
2. provider_options exists but missing 'prompt_template' key
|
||||||
|
3. prompt_template is explicitly set to empty string ""
|
||||||
|
|
||||||
|
In all cases, texts should be sent to the API unchanged.
|
||||||
|
"""
|
||||||
|
# Scenario 1: None provider_options
|
||||||
|
texts = ["Original text one", "Original text two"]
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=None,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Original text one", (
|
||||||
|
"Text should be unchanged with None provider_options"
|
||||||
|
)
|
||||||
|
assert sent_texts[1] == "Original text two"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
assert result.shape == (2, 3)
|
||||||
|
|
||||||
|
# Reset mock for next scenario
|
||||||
|
mock_openai_client.reset_mock()
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [
|
||||||
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||||
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||||
|
]
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Scenario 2: Missing 'prompt_template' key
|
||||||
|
texts = ["Text without template", "Another text"]
|
||||||
|
provider_options = {"base_url": "https://api.openai.com/v1"}
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key"
|
||||||
|
assert sent_texts[1] == "Another text"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
# Reset mock for next scenario
|
||||||
|
mock_openai_client.reset_mock()
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Scenario 3: Empty string template
|
||||||
|
texts = ["Text one", "Text two"]
|
||||||
|
provider_options = {"prompt_template": ""}
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Text one", "Empty template should not modify text"
|
||||||
|
assert sent_texts[1] == "Text two"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template is prepended in all batches when texts exceed batch size.
|
||||||
|
|
||||||
|
OpenAI API has batch size limits. When input texts are split into
|
||||||
|
multiple batches, the template should be prepended to texts in every batch.
|
||||||
|
|
||||||
|
This ensures consistency across all API calls.
|
||||||
|
"""
|
||||||
|
# Create many texts that will be split into multiple batches
|
||||||
|
texts = [f"Document {i}" for i in range(1000)]
|
||||||
|
template = "passage: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
# Mock multiple batch responses
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)]
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embeddings.create was called multiple times (batching)
|
||||||
|
assert mock_openai_client.embeddings.create.call_count >= 2, (
|
||||||
|
"Should make multiple API calls for large text list"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was prepended in ALL batches
|
||||||
|
for call in mock_openai_client.embeddings.create.call_args_list:
|
||||||
|
sent_texts = call.kwargs["input"]
|
||||||
|
for text in sent_texts:
|
||||||
|
assert text.startswith(template), (
|
||||||
|
f"All texts in all batches should start with template. Got: {text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result shape
|
||||||
|
assert result.shape[0] == 1000, "Should return embeddings for all texts"
|
||||||
|
|
||||||
|
def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template with special characters is handled correctly.
|
||||||
|
|
||||||
|
Templates may contain special characters, Unicode, newlines, etc.
|
||||||
|
These should all be prepended correctly without encoding issues.
|
||||||
|
"""
|
||||||
|
texts = ["Document content"]
|
||||||
|
# Template with various special characters
|
||||||
|
template = "🔍 Search query [EN]: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify special characters in template were preserved
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
|
||||||
|
assert sent_texts[0] == "🔍 Search query [EN]: Document content", (
|
||||||
|
"Special characters in template should be preserved"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
def test_prompt_template_integration_with_existing_validation(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template works with existing input validation.
|
||||||
|
|
||||||
|
compute_embeddings_openai has validation for empty texts and whitespace.
|
||||||
|
Template prepending should happen AFTER validation, so validation errors
|
||||||
|
are thrown based on original texts, not templated texts.
|
||||||
|
|
||||||
|
This ensures users get clear error messages about their input.
|
||||||
|
"""
|
||||||
|
# Empty text should still raise ValueError even with template
|
||||||
|
texts = [""]
|
||||||
|
provider_options = {"prompt_template": "prefix: "}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="empty/invalid"):
|
||||||
|
compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prompt_template_with_api_key_and_base_url(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template works alongside other provider_options.
|
||||||
|
|
||||||
|
provider_options may contain multiple settings: prompt_template,
|
||||||
|
base_url, api_key. All should work together correctly.
|
||||||
|
"""
|
||||||
|
texts = ["Test document"]
|
||||||
|
provider_options = {
|
||||||
|
"prompt_template": "embed: ",
|
||||||
|
"base_url": "https://custom.api.com/v1",
|
||||||
|
"api_key": "test-key-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "embed: Test document"
|
||||||
|
|
||||||
|
# Verify OpenAI client was created with correct base_url
|
||||||
|
mock_openai_module.assert_called()
|
||||||
|
client_init_kwargs = mock_openai_module.call_args.kwargs
|
||||||
|
assert client_init_kwargs["base_url"] == "https://custom.api.com/v1"
|
||||||
|
assert client_init_kwargs["api_key"] == "test-key-123"
|
||||||
|
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
315
tests/test_lmstudio_bridge.py
Normal file
315
tests/test_lmstudio_bridge.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
"""Unit tests for LM Studio TypeScript SDK bridge functionality.
|
||||||
|
|
||||||
|
This test suite defines the contract for the LM Studio SDK bridge that queries
|
||||||
|
model context length via Node.js subprocess. These tests verify:
|
||||||
|
|
||||||
|
1. Successful SDK query returns context length
|
||||||
|
2. Graceful fallback when Node.js not installed (FileNotFoundError)
|
||||||
|
3. Graceful fallback when SDK not installed (npm error)
|
||||||
|
4. Timeout handling (subprocess.TimeoutExpired)
|
||||||
|
5. Invalid JSON response handling
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
`_query_lmstudio_context_limit` function does not exist yet.
|
||||||
|
|
||||||
|
The function contract:
|
||||||
|
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
|
||||||
|
- Outputs: context_length (int) or None on error
|
||||||
|
- Requirements:
|
||||||
|
1. Call Node.js with inline JavaScript using @lmstudio/sdk
|
||||||
|
2. 10-second timeout (accounts for Node.js startup)
|
||||||
|
3. Graceful fallback on any error (returns None, doesn't raise)
|
||||||
|
4. Parse JSON response with contextLength field
|
||||||
|
5. Log errors at debug level (not warning/error)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Try to import the function - if it doesn't exist, tests will fail as expected
|
||||||
|
try:
|
||||||
|
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||||
|
except ImportError:
|
||||||
|
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
|
||||||
|
def _query_lmstudio_context_limit(*args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioBridge:
|
||||||
|
"""Tests for LM Studio TypeScript SDK bridge integration."""
|
||||||
|
|
||||||
|
def test_query_lmstudio_success(self, monkeypatch):
|
||||||
|
"""Verify successful SDK query returns context length.
|
||||||
|
|
||||||
|
When the Node.js subprocess successfully queries the LM Studio SDK,
|
||||||
|
it should return a JSON response with contextLength field. The function
|
||||||
|
should parse this and return the integer context length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
# Verify timeout is set to 10 seconds
|
||||||
|
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
|
||||||
|
|
||||||
|
# Verify capture_output and text=True are set
|
||||||
|
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
|
||||||
|
assert kwargs.get("text") is True, "Should decode output as text"
|
||||||
|
|
||||||
|
# Return successful JSON response
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
# Test with typical LM Studio model
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 8192, "Should return context length from SDK response"
|
||||||
|
|
||||||
|
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when Node.js not installed.
|
||||||
|
|
||||||
|
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
|
||||||
|
The function should catch this and return None (graceful fallback to registry).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise FileNotFoundError("node: command not found")
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when Node.js not installed"
|
||||||
|
|
||||||
|
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when @lmstudio/sdk not installed.
|
||||||
|
|
||||||
|
When the SDK npm package is not installed, Node.js will return non-zero
|
||||||
|
exit code with error message in stderr. The function should detect this
|
||||||
|
and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 1
|
||||||
|
mock_result.stdout = ""
|
||||||
|
mock_result.stderr = (
|
||||||
|
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
|
||||||
|
)
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when SDK not installed"
|
||||||
|
|
||||||
|
def test_query_lmstudio_timeout(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when subprocess times out.
|
||||||
|
|
||||||
|
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
|
||||||
|
not responding), subprocess.TimeoutExpired should be raised. The function
|
||||||
|
should catch this and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None on timeout"
|
||||||
|
|
||||||
|
def test_query_lmstudio_invalid_json(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when response is invalid JSON.
|
||||||
|
|
||||||
|
When the subprocess returns malformed JSON (e.g., due to SDK error),
|
||||||
|
json.loads will raise ValueError/JSONDecodeError. The function should
|
||||||
|
catch this and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = "This is not valid JSON"
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when JSON parsing fails"
|
||||||
|
|
||||||
|
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when JSON lacks contextLength field.
|
||||||
|
|
||||||
|
When the SDK returns valid JSON but without the expected contextLength
|
||||||
|
field (e.g., error response), the function should return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="nonexistent-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength field missing"
|
||||||
|
|
||||||
|
def test_query_lmstudio_null_context_length(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when contextLength is null.
|
||||||
|
|
||||||
|
When the SDK returns contextLength: null (model couldn't be loaded),
|
||||||
|
the function should return None for registry fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength is null"
|
||||||
|
|
||||||
|
def test_query_lmstudio_zero_context_length(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when contextLength is zero.
|
||||||
|
|
||||||
|
When the SDK returns contextLength: 0 (invalid value), the function
|
||||||
|
should return None to trigger registry fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength is zero"
|
||||||
|
|
||||||
|
def test_query_lmstudio_with_custom_port(self, monkeypatch):
|
||||||
|
"""Verify SDK query works with non-default WebSocket port.
|
||||||
|
|
||||||
|
LM Studio can run on custom ports. The function should pass the
|
||||||
|
provided base_url to the Node.js subprocess.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
# Verify the base_url argument is passed correctly
|
||||||
|
command = args[0] if args else kwargs.get("args", [])
|
||||||
|
assert "ws://localhost:8080" in " ".join(command), (
|
||||||
|
"Should pass custom port to subprocess"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:8080"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 4096, "Should work with custom WebSocket port"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"context_length,expected",
|
||||||
|
[
|
||||||
|
(512, 512), # Small context
|
||||||
|
(2048, 2048), # Common context
|
||||||
|
(8192, 8192), # Large context
|
||||||
|
(32768, 32768), # Very large context
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
|
||||||
|
"""Verify SDK query handles various context length values.
|
||||||
|
|
||||||
|
Different models have different context lengths. The function should
|
||||||
|
correctly parse and return any positive integer value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == expected, f"Should return {expected} for context length {context_length}"
|
||||||
|
|
||||||
|
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
|
||||||
|
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
|
||||||
|
|
||||||
|
Following the graceful fallback pattern from Ollama implementation,
|
||||||
|
errors should be logged at debug level to avoid alarming users when
|
||||||
|
fallback to registry works fine.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise FileNotFoundError("node: command not found")
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
|
||||||
|
|
||||||
|
# Check that debug logging occurred (not warning/error)
|
||||||
|
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
|
||||||
|
assert len(debug_logs) > 0, "Should log error at DEBUG level"
|
||||||
|
|
||||||
|
# Verify no WARNING or ERROR logs
|
||||||
|
warning_or_error_logs = [
|
||||||
|
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
|
||||||
|
]
|
||||||
|
assert len(warning_or_error_logs) == 0, (
|
||||||
|
"Should not log at WARNING/ERROR level for expected failures"
|
||||||
|
)
|
||||||
400
tests/test_prompt_template_e2e.py
Normal file
400
tests/test_prompt_template_e2e.py
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
"""End-to-end integration tests for prompt template and token limit features.
|
||||||
|
|
||||||
|
These tests verify real-world functionality with live services:
|
||||||
|
- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support
|
||||||
|
- Ollama with dynamic token limit detection
|
||||||
|
- Hybrid token limit discovery mechanism
|
||||||
|
|
||||||
|
Run with: pytest tests/test_prompt_template_e2e.py -v -s
|
||||||
|
Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration"
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
1. LM Studio running with embedding model: http://localhost:1234
|
||||||
|
2. [Optional] Ollama running: ollama serve
|
||||||
|
3. [Optional] Ollama model: ollama pull nomic-embed-text
|
||||||
|
4. [Optional] Node.js + @lmstudio/sdk for context length detection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from leann.embedding_compute import (
|
||||||
|
compute_embeddings_ollama,
|
||||||
|
compute_embeddings_openai,
|
||||||
|
get_model_token_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test markers for conditional execution
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool:
|
||||||
|
"""Check if a service is available on the given host:port."""
|
||||||
|
try:
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(timeout)
|
||||||
|
result = sock.connect_ex((host, port))
|
||||||
|
sock.close()
|
||||||
|
return result == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_available() -> bool:
|
||||||
|
"""Check if Ollama service is available."""
|
||||||
|
if not check_service_available("localhost", 11434):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_lmstudio_available() -> bool:
|
||||||
|
"""Check if LM Studio service is available."""
|
||||||
|
if not check_service_available("localhost", 1234):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=2.0)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_lmstudio_first_model() -> str:
|
||||||
|
"""Get the first available model from LM Studio."""
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||||
|
data = response.json()
|
||||||
|
models = data.get("data", [])
|
||||||
|
if models:
|
||||||
|
return models[0]["id"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateOpenAI:
|
||||||
|
"""End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio)."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234"
|
||||||
|
)
|
||||||
|
def test_lmstudio_embedding_with_prompt_template(self):
|
||||||
|
"""Test prompt templates with LM Studio using OpenAI-compatible API."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
texts = ["artificial intelligence", "machine learning"]
|
||||||
|
prompt_template = "search_query: "
|
||||||
|
|
||||||
|
# Get embeddings with prompt template via provider_options
|
||||||
|
provider_options = {"prompt_template": prompt_template}
|
||||||
|
embeddings = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name=model_name,
|
||||||
|
base_url="http://localhost:1234/v1",
|
||||||
|
api_key="lm-studio", # LM Studio doesn't require real key
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embeddings is not None
|
||||||
|
assert len(embeddings) == 2
|
||||||
|
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||||
|
assert all(len(emb) > 0 for emb in embeddings)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_prompt_template_affects_embeddings(self):
|
||||||
|
"""Verify that prompt templates actually change embedding values."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
text = "machine learning"
|
||||||
|
base_url = "http://localhost:1234/v1"
|
||||||
|
api_key = "lm-studio"
|
||||||
|
|
||||||
|
# Get embeddings without template
|
||||||
|
embeddings_no_template = compute_embeddings_openai(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
provider_options={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embeddings with template
|
||||||
|
embeddings_with_template = compute_embeddings_openai(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
provider_options={"prompt_template": "search_query: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embeddings should be different when template is applied
|
||||||
|
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||||
|
|
||||||
|
logger.info("✓ Prompt template changes embedding values as expected")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateOllama:
|
||||||
|
"""End-to-end tests for prompt template with Ollama."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not check_ollama_available(), reason="Ollama service not available on localhost:11434"
|
||||||
|
)
|
||||||
|
def test_ollama_embedding_with_prompt_template(self):
|
||||||
|
"""Test prompt templates with Ollama using any available embedding model."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
model_name = embedding_models[0]
|
||||||
|
|
||||||
|
texts = ["artificial intelligence", "machine learning"]
|
||||||
|
prompt_template = "search_query: "
|
||||||
|
|
||||||
|
# Get embeddings with prompt template via provider_options
|
||||||
|
provider_options = {"prompt_template": prompt_template}
|
||||||
|
embeddings = compute_embeddings_ollama(
|
||||||
|
texts=texts,
|
||||||
|
model_name=model_name,
|
||||||
|
is_build=False,
|
||||||
|
host="http://localhost:11434",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embeddings is not None
|
||||||
|
assert len(embeddings) == 2
|
||||||
|
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||||
|
assert all(len(emb) > 0 for emb in embeddings)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_ollama_prompt_template_affects_embeddings(self):
|
||||||
|
"""Verify that prompt templates actually change embedding values with Ollama."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
model_name = embedding_models[0]
|
||||||
|
text = "machine learning"
|
||||||
|
host = "http://localhost:11434"
|
||||||
|
|
||||||
|
# Get embeddings without template
|
||||||
|
embeddings_no_template = compute_embeddings_ollama(
|
||||||
|
texts=[text], model_name=model_name, is_build=False, host=host, provider_options={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embeddings with template
|
||||||
|
embeddings_with_template = compute_embeddings_ollama(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
is_build=False,
|
||||||
|
host=host,
|
||||||
|
provider_options={"prompt_template": "search_query: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embeddings should be different when template is applied
|
||||||
|
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||||
|
|
||||||
|
logger.info("✓ Ollama prompt template changes embedding values as expected")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioSDK:
|
||||||
|
"""End-to-end tests for LM Studio SDK integration."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_model_listing(self):
|
||||||
|
"""Test that we can list models from LM Studio."""
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "data" in data
|
||||||
|
|
||||||
|
models = data["data"]
|
||||||
|
logger.info(f"✓ LM Studio models available: {len(models)}")
|
||||||
|
|
||||||
|
if models:
|
||||||
|
logger.info(f" First model: {models[0].get('id', 'unknown')}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"LM Studio API error: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_sdk_context_length_detection(self):
|
||||||
|
"""Test context length detection via LM Studio SDK bridge (requires Node.js + SDK)."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||||
|
|
||||||
|
# SDK requires WebSocket URL (ws://)
|
||||||
|
context_length = _query_lmstudio_context_limit(
|
||||||
|
model_name=model_name, base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
if context_length is None:
|
||||||
|
logger.warning(
|
||||||
|
"⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)"
|
||||||
|
)
|
||||||
|
pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable")
|
||||||
|
else:
|
||||||
|
assert context_length > 0
|
||||||
|
logger.info(
|
||||||
|
f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("_query_lmstudio_context_limit not implemented yet")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LM Studio SDK test error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaTokenLimit:
|
||||||
|
"""End-to-end tests for Ollama token limit discovery."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_ollama_token_limit_detection(self):
|
||||||
|
"""Test dynamic token limit detection from Ollama /api/show endpoint."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
test_model = embedding_models[0]
|
||||||
|
|
||||||
|
# Test token limit detection
|
||||||
|
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||||
|
|
||||||
|
assert limit > 0
|
||||||
|
logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama token detection: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridTokenLimit:
|
||||||
|
"""End-to-end tests for hybrid token limit discovery mechanism."""
|
||||||
|
|
||||||
|
def test_hybrid_discovery_registry_fallback(self):
|
||||||
|
"""Test fallback to static registry for known OpenAI models."""
|
||||||
|
# Use a known OpenAI model (should be in registry)
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
base_url="http://fake-server:9999", # Fake URL to force registry lookup
|
||||||
|
)
|
||||||
|
|
||||||
|
# text-embedding-3-small should have 8192 in registry
|
||||||
|
assert limit == 8192
|
||||||
|
logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens")
|
||||||
|
|
||||||
|
def test_hybrid_discovery_default_fallback(self):
|
||||||
|
"""Test fallback to safe default for completely unknown models."""
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="completely-unknown-model-xyz-12345",
|
||||||
|
base_url="http://fake-server:9999",
|
||||||
|
default=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get the specified default
|
||||||
|
assert limit == 512
|
||||||
|
logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_hybrid_discovery_ollama_dynamic_first(self):
|
||||||
|
"""Test that Ollama models use dynamic discovery first."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
test_model = embedding_models[0]
|
||||||
|
|
||||||
|
# Should query Ollama /api/show dynamically
|
||||||
|
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||||
|
|
||||||
|
assert limit > 0
|
||||||
|
logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test hybrid Ollama discovery: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("INTEGRATION TEST SUITE - Real Service Testing")
|
||||||
|
print("=" * 70)
|
||||||
|
print("\nThese tests require live services:")
|
||||||
|
print(" • LM Studio: http://localhost:1234 (with embedding model loaded)")
|
||||||
|
print(" • [Optional] Ollama: http://localhost:11434")
|
||||||
|
print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests")
|
||||||
|
print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s")
|
||||||
|
print("=" * 70 + "\n")
|
||||||
808
tests/test_prompt_template_persistence.py
Normal file
808
tests/test_prompt_template_persistence.py
Normal file
@@ -0,0 +1,808 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for prompt template metadata persistence and reuse.
|
||||||
|
|
||||||
|
These tests verify the complete lifecycle of prompt template persistence:
|
||||||
|
1. Template is saved to .meta.json during index build
|
||||||
|
2. Template is automatically loaded during search operations
|
||||||
|
3. Template can be overridden with explicit flag during search
|
||||||
|
4. Template is reused during chat/ask operations
|
||||||
|
|
||||||
|
These are integration tests that:
|
||||||
|
- Use real file system with temporary directories
|
||||||
|
- Run actual build and search operations
|
||||||
|
- Inspect .meta.json file contents directly
|
||||||
|
- Mock embedding servers to avoid external dependencies
|
||||||
|
- Use small test codebases for fast execution
|
||||||
|
|
||||||
|
Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateMetadataPersistence:
|
||||||
|
"""Tests for prompt template storage in .meta.json during build."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
# Return dummy embeddings as numpy array
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that when build is run with embedding_options containing prompt_template,
|
||||||
|
the template value is saved to .meta.json file.
|
||||||
|
|
||||||
|
This is the core persistence requirement - templates must be saved to allow
|
||||||
|
reuse in subsequent search operations without re-specifying the flag.
|
||||||
|
|
||||||
|
Expected failure: .meta.json exists but doesn't contain embedding_options
|
||||||
|
with prompt_template, or the value is not persisted correctly.
|
||||||
|
"""
|
||||||
|
# Setup test data
|
||||||
|
index_path = temp_index_dir / "test_index.leann"
|
||||||
|
template = "search_document: "
|
||||||
|
|
||||||
|
# Build index with prompt template in embedding_options
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a simple document
|
||||||
|
builder.add_text("This is a test document for indexing")
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify .meta.json was created and contains the template
|
||||||
|
meta_path = temp_index_dir / "test_index.leann.meta.json"
|
||||||
|
assert meta_path.exists(), ".meta.json file should be created during build"
|
||||||
|
|
||||||
|
# Read and parse metadata
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
# Verify embedding_options exists in metadata
|
||||||
|
assert "embedding_options" in meta_data, (
|
||||||
|
"embedding_options should be saved to .meta.json when provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify prompt_template is in embedding_options
|
||||||
|
embedding_options = meta_data["embedding_options"]
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"prompt_template should be saved within embedding_options"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the template value matches what we provided
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that when no prompt template is provided during build,
|
||||||
|
.meta.json either doesn't have embedding_options or prompt_template key.
|
||||||
|
|
||||||
|
This ensures clean metadata without unnecessary keys when features aren't used.
|
||||||
|
|
||||||
|
Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template.
|
||||||
|
"""
|
||||||
|
index_path = temp_index_dir / "test_no_template.leann"
|
||||||
|
|
||||||
|
# Build index WITHOUT prompt template
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
# No embedding_options provided
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text("Document without template")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
meta_path = temp_index_dir / "test_no_template.leann.meta.json"
|
||||||
|
assert meta_path.exists()
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
# If embedding_options exists, it should not contain prompt_template
|
||||||
|
if "embedding_options" in meta_data:
|
||||||
|
embedding_options = meta_data["embedding_options"]
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"prompt_template should not be in metadata when not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateAutoLoadOnSearch:
|
||||||
|
"""Tests for automatic loading of prompt template during search operations.
|
||||||
|
|
||||||
|
NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search).
|
||||||
|
This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad
|
||||||
|
which uses simpler mocking and doesn't hang.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that searching an index built WITHOUT a prompt template
|
||||||
|
works correctly (backward compatibility).
|
||||||
|
|
||||||
|
The searcher should handle missing prompt_template gracefully.
|
||||||
|
|
||||||
|
Expected behavior: Search succeeds, no template is used.
|
||||||
|
"""
|
||||||
|
# Build index without template
|
||||||
|
index_path = temp_index_dir / "no_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
)
|
||||||
|
builder.add_text("Document without template")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mocks
|
||||||
|
mock_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Create searcher and search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
# Verify no template in embedding_options
|
||||||
|
assert "prompt_template" not in searcher.embedding_options, (
|
||||||
|
"Searcher should not have prompt_template when not in metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryPromptTemplateAutoLoad:
|
||||||
|
"""Tests for automatic loading of separate query_prompt_template during search (R2).
|
||||||
|
|
||||||
|
These tests verify the new two-template system where:
|
||||||
|
- build_prompt_template: Applied during index building
|
||||||
|
- query_prompt_template: Applied during search operations
|
||||||
|
|
||||||
|
Expected to FAIL in Red Phase (R2) because query template extraction
|
||||||
|
and application is not yet implemented in LeannSearcher.search().
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_compute_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||||
|
with patch("leann.embedding_compute.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that search() automatically loads and applies query_prompt_template from .meta.json.
|
||||||
|
|
||||||
|
Given: Index built with separate build_prompt_template and query_prompt_template
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding is computed with "query: my query" (query template applied)
|
||||||
|
|
||||||
|
This is the core R2 requirement - query templates must be auto-loaded and applied
|
||||||
|
during search without user intervention.
|
||||||
|
|
||||||
|
Expected failure: compute_embeddings called with raw "my query" instead of
|
||||||
|
"query: my query" because query template extraction is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with separate templates in new format
|
||||||
|
index_path = temp_index_dir / "query_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock to ignore build calls
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search with query
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
# Mock the backend search to avoid actual search
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {
|
||||||
|
"labels": [["test_id_0"]], # IDs (nested list for batch support)
|
||||||
|
"distances": [[0.9]], # Distances (nested list for batch support)
|
||||||
|
}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: compute_embeddings was called with query template applied
|
||||||
|
assert mock_compute_embeddings.called, "compute_embeddings should be called during search"
|
||||||
|
|
||||||
|
# Get the actual text passed to compute_embeddings
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0] # First positional arg (list of texts)
|
||||||
|
|
||||||
|
assert len(texts_arg) == 1, "Should compute embedding for one query"
|
||||||
|
assert texts_arg[0] == "query: my query", (
|
||||||
|
f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify backward compatibility with old single prompt_template format.
|
||||||
|
|
||||||
|
Given: Index with old format (single prompt_template, no query_prompt_template)
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding computed with "doc: my query" (old template applied)
|
||||||
|
|
||||||
|
This ensures indexes built with the old single-template system continue
|
||||||
|
to work correctly with the new search implementation.
|
||||||
|
|
||||||
|
Expected failure: Old template not recognized/applied because backward
|
||||||
|
compatibility logic is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with old single-template format
|
||||||
|
index_path = temp_index_dir / "old_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": "doc: "}, # Old format
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: Old template was applied
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "doc: my query", (
|
||||||
|
f"Old prompt_template should be applied for backward compatibility: "
|
||||||
|
f"expected 'doc: my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify backward compatibility when no template is present in .meta.json.
|
||||||
|
|
||||||
|
Given: Index with no template in .meta.json (very old indexes)
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding computed with "my query" (no template, raw query)
|
||||||
|
|
||||||
|
This ensures the most basic backward compatibility - indexes without
|
||||||
|
any template support continue to work as before.
|
||||||
|
|
||||||
|
Expected failure: May fail if default template is incorrectly applied,
|
||||||
|
or if missing template causes error.
|
||||||
|
"""
|
||||||
|
# Setup: Build index without any template
|
||||||
|
index_path = temp_index_dir / "no_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
# No embedding_options at all
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: No template applied (raw query)
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "my query", (
|
||||||
|
f"No template should be applied when missing from metadata: "
|
||||||
|
f"expected 'my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that explicit provider_options can override stored query template.
|
||||||
|
|
||||||
|
Given: Index with query_prompt_template: "query: "
|
||||||
|
When: search() called with provider_options={"prompt_template": "override: "}
|
||||||
|
Then: Query embedding computed with "override: test" (override takes precedence)
|
||||||
|
|
||||||
|
This enables users to experiment with different query templates without
|
||||||
|
rebuilding the index, or to handle special query types differently.
|
||||||
|
|
||||||
|
Expected failure: provider_options parameter is accepted via **kwargs but
|
||||||
|
not used. Query embedding computed with raw "test" instead of "override: test"
|
||||||
|
because override logic is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with query template
|
||||||
|
index_path = temp_index_dir / "override_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search with override
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
# This should accept provider_options parameter
|
||||||
|
searcher.search(
|
||||||
|
"test",
|
||||||
|
top_k=1,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
provider_options={"prompt_template": "override: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Override template was applied
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "override: test", (
|
||||||
|
f"Override template should take precedence: "
|
||||||
|
f"expected 'override: test', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateReuseInChat:
|
||||||
|
"""Tests for prompt template reuse in chat/ask operations."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedding_server_manager(self):
|
||||||
|
"""Mock EmbeddingServerManager for chat tests."""
|
||||||
|
with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class:
|
||||||
|
mock_manager = Mock()
|
||||||
|
mock_manager.start_server.return_value = (True, 5557)
|
||||||
|
mock_manager_class.return_value = mock_manager
|
||||||
|
yield mock_manager
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def index_with_template(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""Build an index with a prompt template."""
|
||||||
|
index_path = temp_index_dir / "chat_template_index.leann"
|
||||||
|
template = "document_query: "
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text("Test document for chat")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
return str(index_path), template
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateIntegrationWithEmbeddingModes:
|
||||||
|
"""Tests for prompt template compatibility with different embedding modes."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mode,model,template,filename_prefix",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"openai",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
"Represent this for searching: ",
|
||||||
|
"openai_template",
|
||||||
|
),
|
||||||
|
("ollama", "nomic-embed-text", "search_query: ", "ollama_template"),
|
||||||
|
("sentence-transformers", "facebook/contriever", "query: ", "st_template"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_prompt_template_metadata_with_embedding_modes(
|
||||||
|
self, temp_index_dir, mode, model, template, filename_prefix
|
||||||
|
):
|
||||||
|
"""Verify prompt template is saved correctly across different embedding modes.
|
||||||
|
|
||||||
|
Tests that prompt templates are persisted to .meta.json for:
|
||||||
|
- OpenAI mode (primary use case)
|
||||||
|
- Ollama mode (also supports templates)
|
||||||
|
- Sentence-transformers mode (saved for forward compatibility)
|
||||||
|
|
||||||
|
Expected behavior: Template is saved to .meta.json regardless of mode.
|
||||||
|
"""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
|
||||||
|
index_path = temp_index_dir / f"{filename_prefix}.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model,
|
||||||
|
embedding_mode=mode,
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text(f"{mode.capitalize()} test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
assert meta_data["embedding_mode"] == mode
|
||||||
|
# Template should be saved for all modes (even if not used by some)
|
||||||
|
if "embedding_options" in meta_data:
|
||||||
|
assert meta_data["embedding_options"]["prompt_template"] == template
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryTemplateApplicationInComputeEmbedding:
|
||||||
|
"""Tests for query template application in compute_query_embedding() (Bug Fix).
|
||||||
|
|
||||||
|
These tests verify that query templates are applied consistently in BOTH
|
||||||
|
code paths (server and fallback) when computing query embeddings.
|
||||||
|
|
||||||
|
This addresses the bug where query templates were only applied in the
|
||||||
|
fallback path, not when using the embedding server (the default path).
|
||||||
|
|
||||||
|
Bug Context:
|
||||||
|
- Issue: Query templates were stored in metadata but only applied during
|
||||||
|
fallback (direct) computation, not when using embedding server
|
||||||
|
- Fix: Move template application to BEFORE any computation path in
|
||||||
|
compute_query_embedding() (searcher_base.py:107-110)
|
||||||
|
- Impact: Critical for models like EmbeddingGemma that require task-specific
|
||||||
|
templates for optimal performance
|
||||||
|
|
||||||
|
These tests ensure the fix works correctly and prevent regression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_with_template(self):
|
||||||
|
"""Create a temporary index with query template in metadata"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
index_dir = Path(tmpdir)
|
||||||
|
index_file = index_dir / "test.leann"
|
||||||
|
meta_file = index_dir / "test.leann.meta.json"
|
||||||
|
|
||||||
|
# Create minimal metadata with query template
|
||||||
|
metadata = {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": "hnsw",
|
||||||
|
"embedding_model": "text-embedding-embeddinggemma-300m-qat",
|
||||||
|
"dimensions": 768,
|
||||||
|
"embedding_mode": "openai",
|
||||||
|
"backend_kwargs": {
|
||||||
|
"graph_degree": 32,
|
||||||
|
"complexity": 64,
|
||||||
|
"distance_metric": "cosine",
|
||||||
|
},
|
||||||
|
"embedding_options": {
|
||||||
|
"base_url": "http://localhost:1234/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"build_prompt_template": "title: none | text: ",
|
||||||
|
"query_prompt_template": "task: search result | query: ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||||
|
|
||||||
|
# Create minimal HNSW index file (empty is okay for this test)
|
||||||
|
index_file.write_bytes(b"")
|
||||||
|
|
||||||
|
yield str(index_file)
|
||||||
|
|
||||||
|
def test_query_template_applied_in_fallback_path(self, temp_index_with_template):
|
||||||
|
"""Test that query template is applied when using fallback (direct) path"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
# Create a concrete implementation for testing
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
# Mock compute_embeddings to capture the query text
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
# Call compute_query_embedding with template (fallback path)
|
||||||
|
result = searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False, # Force fallback path
|
||||||
|
query_template="task: search result | query: ",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "task: search result | query: vector database"
|
||||||
|
assert result.shape == (1, 768)
|
||||||
|
|
||||||
|
def test_query_template_applied_in_server_path(self, temp_index_with_template):
|
||||||
|
"""Test that query template is applied when using server path"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
# Create a concrete implementation for testing
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
# Mock the server methods to capture the query text
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_ensure_server_running(passages_file, port):
|
||||||
|
return port
|
||||||
|
|
||||||
|
def mock_compute_embedding_via_server(chunks, port):
|
||||||
|
captured_queries.extend(chunks)
|
||||||
|
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||||
|
|
||||||
|
searcher._ensure_server_running = mock_ensure_server_running
|
||||||
|
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||||
|
|
||||||
|
# Call compute_query_embedding with template (server path)
|
||||||
|
result = searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=True, # Use server path
|
||||||
|
query_template="task: search result | query: ",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied BEFORE calling server
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "task: search result | query: vector database"
|
||||||
|
assert result.shape == (1, 768)
|
||||||
|
|
||||||
|
def test_query_template_without_template_parameter(self, temp_index_with_template):
|
||||||
|
"""Test that query is unchanged when no template is provided"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template=None, # No template
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify query is unchanged
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "vector database"
|
||||||
|
|
||||||
|
def test_query_template_consistency_between_paths(self, temp_index_with_template):
|
||||||
|
"""Test that both paths apply template identically"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
query_template = "task: search result | query: "
|
||||||
|
original_query = "vector database"
|
||||||
|
|
||||||
|
# Capture queries from fallback path
|
||||||
|
fallback_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
fallback_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query=original_query,
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template=query_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture queries from server path
|
||||||
|
server_queries = []
|
||||||
|
|
||||||
|
def mock_ensure_server_running(passages_file, port):
|
||||||
|
return port
|
||||||
|
|
||||||
|
def mock_compute_embedding_via_server(chunks, port):
|
||||||
|
server_queries.extend(chunks)
|
||||||
|
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||||
|
|
||||||
|
searcher._ensure_server_running = mock_ensure_server_running
|
||||||
|
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||||
|
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query=original_query,
|
||||||
|
use_server_if_available=True,
|
||||||
|
query_template=query_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify both paths produced identical templated queries
|
||||||
|
assert len(fallback_queries) == 1
|
||||||
|
assert len(server_queries) == 1
|
||||||
|
assert fallback_queries[0] == server_queries[0]
|
||||||
|
assert fallback_queries[0] == f"{query_template}{original_query}"
|
||||||
|
|
||||||
|
def test_query_template_with_empty_string(self, temp_index_with_template):
|
||||||
|
"""Test behavior with empty template string"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template="", # Empty string
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty string is falsy, so no template should be applied
|
||||||
|
assert captured_queries[0] == "vector database"
|
||||||
@@ -266,3 +266,378 @@ class TestTokenTruncation:
|
|||||||
assert result_tokens <= target_tokens, (
|
assert result_tokens <= target_tokens, (
|
||||||
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioHybridDiscovery:
|
||||||
|
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
|
||||||
|
|
||||||
|
These tests verify that get_model_token_limit() properly integrates with
|
||||||
|
the LM Studio SDK bridge for dynamic token limit discovery. The integration
|
||||||
|
should:
|
||||||
|
|
||||||
|
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
|
||||||
|
2. Convert HTTP URLs to WebSocket format for SDK queries
|
||||||
|
3. Query LM Studio SDK and use discovered limit
|
||||||
|
4. Fall back to registry when SDK returns None
|
||||||
|
5. Execute AFTER Ollama detection but BEFORE registry fallback
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
|
||||||
|
"""Verify LM Studio SDK query succeeds and returns detected limit.
|
||||||
|
|
||||||
|
When a LM Studio base_url is detected and the SDK query succeeds,
|
||||||
|
get_model_token_limit() should return the dynamically discovered
|
||||||
|
context length without falling back to the registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock _query_lmstudio_context_limit to return successful SDK query
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
# Verify WebSocket URL was passed (not HTTP)
|
||||||
|
assert base_url.startswith("ws://"), (
|
||||||
|
f"Should convert HTTP to WebSocket format, got: {base_url}"
|
||||||
|
)
|
||||||
|
return 8192 # Successful SDK query
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with HTTP URL that should be converted to WebSocket
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="custom-model", base_url="http://localhost:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 8192, "Should return limit from LM Studio SDK query"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
|
||||||
|
"""Verify fallback to registry when LM Studio SDK returns None.
|
||||||
|
|
||||||
|
When LM Studio SDK query fails (returns None), get_model_token_limit()
|
||||||
|
should fall back to the EMBEDDING_MODEL_LIMITS registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock _query_lmstudio_context_limit to return None (SDK failure)
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
return None # SDK query failed
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with known model that exists in registry
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fall back to registry value
|
||||||
|
assert limit == 2048, "Should fall back to registry when SDK returns None"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
|
||||||
|
"""Verify detection of LM Studio via port 1234.
|
||||||
|
|
||||||
|
get_model_token_limit() should recognize port 1234 as a LM Studio
|
||||||
|
server and attempt SDK query, regardless of hostname.
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with port 1234 (default LM Studio port)
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
|
||||||
|
|
||||||
|
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
|
||||||
|
assert limit == 4096, "Should return SDK query result"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_url,expected_limit,keyword",
|
||||||
|
[
|
||||||
|
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
|
||||||
|
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_model_token_limit_lmstudio_url_keyword_detection(
|
||||||
|
self, monkeypatch, test_url, expected_limit, keyword
|
||||||
|
):
|
||||||
|
"""Verify detection of LM Studio via keywords in URL.
|
||||||
|
|
||||||
|
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
|
||||||
|
in the URL as indicating a LM Studio server.
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return expected_limit
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
|
||||||
|
|
||||||
|
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
|
||||||
|
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_url,expected_protocol,expected_host",
|
||||||
|
[
|
||||||
|
("http://localhost:1234/v1", "ws://", "localhost:1234"),
|
||||||
|
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_model_token_limit_protocol_conversion(
|
||||||
|
self, monkeypatch, input_url, expected_protocol, expected_host
|
||||||
|
):
|
||||||
|
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
|
||||||
|
|
||||||
|
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
|
||||||
|
1. Convert 'http://' to 'ws://'
|
||||||
|
2. Convert 'https://' to 'wss://'
|
||||||
|
3. Remove '/v1' or other path suffixes (SDK expects base URL)
|
||||||
|
4. Preserve host and port
|
||||||
|
"""
|
||||||
|
conversions_tested = []
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
conversions_tested.append(base_url)
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_model_token_limit(model_name="test-model", base_url=input_url)
|
||||||
|
|
||||||
|
# Verify conversion happened
|
||||||
|
assert len(conversions_tested) == 1, "Should have called SDK query once"
|
||||||
|
assert conversions_tested[0].startswith(expected_protocol), (
|
||||||
|
f"Should convert to {expected_protocol}"
|
||||||
|
)
|
||||||
|
assert expected_host in conversions_tested[0], (
|
||||||
|
f"Should preserve host and port: {expected_host}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
|
||||||
|
"""Verify LM Studio detection happens AFTER Ollama detection.
|
||||||
|
|
||||||
|
The hybrid discovery order should be:
|
||||||
|
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
|
||||||
|
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
|
||||||
|
3. Registry fallback
|
||||||
|
|
||||||
|
If both Ollama and LM Studio patterns match, Ollama should take precedence.
|
||||||
|
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
|
||||||
|
"""
|
||||||
|
ollama_called = False
|
||||||
|
lmstudio_called = False
|
||||||
|
|
||||||
|
def mock_query_ollama(model_name, base_url):
|
||||||
|
nonlocal ollama_called
|
||||||
|
ollama_called = True
|
||||||
|
return 2048 # Ollama query succeeds
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal lmstudio_called
|
||||||
|
lmstudio_called = True
|
||||||
|
return None # Should not be reached if Ollama succeeds
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_ollama_context_limit",
|
||||||
|
mock_query_ollama,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with Ollama URL
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="test-model", base_url="http://localhost:11434/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ollama_called, "Should attempt Ollama query first"
|
||||||
|
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
|
||||||
|
assert limit == 2048, "Should return Ollama result"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
|
||||||
|
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
|
||||||
|
|
||||||
|
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
|
||||||
|
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
|
||||||
|
"""
|
||||||
|
lmstudio_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal lmstudio_called
|
||||||
|
lmstudio_called = True
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with non-LM Studio URLs
|
||||||
|
test_cases = [
|
||||||
|
"http://localhost:8080/v1", # Different port
|
||||||
|
"http://openai.example.com/v1", # Different service
|
||||||
|
"http://localhost:3000/v1", # Another port
|
||||||
|
]
|
||||||
|
|
||||||
|
for base_url in test_cases:
|
||||||
|
lmstudio_called = False # Reset for each test
|
||||||
|
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
|
||||||
|
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
|
||||||
|
"""Verify LM Studio detection is case-insensitive for keywords.
|
||||||
|
|
||||||
|
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
|
||||||
|
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test various case variations
|
||||||
|
test_cases = [
|
||||||
|
"http://LMStudio.local:8080/v1",
|
||||||
|
"http://LMSTUDIO.example.com/v1",
|
||||||
|
"http://LmStudio.local/v1",
|
||||||
|
"http://api.LM.STUDIO:5000/v1",
|
||||||
|
]
|
||||||
|
|
||||||
|
for base_url in test_cases:
|
||||||
|
query_called = False # Reset for each test
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
|
||||||
|
assert query_called, f"Should detect LM Studio in URL: {base_url}"
|
||||||
|
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenLimitCaching:
|
||||||
|
"""Tests for token limit caching to prevent repeated SDK/API calls.
|
||||||
|
|
||||||
|
Caching prevents duplicate SDK/API calls within the same Python process,
|
||||||
|
which is important because:
|
||||||
|
1. LM Studio SDK load() can load duplicate model instances
|
||||||
|
2. Ollama /api/show queries add latency
|
||||||
|
3. Registry lookups are pure overhead
|
||||||
|
|
||||||
|
Cache is process-scoped and resets between leann build invocations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Clear cache before each test."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
_token_limit_cache.clear()
|
||||||
|
|
||||||
|
def test_registry_lookup_is_cached(self):
|
||||||
|
"""Verify that registry lookups are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# First call
|
||||||
|
limit1 = get_model_token_limit("text-embedding-3-small")
|
||||||
|
assert limit1 == 8192
|
||||||
|
|
||||||
|
# Verify it's in cache
|
||||||
|
cache_key = ("text-embedding-3-small", "")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 8192
|
||||||
|
|
||||||
|
# Second call should use cache
|
||||||
|
limit2 = get_model_token_limit("text-embedding-3-small")
|
||||||
|
assert limit2 == 8192
|
||||||
|
|
||||||
|
def test_default_fallback_is_cached(self):
|
||||||
|
"""Verify that default fallbacks are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# First call with unknown model
|
||||||
|
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||||
|
assert limit1 == 512
|
||||||
|
|
||||||
|
# Verify it's in cache
|
||||||
|
cache_key = ("unknown-model-xyz", "")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 512
|
||||||
|
|
||||||
|
# Second call should use cache
|
||||||
|
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||||
|
assert limit2 == 512
|
||||||
|
|
||||||
|
def test_different_urls_create_separate_cache_entries(self):
|
||||||
|
"""Verify that different base_urls create separate cache entries."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# Same model, different URLs
|
||||||
|
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
|
||||||
|
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
|
||||||
|
|
||||||
|
# Both should find the model in registry (2048)
|
||||||
|
assert limit1 == 2048
|
||||||
|
assert limit2 == 2048
|
||||||
|
|
||||||
|
# But they should be separate cache entries
|
||||||
|
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
|
||||||
|
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
|
||||||
|
|
||||||
|
assert cache_key1 in _token_limit_cache
|
||||||
|
assert cache_key2 in _token_limit_cache
|
||||||
|
assert len(_token_limit_cache) == 2
|
||||||
|
|
||||||
|
def test_cache_prevents_repeated_lookups(self):
|
||||||
|
"""Verify that cache prevents repeated registry/API lookups."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
model_name = "text-embedding-ada-002"
|
||||||
|
|
||||||
|
# First call - should add to cache
|
||||||
|
assert len(_token_limit_cache) == 0
|
||||||
|
limit1 = get_model_token_limit(model_name)
|
||||||
|
|
||||||
|
cache_size_after_first = len(_token_limit_cache)
|
||||||
|
assert cache_size_after_first == 1
|
||||||
|
|
||||||
|
# Multiple subsequent calls - cache size should not change
|
||||||
|
for _ in range(5):
|
||||||
|
limit = get_model_token_limit(model_name)
|
||||||
|
assert limit == limit1
|
||||||
|
assert len(_token_limit_cache) == cache_size_after_first
|
||||||
|
|
||||||
|
def test_versioned_model_names_cached_correctly(self):
|
||||||
|
"""Verify that versioned model names (e.g., model:tag) are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# Model with version tag
|
||||||
|
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
|
||||||
|
assert limit == 2048
|
||||||
|
|
||||||
|
# Should be cached with full name including version
|
||||||
|
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 2048
|
||||||
|
|||||||
Reference in New Issue
Block a user