Compare commits

...

12 Commits

Author SHA1 Message Date
aakash
877fbe81f4 Fix: Prevent duplicate PDF processing when using --file-types .pdf
Fixes #175

Problem:
When --file-types .pdf is specified, PDFs were being processed twice:
1. Separately with PyMuPDF/pdfplumber extractors
2. Again in the 'other file types' section via SimpleDirectoryReader

This caused duplicate processing and potential conflicts.

Solution:
- Exclude .pdf from other_file_extensions when PDFs are already
  processed separately
- Only load other file types if there are extensions to process
- Prevents duplicate PDF processing

Changes:
- Added logic to filter out .pdf from code_extensions when loading
  other file types if PDFs were processed separately
- Updated SimpleDirectoryReader to use filtered extensions
- Added check to skip loading if no other extensions to process
2025-11-30 11:18:57 -08:00
Andy Lee
eb909ccec5 docs: survey 2025-11-20 09:20:45 -08:00
ww26
969f514564 Fix prompt template bugs: build template ignored and runtime override not wired (#173)
* Fix prompt template bugs in build and search

Bug 1: Build template ignored in new format
- Updated compute_embeddings_openai() to read build_prompt_template or prompt_template
- Updated compute_embeddings_ollama() with same fix
- Maintains backward compatibility with old single-template format

Bug 2: Runtime override not wired up
- Wired CLI search to pass provider_options to searcher.search()
- Enables runtime template override during search via --embedding-prompt-template

All 42 prompt template tests passing.

Fixes #155

* Fix: Prevent embedding server from applying templates during search

- Filter out all prompt templates (build_prompt_template, query_prompt_template, prompt_template) from provider_options when launching embedding server during search
- Templates are already applied in compute_query_embedding() before server call
- Prevents double-templating and ensures runtime override works correctly

This fixes the issue where --embedding-prompt-template during search was ignored because the server was applying build_prompt_template instead.

* Format code with ruff
2025-11-16 20:56:42 -08:00
ww26
1ef9cba7de Feature/prompt templates and lmstudio sdk (#171)
* Add prompt template support and LM Studio SDK integration

Features:

- Prompt template support for embedding models (via --embedding-prompt-template)

- LM Studio SDK integration for automatic context length detection

- Hybrid token limit discovery (Ollama → LM Studio → Registry → Default)

- Client-side token truncation to prevent silent failures

- Automatic persistence of embedding_options to .meta.json

Implementation:

- Added _query_lmstudio_context_limit() with Node.js subprocess bridge

- Modified compute_embeddings_openai() to apply prompt templates before truncation

- Extended CLI with --embedding-prompt-template flag for build and search

- URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords)

- HTTP→WebSocket URL conversion for SDK compatibility

Tests:

- 60 passing tests across 5 test files

- Comprehensive coverage of prompt templates, LM Studio integration, and token handling

- Parametrized tests for maintainability and clarity

* Add integration tests and fix LM Studio SDK bridge

Features:
- End-to-end integration tests for prompt template with EmbeddingGemma
- Integration tests for hybrid token limit discovery mechanism
- Tests verify real-world functionality with live services (LM Studio, Ollama)

Fixes:
- LM Studio SDK bridge now uses client.embedding.load() for embedding models
- Fixed NODE_PATH resolution to include npm global modules
- Fixed integration test to use WebSocket URL (ws://) for SDK bridge

Tests:
- test_prompt_template_e2e.py: 8 integration tests covering:
  - Prompt template prepending with LM Studio (EmbeddingGemma)
  - LM Studio SDK bridge for context length detection
  - Ollama dynamic token limit detection
  - Hybrid discovery fallback mechanism (registry, default)
- All tests marked with @pytest.mark.integration for selective execution
- Tests gracefully skip when services unavailable

Documentation:
- Updated tests/README.md with integration test section
- Added prerequisites and running instructions
- Documented that prompt templates are ONLY for EmbeddingGemma
- Added integration marker to pyproject.toml

Test Results:
- All 8 integration tests passing with live services
- Confirmed prompt templates work correctly with EmbeddingGemma
- Verified LM Studio SDK bridge auto-detects context length (2048)
- Validated hybrid token limit discovery across all backends

* Add prompt template support to Ollama mode

Extends prompt template functionality from OpenAI mode to Ollama for backend consistency.

Changes:
- Add provider_options parameter to compute_embeddings_ollama()
- Apply prompt template before token truncation (lines 1005-1011)
- Pass provider_options through compute_embeddings() call chain

Tests:
- test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama
- test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template
- Both tests pass with live Ollama service (2/2 passing)

Usage:
leann build --embedding-mode ollama --embedding-prompt-template "query: " ...

* Fix LM Studio SDK bridge to respect JIT auto-evict settings

Problem: SDK bridge called client.embedding.load() which loaded models into
LM Studio memory and bypassed JIT auto-evict settings, causing duplicate
model instances to accumulate.

Root cause analysis (from Perplexity research):
- Explicit SDK load() commands are treated as "pinned" models
- JIT auto-evict only applies to models loaded reactively via API requests
- SDK-loaded models remain in memory until explicitly unloaded

Solutions implemented:

1. Add model.unload() after metadata query (line 243)
   - Load model temporarily to get context length
   - Unload immediately to hand control back to JIT system
   - Subsequent API requests trigger JIT load with auto-evict

2. Add token limit caching to prevent repeated SDK calls
   - Cache discovered limits in _token_limit_cache dict (line 48)
   - Key: (model_name, base_url), Value: token_limit
   - Prevents duplicate load/unload cycles within same process
   - Cache shared across all discovery methods (Ollama, SDK, registry)

Tests:
- TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py)
- Manual testing confirmed no duplicate models in LM Studio after fix
- All existing tests pass

Impact:
- Respects user's LM Studio JIT and auto-evict settings
- Reduces model memory footprint
- Faster subsequent builds (cached limits)

* Document prompt template and LM Studio SDK features

Added comprehensive documentation for new optional embedding features:

Configuration Guide (docs/configuration-guide.md):
- New section: "Optional Embedding Features"
- Task-Specific Prompt Templates subsection:
  - Explains EmbeddingGemma use case with document/query prompts
  - CLI and Python API examples
  - Clear warnings about compatible vs incompatible models
  - References to GitHub issue #155 and HuggingFace blog
- LM Studio Auto-Detection subsection:
  - Prerequisites (Node.js + @lmstudio/sdk)
  - How auto-detection works (4-step process)
  - Benefits and optional nature clearly stated

FAQ (docs/faq.md):
- FAQ #2: When should I use prompt templates?
  - DO/DON'T guidance with examples
  - Links to detailed configuration guide
- FAQ #3: Why is LM Studio loading multiple copies?
  - Explains the JIT auto-evict fix
  - Troubleshooting steps if still seeing issues
- FAQ #4: Do I need Node.js and @lmstudio/sdk?
  - Clarifies it's completely optional
  - Lists benefits if installed
  - Installation instructions

Cross-references between documents for easy navigation between quick reference and detailed guides.

* Add separate build/query template support for task-specific models

Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility.

* Consolidate prompt template tests from 44 to 37 tests

Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing.

* Fix query template application in compute_query_embedding

Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches.

Changes:

- Move template application to before any computation path (searcher_base.py:109-110)

- Add comprehensive tests for both server and fallback paths

- Consolidate tests into test_prompt_template_persistence.py

Tests verify:

- Template applied when using embedding server

- Template applied in fallback path

- Consistent behavior between both paths

* Apply ruff formatting and fix linting issues

- Remove unused imports

- Fix import ordering

- Remove unused variables

- Apply code formatting

* Fix CI test failures: mock OPENAI_API_KEY in tests

Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
2025-11-14 15:25:17 -08:00
yichuan520030910320
a63550944b update pkg name 2025-11-13 14:57:28 -08:00
yichuan520030910320
97493a2896 update module to install 2025-11-13 14:53:25 -08:00
Aakash Suresh
f7d2dc6e7c Merge pull request #163 from orbisai0security/fix-semgrep-python.lang.security.audit.eval-detected.eval-detected-apps-slack-data-slack-mc-b134f52c-f326
Fix: Unsafe Code Execution Function Could Allow External Code Injection in apps/slack_data/slack_mcp_reader.py
2025-11-13 14:35:44 -08:00
Gil Vernik
ea86b283cb [bug] fix when no package name found (#167)
* fix when no package name

* fix pre-commit style issue

* fix to ruff (legacy alias) fail test
2025-11-13 14:23:13 -08:00
aakash
e7519bceaa Fix CI: sync uv.lock from main and remove .lycheeignore (workflow exclusion is sufficient) 2025-11-13 13:10:07 -08:00
aakash
abf0b2c676 Fix CI: improve security fix and add link checker configuration
- Fix import order (ast before asyncio)
- Remove NameError from exception handling (ast.literal_eval doesn't raise it)
- Add .lycheeignore to exclude intermittently unavailable star-history API
- Update link-check workflow to exclude star-history API and accept 503 status codes
2025-11-13 13:05:00 -08:00
GitHub Actions
3c4785bb63 chore: release v0.3.5 2025-11-12 06:01:25 +00:00
orbisai0security
930b79cc98 fix: semgrep_python.lang.security.audit.eval-detected.eval-detected_apps/slack_data/slack_mcp_reader.py_157 2025-11-12 03:47:18 +00:00
23 changed files with 3163 additions and 37 deletions

View File

@@ -14,6 +14,6 @@ jobs:
- uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2
with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -16,12 +16,24 @@
</a>
</p>
<div align="center">
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
</a>
<p>
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
</p>
</div>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
The smallest vector index in the world. RAG Everything with LEANN!
</h2>
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.

View File

@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
flexible message processing options.
"""
import ast
import asyncio
import json
import logging
@@ -146,16 +147,16 @@ class SlackMCPReader:
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = eval(match.group(1))
except (ValueError, SyntaxError, NameError):
error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError):
pass
else:
# Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = eval(match.group(1))
except (ValueError, SyntaxError, NameError):
error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError):
pass
if self._is_cache_sync_error(error_dict):

View File

@@ -158,6 +158,95 @@ builder.build_index("./indexes/my-notes", chunks)
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Optional Embedding Features
### Task-Specific Prompt Templates
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
- **Indexing documents**: `"title: none | text: "`
- **Search queries**: `"task: search result | query: "`
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
```bash
# Build index with EmbeddingGemma (via LM Studio or Ollama)
leann build my-docs \
--docs ./documents \
--embedding-mode openai \
--embedding-model text-embedding-embeddinggemma-300m-qat \
--embedding-api-base http://localhost:1234/v1 \
--embedding-prompt-template "title: none | text: " \
--force
# Search with query-specific prompt
leann search my-docs \
--query "What is quantum computing?" \
--embedding-prompt-template "task: search result | query: "
```
**Important Notes:**
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
**Python API:**
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
embedding_mode="openai",
embedding_model="text-embedding-embeddinggemma-300m-qat",
embedding_options={
"base_url": "http://localhost:1234/v1",
"api_key": "lm-studio",
"prompt_template": "title: none | text: ",
},
)
builder.build_index("./indexes/my-docs", chunks)
```
**References:**
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
### LM Studio Auto-Detection (Optional)
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
**Prerequisites:**
```bash
# Install Node.js (if not already installed)
# Then install the LM Studio SDK globally
npm install -g @lmstudio/sdk
```
**How it works:**
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
2. Queries model metadata via Node.js subprocess
3. Automatically unloads model after query (respects your JIT auto-evict settings)
4. Falls back to static registry if SDK unavailable
**No configuration needed** - it works automatically when SDK is installed:
```bash
leann build my-docs \
--docs ./documents \
--embedding-mode openai \
--embedding-model text-embedding-nomic-embed-text-v1.5 \
--embedding-api-base http://localhost:1234/v1
# Context length auto-detected if SDK available
# Falls back to registry (2048) if not
```
**Benefits:**
- ✅ Automatic token limit detection
- ✅ Respects LM Studio JIT auto-evict settings
- ✅ No manual registry maintenance
- ✅ Graceful fallback if SDK unavailable
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)

View File

@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
## 2. When should I use prompt templates?
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
**Example usage with EmbeddingGemma:**
```bash
# Build with document prompt
leann build my-docs --embedding-prompt-template "title: none | text: "
# Search with query prompt
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
```
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
## 3. Why is LM Studio loading multiple copies of my model?
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
**If you still see duplicates:**
- Update to the latest LEANN version
- Restart LM Studio to clear loaded models
- Check that you have JIT auto-evict enabled in LM Studio settings
**How it works now:**
1. LEANN loads model temporarily to get context length
2. Immediately unloads after query
3. LM Studio JIT loads model on-demand for actual embeddings
4. Auto-evicts per your settings
## 4. Do I need Node.js and @lmstudio/sdk?
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
**Benefits if you install it:**
- Automatic context length detection for LM Studio models
- No manual registry maintenance
- Always gets accurate token limits from the model itself
**To install (optional):**
```bash
npm install -g @lmstudio/sdk
```
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.3.4"
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
version = "0.3.5"
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build]
# Key: simplified CMake path

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.3.4"
version = "0.3.5"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.3.4",
"leann-core==0.3.5",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.3.4"
version = "0.3.5"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -916,6 +916,7 @@ class LeannSearcher:
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
provider_options: Optional[dict[str, Any]] = None,
**kwargs,
) -> list[SearchResult]:
"""
@@ -979,10 +980,24 @@ class LeannSearcher:
start_time = time.time()
# Extract query template from stored embedding_options with fallback chain:
# 1. Check provider_options override (highest priority)
# 2. Check query_prompt_template (new format)
# 3. Check prompt_template (old format for backward compat)
# 4. None (no template)
query_template = None
if provider_options and "prompt_template" in provider_options:
query_template = provider_options["prompt_template"]
elif "query_prompt_template" in self.embedding_options:
query_template = self.embedding_options["query_prompt_template"]
elif "prompt_template" in self.embedding_options:
query_template = self.embedding_options["prompt_template"]
query_embedding = self.backend_impl.compute_query_embedding(
query,
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
query_template=query_template,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time

View File

@@ -144,6 +144,18 @@ Examples:
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
build_parser.add_argument(
"--embedding-prompt-template",
type=str,
default=None,
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
)
build_parser.add_argument(
"--query-prompt-template",
type=str,
default=None,
help="Prompt template for queries (different from build template for task-specific models)",
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index"
)
@@ -260,6 +272,12 @@ Examples:
action="store_true",
help="Display file paths and metadata in search results",
)
search_parser.add_argument(
"--embedding-prompt-template",
type=str,
default=None,
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
@@ -1162,6 +1180,11 @@ Examples:
print(f"Warning: Could not process {file_path}: {e}")
# Load other file types with default reader
# Exclude PDFs from code_extensions if they were already processed separately
other_file_extensions = code_extensions
if should_process_pdfs and ".pdf" in code_extensions:
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
try:
# Create a custom file filter function using our PathSpec
def file_filter(
@@ -1177,15 +1200,19 @@ Examples:
except (ValueError, OSError):
return True # Include files that can't be processed
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=code_extensions,
file_extractor={}, # Use default extractors
exclude_hidden=not include_hidden,
filename_as_id=True,
).load_data(show_progress=True)
# Only load other file types if there are extensions to process
if other_file_extensions:
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=other_file_extensions,
file_extractor={}, # Use default extractors
exclude_hidden=not include_hidden,
filename_as_id=True,
).load_data(show_progress=True)
else:
other_docs = []
# Filter documents after loading based on gitignore rules
filtered_docs = []
@@ -1398,6 +1425,14 @@ Examples:
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
if args.query_prompt_template:
# New format: separate templates
if args.embedding_prompt_template:
embedding_options["build_prompt_template"] = args.embedding_prompt_template
embedding_options["query_prompt_template"] = args.query_prompt_template
elif args.embedding_prompt_template:
# Old format: single template (backward compat)
embedding_options["prompt_template"] = args.embedding_prompt_template
builder = LeannBuilder(
backend_name=args.backend_name,
@@ -1519,6 +1554,11 @@ Examples:
print("Invalid input. Aborting search.")
return
# Build provider_options for runtime override
provider_options = {}
if args.embedding_prompt_template:
provider_options["prompt_template"] = args.embedding_prompt_template
searcher = LeannSearcher(index_path=index_path)
results = searcher.search(
query,
@@ -1528,6 +1568,7 @@ Examples:
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
provider_options=provider_options if provider_options else None,
)
print(f"Search results for '{query}' (top {len(results)}):")

View File

@@ -4,8 +4,10 @@ Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import json
import logging
import os
import subprocess
import time
from typing import Any, Optional
@@ -40,6 +42,11 @@ EMBEDDING_MODEL_LIMITS = {
"text-embedding-ada-002": 8192,
}
# Runtime cache for dynamically discovered token limits
# Key: (model_name, base_url), Value: token_limit
# Prevents repeated SDK/API calls for the same model
_token_limit_cache: dict[tuple[str, str], int] = {}
def get_model_token_limit(
model_name: str,
@@ -49,6 +56,7 @@ def get_model_token_limit(
"""
Get token limit for a given embedding model.
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
Caches discovered limits to prevent repeated API/SDK calls.
Args:
model_name: Name of the embedding model
@@ -58,12 +66,33 @@ def get_model_token_limit(
Returns:
Token limit for the model in tokens
"""
# Check cache first to avoid repeated SDK/API calls
cache_key = (model_name, base_url or "")
if cache_key in _token_limit_cache:
cached_limit = _token_limit_cache[cache_key]
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
return cached_limit
# Try Ollama dynamic discovery if base_url provided
if base_url:
# Detect Ollama servers by port or "ollama" in URL
if "11434" in base_url or "ollama" in base_url.lower():
limit = _query_ollama_context_limit(model_name, base_url)
if limit:
_token_limit_cache[cache_key] = limit
return limit
# Try LM Studio SDK discovery
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
# Convert HTTP to WebSocket URL
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
# Remove /v1 suffix if present
if ws_url.endswith("/v1"):
ws_url = ws_url[:-3]
limit = _query_lmstudio_context_limit(model_name, ws_url)
if limit:
_token_limit_cache[cache_key] = limit
return limit
# Fallback to known model registry with version handling (from PR #154)
@@ -72,19 +101,25 @@ def get_model_token_limit(
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[model_name]
limit = EMBEDDING_MODEL_LIMITS[model_name]
_token_limit_cache[cache_key] = limit
return limit
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[base_model_name]
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
_token_limit_cache[cache_key] = limit
return limit
# Check partial matches for common patterns
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
return limit
_token_limit_cache[cache_key] = registry_limit
return registry_limit
# Default fallback
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
_token_limit_cache[cache_key] = default
return default
@@ -185,6 +220,91 @@ def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]
return None
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query LM Studio SDK for model context length via Node.js subprocess.
Args:
model_name: Name of the LM Studio model
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
Returns:
Context limit in tokens if found, None otherwise
"""
# Inline JavaScript using @lmstudio/sdk
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
js_code = f"""
const {{ LMStudioClient }} = require('@lmstudio/sdk');
(async () => {{
try {{
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
const contextLength = await model.getContextLength();
await model.unload(); // Unload immediately to respect JIT auto-evict settings
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
}} catch (error) {{
console.error(JSON.stringify({{ error: error.message }}));
process.exit(1);
}}
}})();
"""
try:
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
env = os.environ.copy()
# Try to get npm global root (works with nvm, brew node, etc.)
try:
npm_root = subprocess.run(
["npm", "root", "-g"],
capture_output=True,
text=True,
timeout=5,
)
if npm_root.returncode == 0:
global_modules = npm_root.stdout.strip()
# Append to existing NODE_PATH if present
existing_node_path = env.get("NODE_PATH", "")
env["NODE_PATH"] = (
f"{global_modules}:{existing_node_path}"
if existing_node_path
else global_modules
)
except Exception:
# If npm not available, continue with existing NODE_PATH
pass
result = subprocess.run(
["node", "-e", js_code],
capture_output=True,
text=True,
timeout=10,
env=env,
)
if result.returncode != 0:
logger.debug(f"LM Studio SDK error: {result.stderr}")
return None
data = json.loads(result.stdout)
context_length = data.get("contextLength")
if context_length and context_length > 0:
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
return context_length
except FileNotFoundError:
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
except subprocess.TimeoutExpired:
logger.debug("LM Studio SDK query timeout")
except json.JSONDecodeError:
logger.debug("LM Studio SDK returned invalid JSON")
except Exception as e:
logger.debug(f"LM Studio SDK query failed: {e}")
return None
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
@@ -232,6 +352,7 @@ def compute_embeddings(
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
provider_options=provider_options,
)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
@@ -241,6 +362,7 @@ def compute_embeddings(
model_name,
is_build=is_build,
host=provider_options.get("host"),
provider_options=provider_options,
)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
@@ -579,6 +701,7 @@ def compute_embeddings_openai(
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
@@ -597,26 +720,40 @@ def compute_embeddings_openai(
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
resolved_base_url = resolve_openai_base_url(base_url)
resolved_api_key = resolve_openai_api_key(api_key)
# Extract base_url and api_key from provider_options if not provided directly
provider_options = provider_options or {}
effective_base_url = base_url or provider_options.get("base_url")
effective_api_key = api_key or provider_options.get("api_key")
resolved_base_url = resolve_openai_base_url(effective_base_url)
resolved_api_key = resolve_openai_api_key(effective_api_key)
if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
# Create OpenAI client
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
print(f"len of texts: {len(texts)}")
# Apply prompt template if provided
# Priority: build_prompt_template (new format) > prompt_template (old format)
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
"prompt_template"
)
if prompt_template:
logger.warning(f"Applying prompt template: '{prompt_template}'")
texts = [f"{prompt_template}{text}" for text in texts]
# Query token limit and apply truncation
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
texts = truncate_to_token_limit(texts, token_limit)
# OpenAI has limits on batch size and input length
max_batch_size = 800 # Conservative batch size because the token limit is 300K
all_embeddings = []
@@ -647,7 +784,15 @@ def compute_embeddings_openai(
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
# Verify we got the expected number of embeddings
if len(batch_embeddings) != len(batch_texts):
logger.warning(
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
)
# Only take the number of embeddings that match the batch size
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
@@ -737,6 +882,7 @@ def compute_embeddings_ollama(
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Compute embeddings using Ollama API with true batch processing.
@@ -749,6 +895,7 @@ def compute_embeddings_ollama(
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (defaults to environment or http://localhost:11434)
provider_options: Optional provider-specific options (e.g., prompt_template)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -885,6 +1032,17 @@ def compute_embeddings_ollama(
logger.info(f"Using batch size: {batch_size} for true batch processing")
# Apply prompt template if provided
provider_options = provider_options or {}
# Priority: build_prompt_template (new format) > prompt_template (old format)
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
"prompt_template"
)
if prompt_template:
logger.warning(f"Applying prompt template: '{prompt_template}'")
texts = [f"{prompt_template}{text}" for text in texts]
# Get model token limit and apply truncation before batching
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
logger.info(f"Model '{model_name}' token limit: {token_limit}")

View File

@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
query_template: Optional[str] = None,
) -> np.ndarray:
"""Compute embedding for a query string
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
query_template: Optional prompt template to prepend to query
Returns:
Query embedding as numpy array with shape (1, D)

View File

@@ -33,6 +33,8 @@ def autodiscover_backends():
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata["name"]
if dist_name is None:
continue
if dist_name.startswith("leann-backend-"):
backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name)

View File

@@ -71,6 +71,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
or "mips"
)
# Filter out ALL prompt templates from provider_options during search
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
# The server should never apply templates during search to avoid double-templating
search_provider_options = {
k: v
for k, v in self.embedding_options.items()
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
}
server_started, actual_port = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
@@ -78,7 +87,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file,
distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options,
provider_options=search_provider_options,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -90,6 +99,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
query: str,
use_server_if_available: bool = True,
zmq_port: int = 5557,
query_template: Optional[str] = None,
) -> np.ndarray:
"""
Compute embedding for a query string.
@@ -98,10 +108,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
query_template: Optional prompt template to prepend to query
Returns:
Query embedding as numpy array
"""
# Apply query template BEFORE any computation path
# This ensures template is applied consistently for both server and fallback paths
if query_template:
query = f"{query_template}{query}"
# Try to use embedding server if available and requested
if use_server_if_available:
try:

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.3.4"
version = "0.3.5"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -69,7 +69,8 @@ diskann = [
# Add a new optional dependency group for document processing
documents = [
"beautifulsoup4>=4.13.0", # For HTML parsing
"python-docx>=0.8.11", # For Word documents
"python-docx>=0.8.11", # For Word documents (creating/editing)
"docx2txt>=0.9", # For Word documents (text extraction)
"openpyxl>=3.1.0", # For Excel files
"pandas>=2.2.0", # For data processing
]
@@ -164,6 +165,7 @@ python_functions = ["test_*"]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"openai: marks tests that require OpenAI API key",
"integration: marks tests that require live services (Ollama, LM Studio, etc.)",
]
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
addopts = [

View File

@@ -36,6 +36,14 @@ Tests DiskANN graph partitioning functionality:
- Includes performance comparison between DiskANN (with partition) and HNSW
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
### `test_prompt_template_e2e.py`
Integration tests for prompt template feature with live embedding services:
- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio)
- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default)
- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk)
- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration`
- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models
## Running Tests
### Install test dependencies:
@@ -66,6 +74,12 @@ pytest tests/ -m "not openai"
# Skip slow tests
pytest tests/ -m "not slow"
# Skip integration tests (that require live services)
pytest tests/ -m "not integration"
# Run only integration tests (requires LM Studio or Ollama running)
pytest tests/test_prompt_template_e2e.py -v -s
# Run DiskANN partition tests (requires local machine, not CI)
pytest tests/test_diskann_partition.py
```
@@ -101,6 +115,20 @@ The `pytest.ini` file configures:
- Custom markers for slow and OpenAI tests
- Verbose output with short tracebacks
### Integration Test Prerequisites
Integration tests (`test_prompt_template_e2e.py`) require live services:
**Required:**
- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded
**Optional:**
- Ollama running at `http://localhost:11434` for token limit detection tests
- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests
Tests gracefully skip if services are unavailable.
### Known Issues
- OpenAI tests are automatically skipped if no API key is provided
- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed)

View File

@@ -0,0 +1,533 @@
"""
Tests for CLI argument integration of --embedding-prompt-template.
These tests verify that:
1. The --embedding-prompt-template flag is properly registered on build and search commands
2. The template value flows from CLI args to embedding_options dict
3. The template is passed through to compute_embeddings() function
4. Default behavior (no flag) is handled correctly
"""
from unittest.mock import Mock, patch
from leann.cli import LeannCLI
class TestCLIPromptTemplateArgument:
"""Tests for --embedding-prompt-template on build and search commands."""
def test_commands_accept_prompt_template_argument(self):
"""Verify that build and search parsers accept --embedding-prompt-template flag."""
cli = LeannCLI()
parser = cli.create_parser()
template_value = "search_query: "
# Test build command
build_args = parser.parse_args(
[
"build",
"test-index",
"--docs",
"/tmp/test-docs",
"--embedding-prompt-template",
template_value,
]
)
assert build_args.command == "build"
assert hasattr(build_args, "embedding_prompt_template"), (
"build command should have embedding_prompt_template attribute"
)
assert build_args.embedding_prompt_template == template_value
# Test search command
search_args = parser.parse_args(
["search", "test-index", "my query", "--embedding-prompt-template", template_value]
)
assert search_args.command == "search"
assert hasattr(search_args, "embedding_prompt_template"), (
"search command should have embedding_prompt_template attribute"
)
assert search_args.embedding_prompt_template == template_value
def test_commands_default_to_none(self):
"""Verify default value is None when flag not provided (backward compatibility)."""
cli = LeannCLI()
parser = cli.create_parser()
# Test build command default
build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"])
assert hasattr(build_args, "embedding_prompt_template"), (
"build command should have embedding_prompt_template attribute"
)
assert build_args.embedding_prompt_template is None, (
"Build default value should be None when flag not provided"
)
# Test search command default
search_args = parser.parse_args(["search", "test-index", "my query"])
assert hasattr(search_args, "embedding_prompt_template"), (
"search command should have embedding_prompt_template attribute"
)
assert search_args.embedding_prompt_template is None, (
"Search default value should be None when flag not provided"
)
class TestBuildCommandPromptTemplateArgumentExtras:
"""Additional build-specific tests for prompt template argument."""
def test_build_command_prompt_template_with_multiword_value(self):
"""
Verify that template values with spaces are handled correctly.
Templates like "search_document: " or "Represent this sentence for searching: "
should be accepted as a single string argument.
"""
cli = LeannCLI()
parser = cli.create_parser()
template = "Represent this sentence for searching: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
"/tmp/test-docs",
"--embedding-prompt-template",
template,
]
)
assert args.embedding_prompt_template == template
class TestPromptTemplateStoredInEmbeddingOptions:
"""Tests for template storage in embedding_options dict."""
@patch("leann.cli.LeannBuilder")
def test_prompt_template_stored_in_embedding_options_on_build(
self, mock_builder_class, tmp_path
):
"""
Verify that when --embedding-prompt-template is provided to build command,
the value is stored in embedding_options dict passed to LeannBuilder.
This test will fail because the CLI doesn't currently process this argument
and add it to embedding_options.
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
# Create CLI and run build command
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "search_query: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--force", # Force rebuild to ensure LeannBuilder is called
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with embedding_options containing prompt_template
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when template provided"
)
assert "prompt_template" in embedding_options, (
"embedding_options should contain 'prompt_template' key"
)
assert embedding_options["prompt_template"] == template, (
f"Template should be '{template}', got {embedding_options.get('prompt_template')}"
)
@patch("leann.cli.LeannBuilder")
def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path):
"""
Verify that when --embedding-prompt-template is NOT provided,
embedding_options either doesn't have the key or it's None.
This ensures we don't pass empty/None values unnecessarily.
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--force", # Force rebuild to ensure LeannBuilder is called
]
)
import asyncio
asyncio.run(cli.build_index(args))
# Check that if embedding_options is passed, it doesn't have prompt_template
call_kwargs = mock_builder_class.call_args.kwargs
if call_kwargs.get("embedding_options"):
embedding_options = call_kwargs["embedding_options"]
# Either the key shouldn't exist, or it should be None
assert (
"prompt_template" not in embedding_options
or embedding_options["prompt_template"] is None
), "prompt_template should not be set when flag not provided"
# R1 Tests: Build-time separate template storage
@patch("leann.cli.LeannBuilder")
def test_build_stores_separate_templates(self, mock_builder_class, tmp_path):
"""
R1 Test 1: Verify that when both --embedding-prompt-template and
--query-prompt-template are provided to build command, both values
are stored separately in embedding_options dict as build_prompt_template
and query_prompt_template.
This test will fail because:
1. CLI doesn't accept --query-prompt-template flag yet
2. CLI doesn't store templates as separate build_prompt_template and
query_prompt_template keys
Expected behavior after implementation:
- .meta.json contains: {"embedding_options": {
"build_prompt_template": "doc: ",
"query_prompt_template": "query: "
}}
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
build_template = "doc: "
query_template = "query: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
build_template,
"--query-prompt-template",
query_template,
"--force",
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with separate template keys
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when templates provided"
)
assert "build_prompt_template" in embedding_options, (
"embedding_options should contain 'build_prompt_template' key"
)
assert embedding_options["build_prompt_template"] == build_template, (
f"build_prompt_template should be '{build_template}'"
)
assert "query_prompt_template" in embedding_options, (
"embedding_options should contain 'query_prompt_template' key"
)
assert embedding_options["query_prompt_template"] == query_template, (
f"query_prompt_template should be '{query_template}'"
)
# Old key should NOT be present when using new separate template format
assert "prompt_template" not in embedding_options, (
"Old 'prompt_template' key should not be present with separate templates"
)
@patch("leann.cli.LeannBuilder")
def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path):
"""
R1 Test 2: Verify backward compatibility - when only
--embedding-prompt-template is provided (old behavior), it should
still be stored as 'prompt_template' in embedding_options.
This ensures existing workflows continue to work unchanged.
This test currently passes because it matches existing behavior, but it
documents the requirement that this behavior must be preserved after
implementing the separate template feature.
Expected behavior:
- .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}}
- No build_prompt_template or query_prompt_template keys
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "prompt: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--force",
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with old format
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when template provided"
)
assert "prompt_template" in embedding_options, (
"embedding_options should contain old 'prompt_template' key for backward compat"
)
assert embedding_options["prompt_template"] == template, (
f"prompt_template should be '{template}'"
)
# New keys should NOT be present in backward compat mode
assert "build_prompt_template" not in embedding_options, (
"build_prompt_template should not be present with single template flag"
)
assert "query_prompt_template" not in embedding_options, (
"query_prompt_template should not be present with single template flag"
)
@patch("leann.cli.LeannBuilder")
def test_build_no_templates(self, mock_builder_class, tmp_path):
"""
R1 Test 3: Verify that when no template flags are provided,
embedding_options has no prompt template keys.
This ensures clean defaults and no unnecessary keys in .meta.json.
This test currently passes because it matches existing behavior, but it
documents the requirement that this behavior must be preserved after
implementing the separate template feature.
Expected behavior:
- .meta.json has no prompt_template, build_prompt_template, or
query_prompt_template keys (or embedding_options is empty/None)
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"])
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that no template keys are present
call_kwargs = mock_builder_class.call_args.kwargs
if call_kwargs.get("embedding_options"):
embedding_options = call_kwargs["embedding_options"]
# None of the template keys should be present
assert "prompt_template" not in embedding_options, (
"prompt_template should not be present when no flags provided"
)
assert "build_prompt_template" not in embedding_options, (
"build_prompt_template should not be present when no flags provided"
)
assert "query_prompt_template" not in embedding_options, (
"query_prompt_template should not be present when no flags provided"
)
class TestPromptTemplateFlowsToComputeEmbeddings:
"""Tests for template flowing through to compute_embeddings function."""
@patch("leann.api.compute_embeddings")
def test_prompt_template_flows_to_compute_embeddings_via_provider_options(
self, mock_compute_embeddings, tmp_path
):
"""
Verify that the prompt template flows from CLI args through LeannBuilder
to compute_embeddings() function via provider_options parameter.
This is an integration test that verifies the complete flow:
CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options)
This test will fail because:
1. CLI doesn't capture the argument yet
2. embedding_options doesn't include prompt_template
3. LeannBuilder doesn't pass it through to compute_embeddings
"""
# Mock compute_embeddings to return dummy embeddings as numpy array
import numpy as np
mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
# Use real LeannBuilder (not mocked) to test the actual flow
cli = LeannCLI()
# Mock load_documents to return a simple document
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "search_document: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--backend-name",
"hnsw", # Use hnsw backend
"--force", # Force rebuild to ensure index is created
]
)
# This should fail because the flow isn't implemented yet
import asyncio
asyncio.run(cli.build_index(args))
# Verify compute_embeddings was called with provider_options containing prompt_template
assert mock_compute_embeddings.called, "compute_embeddings should have been called"
# Check the call arguments
call_kwargs = mock_compute_embeddings.call_args.kwargs
assert "provider_options" in call_kwargs, (
"compute_embeddings should receive provider_options parameter"
)
provider_options = call_kwargs["provider_options"]
assert provider_options is not None, "provider_options should not be None"
assert "prompt_template" in provider_options, (
"provider_options should contain prompt_template key"
)
assert provider_options["prompt_template"] == template, (
f"Template should be '{template}', got {provider_options.get('prompt_template')}"
)
class TestPromptTemplateArgumentHelp:
"""Tests for argument help text and documentation."""
def test_build_command_prompt_template_has_help_text(self):
"""
Verify that --embedding-prompt-template has descriptive help text.
Good help text is crucial for CLI usability.
"""
cli = LeannCLI()
parser = cli.create_parser()
# Get the build subparser
# This is a bit tricky - we need to parse to get the help
# We'll check that the help includes relevant keywords
import io
from contextlib import redirect_stdout
f = io.StringIO()
try:
with redirect_stdout(f):
parser.parse_args(["build", "--help"])
except SystemExit:
pass # --help causes sys.exit(0)
help_text = f.getvalue()
assert "--embedding-prompt-template" in help_text, (
"Help text should mention --embedding-prompt-template"
)
# Check for keywords that should be in the help
help_lower = help_text.lower()
assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), (
"Help text should explain what the prompt template does"
)
def test_search_command_prompt_template_has_help_text(self):
"""
Verify that search command also has help text for --embedding-prompt-template.
"""
cli = LeannCLI()
parser = cli.create_parser()
import io
from contextlib import redirect_stdout
f = io.StringIO()
try:
with redirect_stdout(f):
parser.parse_args(["search", "--help"])
except SystemExit:
pass # --help causes sys.exit(0)
help_text = f.getvalue()
assert "--embedding-prompt-template" in help_text, (
"Search help text should mention --embedding-prompt-template"
)

View File

@@ -0,0 +1,281 @@
"""Unit tests for prompt template prepending in OpenAI embeddings.
This test suite defines the contract for prompt template functionality that allows
users to prepend a consistent prompt to all embedding inputs. These tests verify:
1. Template prepending to all input texts before embedding computation
2. Graceful handling of None/missing provider_options
3. Empty string template behavior (no-op)
4. Logging of template application for observability
5. Template application before token truncation
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import pytest
from leann.embedding_compute import compute_embeddings_openai
class TestPromptTemplatePrepending:
"""Tests for prompt template prepending in compute_embeddings_openai."""
@pytest.fixture
def mock_openai_client(self):
"""Create mock OpenAI client that captures input texts."""
mock_client = MagicMock()
# Mock the embeddings.create response
mock_response = Mock()
mock_response.data = [
Mock(embedding=[0.1, 0.2, 0.3]),
Mock(embedding=[0.4, 0.5, 0.6]),
]
mock_client.embeddings.create.return_value = mock_response
return mock_client
@pytest.fixture
def mock_openai_module(self, mock_openai_client, monkeypatch):
"""Mock the openai module to return our mock client."""
# Mock the API key environment variable
monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking")
# openai is imported inside the function, so we need to patch it there
with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai:
yield mock_openai
def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client):
"""Verify template is prepended to all input texts.
When provider_options contains "prompt_template", that template should
be prepended to every text in the input list before sending to OpenAI API.
This is the core functionality: the template acts as a consistent prefix
that provides context or instruction for the embedding model.
"""
texts = ["First document", "Second document"]
template = "search_document: "
provider_options = {"prompt_template": template}
# Call compute_embeddings_openai with provider_options
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify embeddings.create was called with templated texts
mock_openai_client.embeddings.create.assert_called_once()
call_args = mock_openai_client.embeddings.create.call_args
# Extract the input texts sent to API
sent_texts = call_args.kwargs["input"]
# Verify template was prepended to all texts
assert len(sent_texts) == 2, "Should send same number of texts"
assert sent_texts[0] == "search_document: First document", (
"Template should be prepended to first text"
)
assert sent_texts[1] == "search_document: Second document", (
"Template should be prepended to second text"
)
# Verify result is valid embeddings array
assert isinstance(result, np.ndarray)
assert result.shape == (2, 3), "Should return correct shape"
def test_template_not_applied_when_missing_or_empty(
self, mock_openai_module, mock_openai_client
):
"""Verify template not applied when provider_options is None, missing key, or empty string.
This consolidated test covers three scenarios where templates should NOT be applied:
1. provider_options is None (default behavior)
2. provider_options exists but missing 'prompt_template' key
3. prompt_template is explicitly set to empty string ""
In all cases, texts should be sent to the API unchanged.
"""
# Scenario 1: None provider_options
texts = ["Original text one", "Original text two"]
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=None,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Original text one", (
"Text should be unchanged with None provider_options"
)
assert sent_texts[1] == "Original text two"
assert isinstance(result, np.ndarray)
assert result.shape == (2, 3)
# Reset mock for next scenario
mock_openai_client.reset_mock()
mock_response = Mock()
mock_response.data = [
Mock(embedding=[0.1, 0.2, 0.3]),
Mock(embedding=[0.4, 0.5, 0.6]),
]
mock_openai_client.embeddings.create.return_value = mock_response
# Scenario 2: Missing 'prompt_template' key
texts = ["Text without template", "Another text"]
provider_options = {"base_url": "https://api.openai.com/v1"}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key"
assert sent_texts[1] == "Another text"
assert isinstance(result, np.ndarray)
# Reset mock for next scenario
mock_openai_client.reset_mock()
mock_openai_client.embeddings.create.return_value = mock_response
# Scenario 3: Empty string template
texts = ["Text one", "Text two"]
provider_options = {"prompt_template": ""}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Text one", "Empty template should not modify text"
assert sent_texts[1] == "Text two"
assert isinstance(result, np.ndarray)
def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client):
"""Verify template is prepended in all batches when texts exceed batch size.
OpenAI API has batch size limits. When input texts are split into
multiple batches, the template should be prepended to texts in every batch.
This ensures consistency across all API calls.
"""
# Create many texts that will be split into multiple batches
texts = [f"Document {i}" for i in range(1000)]
template = "passage: "
provider_options = {"prompt_template": template}
# Mock multiple batch responses
mock_response = Mock()
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)]
mock_openai_client.embeddings.create.return_value = mock_response
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify embeddings.create was called multiple times (batching)
assert mock_openai_client.embeddings.create.call_count >= 2, (
"Should make multiple API calls for large text list"
)
# Verify template was prepended in ALL batches
for call in mock_openai_client.embeddings.create.call_args_list:
sent_texts = call.kwargs["input"]
for text in sent_texts:
assert text.startswith(template), (
f"All texts in all batches should start with template. Got: {text}"
)
# Verify result shape
assert result.shape[0] == 1000, "Should return embeddings for all texts"
def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client):
"""Verify template with special characters is handled correctly.
Templates may contain special characters, Unicode, newlines, etc.
These should all be prepended correctly without encoding issues.
"""
texts = ["Document content"]
# Template with various special characters
template = "🔍 Search query [EN]: "
provider_options = {"prompt_template": template}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify special characters in template were preserved
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "🔍 Search query [EN]: Document content", (
"Special characters in template should be preserved"
)
assert isinstance(result, np.ndarray)
def test_prompt_template_integration_with_existing_validation(
self, mock_openai_module, mock_openai_client
):
"""Verify template works with existing input validation.
compute_embeddings_openai has validation for empty texts and whitespace.
Template prepending should happen AFTER validation, so validation errors
are thrown based on original texts, not templated texts.
This ensures users get clear error messages about their input.
"""
# Empty text should still raise ValueError even with template
texts = [""]
provider_options = {"prompt_template": "prefix: "}
with pytest.raises(ValueError, match="empty/invalid"):
compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
def test_prompt_template_with_api_key_and_base_url(
self, mock_openai_module, mock_openai_client
):
"""Verify template works alongside other provider_options.
provider_options may contain multiple settings: prompt_template,
base_url, api_key. All should work together correctly.
"""
texts = ["Test document"]
provider_options = {
"prompt_template": "embed: ",
"base_url": "https://custom.api.com/v1",
"api_key": "test-key-123",
}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify template was applied
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "embed: Test document"
# Verify OpenAI client was created with correct base_url
mock_openai_module.assert_called()
client_init_kwargs = mock_openai_module.call_args.kwargs
assert client_init_kwargs["base_url"] == "https://custom.api.com/v1"
assert client_init_kwargs["api_key"] == "test-key-123"
assert isinstance(result, np.ndarray)

View File

@@ -0,0 +1,315 @@
"""Unit tests for LM Studio TypeScript SDK bridge functionality.
This test suite defines the contract for the LM Studio SDK bridge that queries
model context length via Node.js subprocess. These tests verify:
1. Successful SDK query returns context length
2. Graceful fallback when Node.js not installed (FileNotFoundError)
3. Graceful fallback when SDK not installed (npm error)
4. Timeout handling (subprocess.TimeoutExpired)
5. Invalid JSON response handling
All tests are written in Red Phase - they should FAIL initially because the
`_query_lmstudio_context_limit` function does not exist yet.
The function contract:
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
- Outputs: context_length (int) or None on error
- Requirements:
1. Call Node.js with inline JavaScript using @lmstudio/sdk
2. 10-second timeout (accounts for Node.js startup)
3. Graceful fallback on any error (returns None, doesn't raise)
4. Parse JSON response with contextLength field
5. Log errors at debug level (not warning/error)
"""
import subprocess
from unittest.mock import Mock
import pytest
# Try to import the function - if it doesn't exist, tests will fail as expected
try:
from leann.embedding_compute import _query_lmstudio_context_limit
except ImportError:
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
def _query_lmstudio_context_limit(*args, **kwargs):
raise NotImplementedError(
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
)
class TestLMStudioBridge:
"""Tests for LM Studio TypeScript SDK bridge integration."""
def test_query_lmstudio_success(self, monkeypatch):
"""Verify successful SDK query returns context length.
When the Node.js subprocess successfully queries the LM Studio SDK,
it should return a JSON response with contextLength field. The function
should parse this and return the integer context length.
"""
def mock_run(*args, **kwargs):
# Verify timeout is set to 10 seconds
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
# Verify capture_output and text=True are set
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
assert kwargs.get("text") is True, "Should decode output as text"
# Return successful JSON response
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
# Test with typical LM Studio model
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit == 8192, "Should return context length from SDK response"
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
"""Verify graceful fallback when Node.js not installed.
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
The function should catch this and return None (graceful fallback to registry).
"""
def mock_run(*args, **kwargs):
raise FileNotFoundError("node: command not found")
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when Node.js not installed"
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
"""Verify graceful fallback when @lmstudio/sdk not installed.
When the SDK npm package is not installed, Node.js will return non-zero
exit code with error message in stderr. The function should detect this
and return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 1
mock_result.stdout = ""
mock_result.stderr = (
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
)
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when SDK not installed"
def test_query_lmstudio_timeout(self, monkeypatch):
"""Verify graceful fallback when subprocess times out.
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
not responding), subprocess.TimeoutExpired should be raised. The function
should catch this and return None.
"""
def mock_run(*args, **kwargs):
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None on timeout"
def test_query_lmstudio_invalid_json(self, monkeypatch):
"""Verify graceful fallback when response is invalid JSON.
When the subprocess returns malformed JSON (e.g., due to SDK error),
json.loads will raise ValueError/JSONDecodeError. The function should
catch this and return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = "This is not valid JSON"
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when JSON parsing fails"
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
"""Verify graceful fallback when JSON lacks contextLength field.
When the SDK returns valid JSON but without the expected contextLength
field (e.g., error response), the function should return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="nonexistent-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength field missing"
def test_query_lmstudio_null_context_length(self, monkeypatch):
"""Verify graceful fallback when contextLength is null.
When the SDK returns contextLength: null (model couldn't be loaded),
the function should return None for registry fallback.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength is null"
def test_query_lmstudio_zero_context_length(self, monkeypatch):
"""Verify graceful fallback when contextLength is zero.
When the SDK returns contextLength: 0 (invalid value), the function
should return None to trigger registry fallback.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength is zero"
def test_query_lmstudio_with_custom_port(self, monkeypatch):
"""Verify SDK query works with non-default WebSocket port.
LM Studio can run on custom ports. The function should pass the
provided base_url to the Node.js subprocess.
"""
def mock_run(*args, **kwargs):
# Verify the base_url argument is passed correctly
command = args[0] if args else kwargs.get("args", [])
assert "ws://localhost:8080" in " ".join(command), (
"Should pass custom port to subprocess"
)
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:8080"
)
assert limit == 4096, "Should work with custom WebSocket port"
@pytest.mark.parametrize(
"context_length,expected",
[
(512, 512), # Small context
(2048, 2048), # Common context
(8192, 8192), # Large context
(32768, 32768), # Very large context
],
)
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
"""Verify SDK query handles various context length values.
Different models have different context lengths. The function should
correctly parse and return any positive integer value.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit == expected, f"Should return {expected} for context length {context_length}"
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
Following the graceful fallback pattern from Ollama implementation,
errors should be logged at debug level to avoid alarming users when
fallback to registry works fine.
"""
import logging
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
def mock_run(*args, **kwargs):
raise FileNotFoundError("node: command not found")
monkeypatch.setattr("subprocess.run", mock_run)
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
# Check that debug logging occurred (not warning/error)
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
assert len(debug_logs) > 0, "Should log error at DEBUG level"
# Verify no WARNING or ERROR logs
warning_or_error_logs = [
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
]
assert len(warning_or_error_logs) == 0, (
"Should not log at WARNING/ERROR level for expected failures"
)

View File

@@ -0,0 +1,400 @@
"""End-to-end integration tests for prompt template and token limit features.
These tests verify real-world functionality with live services:
- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support
- Ollama with dynamic token limit detection
- Hybrid token limit discovery mechanism
Run with: pytest tests/test_prompt_template_e2e.py -v -s
Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration"
Prerequisites:
1. LM Studio running with embedding model: http://localhost:1234
2. [Optional] Ollama running: ollama serve
3. [Optional] Ollama model: ollama pull nomic-embed-text
4. [Optional] Node.js + @lmstudio/sdk for context length detection
"""
import logging
import socket
import numpy as np
import pytest
import requests
from leann.embedding_compute import (
compute_embeddings_ollama,
compute_embeddings_openai,
get_model_token_limit,
)
# Test markers for conditional execution
pytestmark = pytest.mark.integration
logger = logging.getLogger(__name__)
def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool:
"""Check if a service is available on the given host:port."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
result = sock.connect_ex((host, port))
sock.close()
return result == 0
except Exception:
return False
def check_ollama_available() -> bool:
"""Check if Ollama service is available."""
if not check_service_available("localhost", 11434):
return False
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
return response.status_code == 200
except Exception:
return False
def check_lmstudio_available() -> bool:
"""Check if LM Studio service is available."""
if not check_service_available("localhost", 1234):
return False
try:
response = requests.get("http://localhost:1234/v1/models", timeout=2.0)
return response.status_code == 200
except Exception:
return False
def get_lmstudio_first_model() -> str:
"""Get the first available model from LM Studio."""
try:
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
data = response.json()
models = data.get("data", [])
if models:
return models[0]["id"]
except Exception:
pass
return None
class TestPromptTemplateOpenAI:
"""End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio)."""
@pytest.mark.skipif(
not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234"
)
def test_lmstudio_embedding_with_prompt_template(self):
"""Test prompt templates with LM Studio using OpenAI-compatible API."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
texts = ["artificial intelligence", "machine learning"]
prompt_template = "search_query: "
# Get embeddings with prompt template via provider_options
provider_options = {"prompt_template": prompt_template}
embeddings = compute_embeddings_openai(
texts=texts,
model_name=model_name,
base_url="http://localhost:1234/v1",
api_key="lm-studio", # LM Studio doesn't require real key
provider_options=provider_options,
)
assert embeddings is not None
assert len(embeddings) == 2
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
assert all(len(emb) > 0 for emb in embeddings)
logger.info(
f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
)
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_prompt_template_affects_embeddings(self):
"""Verify that prompt templates actually change embedding values."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
text = "machine learning"
base_url = "http://localhost:1234/v1"
api_key = "lm-studio"
# Get embeddings without template
embeddings_no_template = compute_embeddings_openai(
texts=[text],
model_name=model_name,
base_url=base_url,
api_key=api_key,
provider_options={},
)
# Get embeddings with template
embeddings_with_template = compute_embeddings_openai(
texts=[text],
model_name=model_name,
base_url=base_url,
api_key=api_key,
provider_options={"prompt_template": "search_query: "},
)
# Embeddings should be different when template is applied
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
logger.info("✓ Prompt template changes embedding values as expected")
class TestPromptTemplateOllama:
"""End-to-end tests for prompt template with Ollama."""
@pytest.mark.skipif(
not check_ollama_available(), reason="Ollama service not available on localhost:11434"
)
def test_ollama_embedding_with_prompt_template(self):
"""Test prompt templates with Ollama using any available embedding model."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
model_name = embedding_models[0]
texts = ["artificial intelligence", "machine learning"]
prompt_template = "search_query: "
# Get embeddings with prompt template via provider_options
provider_options = {"prompt_template": prompt_template}
embeddings = compute_embeddings_ollama(
texts=texts,
model_name=model_name,
is_build=False,
host="http://localhost:11434",
provider_options=provider_options,
)
assert embeddings is not None
assert len(embeddings) == 2
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
assert all(len(emb) > 0 for emb in embeddings)
logger.info(
f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
)
except Exception as e:
pytest.skip(f"Could not test Ollama prompt template: {e}")
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_ollama_prompt_template_affects_embeddings(self):
"""Verify that prompt templates actually change embedding values with Ollama."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
model_name = embedding_models[0]
text = "machine learning"
host = "http://localhost:11434"
# Get embeddings without template
embeddings_no_template = compute_embeddings_ollama(
texts=[text], model_name=model_name, is_build=False, host=host, provider_options={}
)
# Get embeddings with template
embeddings_with_template = compute_embeddings_ollama(
texts=[text],
model_name=model_name,
is_build=False,
host=host,
provider_options={"prompt_template": "search_query: "},
)
# Embeddings should be different when template is applied
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
logger.info("✓ Ollama prompt template changes embedding values as expected")
except Exception as e:
pytest.skip(f"Could not test Ollama prompt template: {e}")
class TestLMStudioSDK:
"""End-to-end tests for LM Studio SDK integration."""
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_model_listing(self):
"""Test that we can list models from LM Studio."""
try:
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
assert response.status_code == 200
data = response.json()
assert "data" in data
models = data["data"]
logger.info(f"✓ LM Studio models available: {len(models)}")
if models:
logger.info(f" First model: {models[0].get('id', 'unknown')}")
except Exception as e:
pytest.skip(f"LM Studio API error: {e}")
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_sdk_context_length_detection(self):
"""Test context length detection via LM Studio SDK bridge (requires Node.js + SDK)."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
try:
from leann.embedding_compute import _query_lmstudio_context_limit
# SDK requires WebSocket URL (ws://)
context_length = _query_lmstudio_context_limit(
model_name=model_name, base_url="ws://localhost:1234"
)
if context_length is None:
logger.warning(
"⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)"
)
pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable")
else:
assert context_length > 0
logger.info(
f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}"
)
except ImportError:
pytest.skip("_query_lmstudio_context_limit not implemented yet")
except Exception as e:
logger.error(f"LM Studio SDK test error: {e}")
raise
class TestOllamaTokenLimit:
"""End-to-end tests for Ollama token limit discovery."""
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_ollama_token_limit_detection(self):
"""Test dynamic token limit detection from Ollama /api/show endpoint."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
test_model = embedding_models[0]
# Test token limit detection
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
assert limit > 0
logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}")
except Exception as e:
pytest.skip(f"Could not test Ollama token detection: {e}")
class TestHybridTokenLimit:
"""End-to-end tests for hybrid token limit discovery mechanism."""
def test_hybrid_discovery_registry_fallback(self):
"""Test fallback to static registry for known OpenAI models."""
# Use a known OpenAI model (should be in registry)
limit = get_model_token_limit(
model_name="text-embedding-3-small",
base_url="http://fake-server:9999", # Fake URL to force registry lookup
)
# text-embedding-3-small should have 8192 in registry
assert limit == 8192
logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens")
def test_hybrid_discovery_default_fallback(self):
"""Test fallback to safe default for completely unknown models."""
limit = get_model_token_limit(
model_name="completely-unknown-model-xyz-12345",
base_url="http://fake-server:9999",
default=512,
)
# Should get the specified default
assert limit == 512
logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens")
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_hybrid_discovery_ollama_dynamic_first(self):
"""Test that Ollama models use dynamic discovery first."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
test_model = embedding_models[0]
# Should query Ollama /api/show dynamically
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
assert limit > 0
logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}")
except Exception as e:
pytest.skip(f"Could not test hybrid Ollama discovery: {e}")
if __name__ == "__main__":
print("\n" + "=" * 70)
print("INTEGRATION TEST SUITE - Real Service Testing")
print("=" * 70)
print("\nThese tests require live services:")
print(" • LM Studio: http://localhost:1234 (with embedding model loaded)")
print(" • [Optional] Ollama: http://localhost:11434")
print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests")
print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s")
print("=" * 70 + "\n")

View File

@@ -0,0 +1,808 @@
"""
Integration tests for prompt template metadata persistence and reuse.
These tests verify the complete lifecycle of prompt template persistence:
1. Template is saved to .meta.json during index build
2. Template is automatically loaded during search operations
3. Template can be overridden with explicit flag during search
4. Template is reused during chat/ask operations
These are integration tests that:
- Use real file system with temporary directories
- Run actual build and search operations
- Inspect .meta.json file contents directly
- Mock embedding servers to avoid external dependencies
- Use small test codebases for fast execution
Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch
import numpy as np
import pytest
from leann.api import LeannBuilder, LeannSearcher
class TestPromptTemplateMetadataPersistence:
"""Tests for prompt template storage in .meta.json during build."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
# Return dummy embeddings as numpy array
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings):
"""
Verify that when build is run with embedding_options containing prompt_template,
the template value is saved to .meta.json file.
This is the core persistence requirement - templates must be saved to allow
reuse in subsequent search operations without re-specifying the flag.
Expected failure: .meta.json exists but doesn't contain embedding_options
with prompt_template, or the value is not persisted correctly.
"""
# Setup test data
index_path = temp_index_dir / "test_index.leann"
template = "search_document: "
# Build index with prompt template in embedding_options
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": template},
)
# Add a simple document
builder.add_text("This is a test document for indexing")
# Build the index
builder.build_index(str(index_path))
# Verify .meta.json was created and contains the template
meta_path = temp_index_dir / "test_index.leann.meta.json"
assert meta_path.exists(), ".meta.json file should be created during build"
# Read and parse metadata
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
# Verify embedding_options exists in metadata
assert "embedding_options" in meta_data, (
"embedding_options should be saved to .meta.json when provided"
)
# Verify prompt_template is in embedding_options
embedding_options = meta_data["embedding_options"]
assert "prompt_template" in embedding_options, (
"prompt_template should be saved within embedding_options"
)
# Verify the template value matches what we provided
assert embedding_options["prompt_template"] == template, (
f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'"
)
def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings):
"""
Verify that when no prompt template is provided during build,
.meta.json either doesn't have embedding_options or prompt_template key.
This ensures clean metadata without unnecessary keys when features aren't used.
Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template.
"""
index_path = temp_index_dir / "test_no_template.leann"
# Build index WITHOUT prompt template
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# No embedding_options provided
)
builder.add_text("Document without template")
builder.build_index(str(index_path))
# Verify metadata
meta_path = temp_index_dir / "test_no_template.leann.meta.json"
assert meta_path.exists()
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
# If embedding_options exists, it should not contain prompt_template
if "embedding_options" in meta_data:
embedding_options = meta_data["embedding_options"]
assert "prompt_template" not in embedding_options, (
"prompt_template should not be in metadata when not provided"
)
class TestPromptTemplateAutoLoadOnSearch:
"""Tests for automatic loading of prompt template during search operations.
NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search).
This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad
which uses simpler mocking and doesn't hang.
"""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings):
"""
Verify that searching an index built WITHOUT a prompt template
works correctly (backward compatibility).
The searcher should handle missing prompt_template gracefully.
Expected behavior: Search succeeds, no template is used.
"""
# Build index without template
index_path = temp_index_dir / "no_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
)
builder.add_text("Document without template")
builder.build_index(str(index_path))
# Reset mocks
mock_embeddings.reset_mock()
# Create searcher and search
searcher = LeannSearcher(index_path=str(index_path))
# Verify no template in embedding_options
assert "prompt_template" not in searcher.embedding_options, (
"Searcher should not have prompt_template when not in metadata"
)
class TestQueryPromptTemplateAutoLoad:
"""Tests for automatic loading of separate query_prompt_template during search (R2).
These tests verify the new two-template system where:
- build_prompt_template: Applied during index building
- query_prompt_template: Applied during search operations
Expected to FAIL in Red Phase (R2) because query template extraction
and application is not yet implemented in LeannSearcher.search().
"""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_compute_embeddings(self):
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
with patch("leann.embedding_compute.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify that search() automatically loads and applies query_prompt_template from .meta.json.
Given: Index built with separate build_prompt_template and query_prompt_template
When: LeannSearcher.search("my query") is called
Then: Query embedding is computed with "query: my query" (query template applied)
This is the core R2 requirement - query templates must be auto-loaded and applied
during search without user intervention.
Expected failure: compute_embeddings called with raw "my query" instead of
"query: my query" because query template extraction is not implemented.
"""
# Setup: Build index with separate templates in new format
index_path = temp_index_dir / "query_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={
"build_prompt_template": "doc: ",
"query_prompt_template": "query: ",
},
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock to ignore build calls
mock_compute_embeddings.reset_mock()
# Act: Search with query
searcher = LeannSearcher(index_path=str(index_path))
# Mock the backend search to avoid actual search
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {
"labels": [["test_id_0"]], # IDs (nested list for batch support)
"distances": [[0.9]], # Distances (nested list for batch support)
}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: compute_embeddings was called with query template applied
assert mock_compute_embeddings.called, "compute_embeddings should be called during search"
# Get the actual text passed to compute_embeddings
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0] # First positional arg (list of texts)
assert len(texts_arg) == 1, "Should compute embedding for one query"
assert texts_arg[0] == "query: my query", (
f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'"
)
def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify backward compatibility with old single prompt_template format.
Given: Index with old format (single prompt_template, no query_prompt_template)
When: LeannSearcher.search("my query") is called
Then: Query embedding computed with "doc: my query" (old template applied)
This ensures indexes built with the old single-template system continue
to work correctly with the new search implementation.
Expected failure: Old template not recognized/applied because backward
compatibility logic is not implemented.
"""
# Setup: Build index with old single-template format
index_path = temp_index_dir / "old_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": "doc: "}, # Old format
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: Old template was applied
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "doc: my query", (
f"Old prompt_template should be applied for backward compatibility: "
f"expected 'doc: my query', got '{texts_arg[0]}'"
)
def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify backward compatibility when no template is present in .meta.json.
Given: Index with no template in .meta.json (very old indexes)
When: LeannSearcher.search("my query") is called
Then: Query embedding computed with "my query" (no template, raw query)
This ensures the most basic backward compatibility - indexes without
any template support continue to work as before.
Expected failure: May fail if default template is incorrectly applied,
or if missing template causes error.
"""
# Setup: Build index without any template
index_path = temp_index_dir / "no_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# No embedding_options at all
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: No template applied (raw query)
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "my query", (
f"No template should be applied when missing from metadata: "
f"expected 'my query', got '{texts_arg[0]}'"
)
def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings):
"""
Verify that explicit provider_options can override stored query template.
Given: Index with query_prompt_template: "query: "
When: search() called with provider_options={"prompt_template": "override: "}
Then: Query embedding computed with "override: test" (override takes precedence)
This enables users to experiment with different query templates without
rebuilding the index, or to handle special query types differently.
Expected failure: provider_options parameter is accepted via **kwargs but
not used. Query embedding computed with raw "test" instead of "override: test"
because override logic is not implemented.
"""
# Setup: Build index with query template
index_path = temp_index_dir / "override_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={
"build_prompt_template": "doc: ",
"query_prompt_template": "query: ",
},
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search with override
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
# This should accept provider_options parameter
searcher.search(
"test",
top_k=1,
recompute_embeddings=False,
provider_options={"prompt_template": "override: "},
)
# Assert: Override template was applied
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "override: test", (
f"Override template should take precedence: "
f"expected 'override: test', got '{texts_arg[0]}'"
)
class TestPromptTemplateReuseInChat:
"""Tests for prompt template reuse in chat/ask operations."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
@pytest.fixture
def mock_embedding_server_manager(self):
"""Mock EmbeddingServerManager for chat tests."""
with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class:
mock_manager = Mock()
mock_manager.start_server.return_value = (True, 5557)
mock_manager_class.return_value = mock_manager
yield mock_manager
@pytest.fixture
def index_with_template(self, temp_index_dir, mock_embeddings):
"""Build an index with a prompt template."""
index_path = temp_index_dir / "chat_template_index.leann"
template = "document_query: "
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": template},
)
builder.add_text("Test document for chat")
builder.build_index(str(index_path))
return str(index_path), template
class TestPromptTemplateIntegrationWithEmbeddingModes:
"""Tests for prompt template compatibility with different embedding modes."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.mark.parametrize(
"mode,model,template,filename_prefix",
[
(
"openai",
"text-embedding-3-small",
"Represent this for searching: ",
"openai_template",
),
("ollama", "nomic-embed-text", "search_query: ", "ollama_template"),
("sentence-transformers", "facebook/contriever", "query: ", "st_template"),
],
)
def test_prompt_template_metadata_with_embedding_modes(
self, temp_index_dir, mode, model, template, filename_prefix
):
"""Verify prompt template is saved correctly across different embedding modes.
Tests that prompt templates are persisted to .meta.json for:
- OpenAI mode (primary use case)
- Ollama mode (also supports templates)
- Sentence-transformers mode (saved for forward compatibility)
Expected behavior: Template is saved to .meta.json regardless of mode.
"""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
index_path = temp_index_dir / f"{filename_prefix}.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model,
embedding_mode=mode,
embedding_options={"prompt_template": template},
)
builder.add_text(f"{mode.capitalize()} test document")
builder.build_index(str(index_path))
# Verify metadata
meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
assert meta_data["embedding_mode"] == mode
# Template should be saved for all modes (even if not used by some)
if "embedding_options" in meta_data:
assert meta_data["embedding_options"]["prompt_template"] == template
class TestQueryTemplateApplicationInComputeEmbedding:
"""Tests for query template application in compute_query_embedding() (Bug Fix).
These tests verify that query templates are applied consistently in BOTH
code paths (server and fallback) when computing query embeddings.
This addresses the bug where query templates were only applied in the
fallback path, not when using the embedding server (the default path).
Bug Context:
- Issue: Query templates were stored in metadata but only applied during
fallback (direct) computation, not when using embedding server
- Fix: Move template application to BEFORE any computation path in
compute_query_embedding() (searcher_base.py:107-110)
- Impact: Critical for models like EmbeddingGemma that require task-specific
templates for optimal performance
These tests ensure the fix works correctly and prevent regression.
"""
@pytest.fixture
def temp_index_with_template(self):
"""Create a temporary index with query template in metadata"""
with tempfile.TemporaryDirectory() as tmpdir:
index_dir = Path(tmpdir)
index_file = index_dir / "test.leann"
meta_file = index_dir / "test.leann.meta.json"
# Create minimal metadata with query template
metadata = {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": "text-embedding-embeddinggemma-300m-qat",
"dimensions": 768,
"embedding_mode": "openai",
"backend_kwargs": {
"graph_degree": 32,
"complexity": 64,
"distance_metric": "cosine",
},
"embedding_options": {
"base_url": "http://localhost:1234/v1",
"api_key": "test-key",
"build_prompt_template": "title: none | text: ",
"query_prompt_template": "task: search result | query: ",
},
}
meta_file.write_text(json.dumps(metadata, indent=2))
# Create minimal HNSW index file (empty is okay for this test)
index_file.write_bytes(b"")
yield str(index_file)
def test_query_template_applied_in_fallback_path(self, temp_index_with_template):
"""Test that query template is applied when using fallback (direct) path"""
from leann.searcher_base import BaseSearcher
# Create a concrete implementation for testing
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
# Load metadata
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
# Mock compute_embeddings to capture the query text
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
# Call compute_query_embedding with template (fallback path)
result = searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False, # Force fallback path
query_template="task: search result | query: ",
)
# Verify template was applied
assert len(captured_queries) == 1
assert captured_queries[0] == "task: search result | query: vector database"
assert result.shape == (1, 768)
def test_query_template_applied_in_server_path(self, temp_index_with_template):
"""Test that query template is applied when using server path"""
from leann.searcher_base import BaseSearcher
# Create a concrete implementation for testing
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
# Load metadata
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
# Mock the server methods to capture the query text
captured_queries = []
def mock_ensure_server_running(passages_file, port):
return port
def mock_compute_embedding_via_server(chunks, port):
captured_queries.extend(chunks)
return np.random.rand(len(chunks), 768).astype(np.float32)
searcher._ensure_server_running = mock_ensure_server_running
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
# Call compute_query_embedding with template (server path)
result = searcher.compute_query_embedding(
query="vector database",
use_server_if_available=True, # Use server path
query_template="task: search result | query: ",
)
# Verify template was applied BEFORE calling server
assert len(captured_queries) == 1
assert captured_queries[0] == "task: search result | query: vector database"
assert result.shape == (1, 768)
def test_query_template_without_template_parameter(self, temp_index_with_template):
"""Test that query is unchanged when no template is provided"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False,
query_template=None, # No template
)
# Verify query is unchanged
assert len(captured_queries) == 1
assert captured_queries[0] == "vector database"
def test_query_template_consistency_between_paths(self, temp_index_with_template):
"""Test that both paths apply template identically"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
query_template = "task: search result | query: "
original_query = "vector database"
# Capture queries from fallback path
fallback_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
fallback_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query=original_query,
use_server_if_available=False,
query_template=query_template,
)
# Capture queries from server path
server_queries = []
def mock_ensure_server_running(passages_file, port):
return port
def mock_compute_embedding_via_server(chunks, port):
server_queries.extend(chunks)
return np.random.rand(len(chunks), 768).astype(np.float32)
searcher._ensure_server_running = mock_ensure_server_running
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
searcher.compute_query_embedding(
query=original_query,
use_server_if_available=True,
query_template=query_template,
)
# Verify both paths produced identical templated queries
assert len(fallback_queries) == 1
assert len(server_queries) == 1
assert fallback_queries[0] == server_queries[0]
assert fallback_queries[0] == f"{query_template}{original_query}"
def test_query_template_with_empty_string(self, temp_index_with_template):
"""Test behavior with empty template string"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False,
query_template="", # Empty string
)
# Empty string is falsy, so no template should be applied
assert captured_queries[0] == "vector database"

View File

@@ -266,3 +266,378 @@ class TestTokenTruncation:
assert result_tokens <= target_tokens, (
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
)
class TestLMStudioHybridDiscovery:
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
These tests verify that get_model_token_limit() properly integrates with
the LM Studio SDK bridge for dynamic token limit discovery. The integration
should:
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
2. Convert HTTP URLs to WebSocket format for SDK queries
3. Query LM Studio SDK and use discovered limit
4. Fall back to registry when SDK returns None
5. Execute AFTER Ollama detection but BEFORE registry fallback
All tests are written in Red Phase - they should FAIL initially because the
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
"""
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
"""Verify LM Studio SDK query succeeds and returns detected limit.
When a LM Studio base_url is detected and the SDK query succeeds,
get_model_token_limit() should return the dynamically discovered
context length without falling back to the registry.
"""
# Mock _query_lmstudio_context_limit to return successful SDK query
def mock_query_lmstudio(model_name, base_url):
# Verify WebSocket URL was passed (not HTTP)
assert base_url.startswith("ws://"), (
f"Should convert HTTP to WebSocket format, got: {base_url}"
)
return 8192 # Successful SDK query
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with HTTP URL that should be converted to WebSocket
limit = get_model_token_limit(
model_name="custom-model", base_url="http://localhost:1234/v1"
)
assert limit == 8192, "Should return limit from LM Studio SDK query"
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
"""Verify fallback to registry when LM Studio SDK returns None.
When LM Studio SDK query fails (returns None), get_model_token_limit()
should fall back to the EMBEDDING_MODEL_LIMITS registry.
"""
# Mock _query_lmstudio_context_limit to return None (SDK failure)
def mock_query_lmstudio(model_name, base_url):
return None # SDK query failed
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with known model that exists in registry
limit = get_model_token_limit(
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
)
# Should fall back to registry value
assert limit == 2048, "Should fall back to registry when SDK returns None"
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
"""Verify detection of LM Studio via port 1234.
get_model_token_limit() should recognize port 1234 as a LM Studio
server and attempt SDK query, regardless of hostname.
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return 4096
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with port 1234 (default LM Studio port)
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
assert limit == 4096, "Should return SDK query result"
@pytest.mark.parametrize(
"test_url,expected_limit,keyword",
[
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
],
)
def test_get_model_token_limit_lmstudio_url_keyword_detection(
self, monkeypatch, test_url, expected_limit, keyword
):
"""Verify detection of LM Studio via keywords in URL.
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
in the URL as indicating a LM Studio server.
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return expected_limit
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
@pytest.mark.parametrize(
"input_url,expected_protocol,expected_host",
[
("http://localhost:1234/v1", "ws://", "localhost:1234"),
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
],
)
def test_get_model_token_limit_protocol_conversion(
self, monkeypatch, input_url, expected_protocol, expected_host
):
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
1. Convert 'http://' to 'ws://'
2. Convert 'https://' to 'wss://'
3. Remove '/v1' or other path suffixes (SDK expects base URL)
4. Preserve host and port
"""
conversions_tested = []
def mock_query_lmstudio(model_name, base_url):
conversions_tested.append(base_url)
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
get_model_token_limit(model_name="test-model", base_url=input_url)
# Verify conversion happened
assert len(conversions_tested) == 1, "Should have called SDK query once"
assert conversions_tested[0].startswith(expected_protocol), (
f"Should convert to {expected_protocol}"
)
assert expected_host in conversions_tested[0], (
f"Should preserve host and port: {expected_host}"
)
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
"""Verify LM Studio detection happens AFTER Ollama detection.
The hybrid discovery order should be:
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
3. Registry fallback
If both Ollama and LM Studio patterns match, Ollama should take precedence.
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
"""
ollama_called = False
lmstudio_called = False
def mock_query_ollama(model_name, base_url):
nonlocal ollama_called
ollama_called = True
return 2048 # Ollama query succeeds
def mock_query_lmstudio(model_name, base_url):
nonlocal lmstudio_called
lmstudio_called = True
return None # Should not be reached if Ollama succeeds
monkeypatch.setattr(
"leann.embedding_compute._query_ollama_context_limit",
mock_query_ollama,
)
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with Ollama URL
limit = get_model_token_limit(
model_name="test-model", base_url="http://localhost:11434/api"
)
assert ollama_called, "Should attempt Ollama query first"
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
assert limit == 2048, "Should return Ollama result"
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
"""
lmstudio_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal lmstudio_called
lmstudio_called = True
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with non-LM Studio URLs
test_cases = [
"http://localhost:8080/v1", # Different port
"http://openai.example.com/v1", # Different service
"http://localhost:3000/v1", # Another port
]
for base_url in test_cases:
lmstudio_called = False # Reset for each test
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
"""Verify LM Studio detection is case-insensitive for keywords.
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test various case variations
test_cases = [
"http://LMStudio.local:8080/v1",
"http://LMSTUDIO.example.com/v1",
"http://LmStudio.local/v1",
"http://api.LM.STUDIO:5000/v1",
]
for base_url in test_cases:
query_called = False # Reset for each test
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
assert query_called, f"Should detect LM Studio in URL: {base_url}"
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
class TestTokenLimitCaching:
"""Tests for token limit caching to prevent repeated SDK/API calls.
Caching prevents duplicate SDK/API calls within the same Python process,
which is important because:
1. LM Studio SDK load() can load duplicate model instances
2. Ollama /api/show queries add latency
3. Registry lookups are pure overhead
Cache is process-scoped and resets between leann build invocations.
"""
def setup_method(self):
"""Clear cache before each test."""
from leann.embedding_compute import _token_limit_cache
_token_limit_cache.clear()
def test_registry_lookup_is_cached(self):
"""Verify that registry lookups are cached."""
from leann.embedding_compute import _token_limit_cache
# First call
limit1 = get_model_token_limit("text-embedding-3-small")
assert limit1 == 8192
# Verify it's in cache
cache_key = ("text-embedding-3-small", "")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 8192
# Second call should use cache
limit2 = get_model_token_limit("text-embedding-3-small")
assert limit2 == 8192
def test_default_fallback_is_cached(self):
"""Verify that default fallbacks are cached."""
from leann.embedding_compute import _token_limit_cache
# First call with unknown model
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
assert limit1 == 512
# Verify it's in cache
cache_key = ("unknown-model-xyz", "")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 512
# Second call should use cache
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
assert limit2 == 512
def test_different_urls_create_separate_cache_entries(self):
"""Verify that different base_urls create separate cache entries."""
from leann.embedding_compute import _token_limit_cache
# Same model, different URLs
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
# Both should find the model in registry (2048)
assert limit1 == 2048
assert limit2 == 2048
# But they should be separate cache entries
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
assert cache_key1 in _token_limit_cache
assert cache_key2 in _token_limit_cache
assert len(_token_limit_cache) == 2
def test_cache_prevents_repeated_lookups(self):
"""Verify that cache prevents repeated registry/API lookups."""
from leann.embedding_compute import _token_limit_cache
model_name = "text-embedding-ada-002"
# First call - should add to cache
assert len(_token_limit_cache) == 0
limit1 = get_model_token_limit(model_name)
cache_size_after_first = len(_token_limit_cache)
assert cache_size_after_first == 1
# Multiple subsequent calls - cache size should not change
for _ in range(5):
limit = get_model_token_limit(model_name)
assert limit == limit1
assert len(_token_limit_cache) == cache_size_after_first
def test_versioned_model_names_cached_correctly(self):
"""Verify that versioned model names (e.g., model:tag) are cached."""
from leann.embedding_compute import _token_limit_cache
# Model with version tag
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
assert limit == 2048
# Should be cached with full name including version
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 2048