Feature/prompt templates and lmstudio sdk (#171)
* Add prompt template support and LM Studio SDK integration Features: - Prompt template support for embedding models (via --embedding-prompt-template) - LM Studio SDK integration for automatic context length detection - Hybrid token limit discovery (Ollama → LM Studio → Registry → Default) - Client-side token truncation to prevent silent failures - Automatic persistence of embedding_options to .meta.json Implementation: - Added _query_lmstudio_context_limit() with Node.js subprocess bridge - Modified compute_embeddings_openai() to apply prompt templates before truncation - Extended CLI with --embedding-prompt-template flag for build and search - URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords) - HTTP→WebSocket URL conversion for SDK compatibility Tests: - 60 passing tests across 5 test files - Comprehensive coverage of prompt templates, LM Studio integration, and token handling - Parametrized tests for maintainability and clarity * Add integration tests and fix LM Studio SDK bridge Features: - End-to-end integration tests for prompt template with EmbeddingGemma - Integration tests for hybrid token limit discovery mechanism - Tests verify real-world functionality with live services (LM Studio, Ollama) Fixes: - LM Studio SDK bridge now uses client.embedding.load() for embedding models - Fixed NODE_PATH resolution to include npm global modules - Fixed integration test to use WebSocket URL (ws://) for SDK bridge Tests: - test_prompt_template_e2e.py: 8 integration tests covering: - Prompt template prepending with LM Studio (EmbeddingGemma) - LM Studio SDK bridge for context length detection - Ollama dynamic token limit detection - Hybrid discovery fallback mechanism (registry, default) - All tests marked with @pytest.mark.integration for selective execution - Tests gracefully skip when services unavailable Documentation: - Updated tests/README.md with integration test section - Added prerequisites and running instructions - Documented that prompt templates are ONLY for EmbeddingGemma - Added integration marker to pyproject.toml Test Results: - All 8 integration tests passing with live services - Confirmed prompt templates work correctly with EmbeddingGemma - Verified LM Studio SDK bridge auto-detects context length (2048) - Validated hybrid token limit discovery across all backends * Add prompt template support to Ollama mode Extends prompt template functionality from OpenAI mode to Ollama for backend consistency. Changes: - Add provider_options parameter to compute_embeddings_ollama() - Apply prompt template before token truncation (lines 1005-1011) - Pass provider_options through compute_embeddings() call chain Tests: - test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama - test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template - Both tests pass with live Ollama service (2/2 passing) Usage: leann build --embedding-mode ollama --embedding-prompt-template "query: " ... * Fix LM Studio SDK bridge to respect JIT auto-evict settings Problem: SDK bridge called client.embedding.load() which loaded models into LM Studio memory and bypassed JIT auto-evict settings, causing duplicate model instances to accumulate. Root cause analysis (from Perplexity research): - Explicit SDK load() commands are treated as "pinned" models - JIT auto-evict only applies to models loaded reactively via API requests - SDK-loaded models remain in memory until explicitly unloaded Solutions implemented: 1. Add model.unload() after metadata query (line 243) - Load model temporarily to get context length - Unload immediately to hand control back to JIT system - Subsequent API requests trigger JIT load with auto-evict 2. Add token limit caching to prevent repeated SDK calls - Cache discovered limits in _token_limit_cache dict (line 48) - Key: (model_name, base_url), Value: token_limit - Prevents duplicate load/unload cycles within same process - Cache shared across all discovery methods (Ollama, SDK, registry) Tests: - TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py) - Manual testing confirmed no duplicate models in LM Studio after fix - All existing tests pass Impact: - Respects user's LM Studio JIT and auto-evict settings - Reduces model memory footprint - Faster subsequent builds (cached limits) * Document prompt template and LM Studio SDK features Added comprehensive documentation for new optional embedding features: Configuration Guide (docs/configuration-guide.md): - New section: "Optional Embedding Features" - Task-Specific Prompt Templates subsection: - Explains EmbeddingGemma use case with document/query prompts - CLI and Python API examples - Clear warnings about compatible vs incompatible models - References to GitHub issue #155 and HuggingFace blog - LM Studio Auto-Detection subsection: - Prerequisites (Node.js + @lmstudio/sdk) - How auto-detection works (4-step process) - Benefits and optional nature clearly stated FAQ (docs/faq.md): - FAQ #2: When should I use prompt templates? - DO/DON'T guidance with examples - Links to detailed configuration guide - FAQ #3: Why is LM Studio loading multiple copies? - Explains the JIT auto-evict fix - Troubleshooting steps if still seeing issues - FAQ #4: Do I need Node.js and @lmstudio/sdk? - Clarifies it's completely optional - Lists benefits if installed - Installation instructions Cross-references between documents for easy navigation between quick reference and detailed guides. * Add separate build/query template support for task-specific models Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility. * Consolidate prompt template tests from 44 to 37 tests Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing. * Fix query template application in compute_query_embedding Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches. Changes: - Move template application to before any computation path (searcher_base.py:109-110) - Add comprehensive tests for both server and fallback paths - Consolidate tests into test_prompt_template_persistence.py Tests verify: - Template applied when using embedding server - Template applied in fallback path - Consistent behavior between both paths * Apply ruff formatting and fix linting issues - Remove unused imports - Fix import ordering - Remove unused variables - Apply code formatting * Fix CI test failures: mock OPENAI_API_KEY in tests Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
This commit is contained in:
@@ -158,6 +158,95 @@ builder.build_index("./indexes/my-notes", chunks)
|
|||||||
|
|
||||||
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
|
## Optional Embedding Features
|
||||||
|
|
||||||
|
### Task-Specific Prompt Templates
|
||||||
|
|
||||||
|
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
|
||||||
|
|
||||||
|
- **Indexing documents**: `"title: none | text: "`
|
||||||
|
- **Search queries**: `"task: search result | query: "`
|
||||||
|
|
||||||
|
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build index with EmbeddingGemma (via LM Studio or Ollama)
|
||||||
|
leann build my-docs \
|
||||||
|
--docs ./documents \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-embeddinggemma-300m-qat \
|
||||||
|
--embedding-api-base http://localhost:1234/v1 \
|
||||||
|
--embedding-prompt-template "title: none | text: " \
|
||||||
|
--force
|
||||||
|
|
||||||
|
# Search with query-specific prompt
|
||||||
|
leann search my-docs \
|
||||||
|
--query "What is quantum computing?" \
|
||||||
|
--embedding-prompt-template "task: search result | query: "
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important Notes:**
|
||||||
|
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
|
||||||
|
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
|
||||||
|
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
|
||||||
|
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
|
||||||
|
|
||||||
|
**Python API:**
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-embeddinggemma-300m-qat",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://localhost:1234/v1",
|
||||||
|
"api_key": "lm-studio",
|
||||||
|
"prompt_template": "title: none | text: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-docs", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
**References:**
|
||||||
|
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
|
||||||
|
|
||||||
|
### LM Studio Auto-Detection (Optional)
|
||||||
|
|
||||||
|
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
|
||||||
|
|
||||||
|
**Prerequisites:**
|
||||||
|
```bash
|
||||||
|
# Install Node.js (if not already installed)
|
||||||
|
# Then install the LM Studio SDK globally
|
||||||
|
npm install -g @lmstudio/sdk
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works:**
|
||||||
|
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
|
||||||
|
2. Queries model metadata via Node.js subprocess
|
||||||
|
3. Automatically unloads model after query (respects your JIT auto-evict settings)
|
||||||
|
4. Falls back to static registry if SDK unavailable
|
||||||
|
|
||||||
|
**No configuration needed** - it works automatically when SDK is installed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
leann build my-docs \
|
||||||
|
--docs ./documents \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-nomic-embed-text-v1.5 \
|
||||||
|
--embedding-api-base http://localhost:1234/v1
|
||||||
|
# Context length auto-detected if SDK available
|
||||||
|
# Falls back to registry (2048) if not
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
- ✅ Automatic token limit detection
|
||||||
|
- ✅ Respects LM Studio JIT auto-evict settings
|
||||||
|
- ✅ No manual registry maintenance
|
||||||
|
- ✅ Graceful fallback if SDK unavailable
|
||||||
|
|
||||||
|
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
|
|||||||
48
docs/faq.md
48
docs/faq.md
@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
|||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
```
|
```
|
||||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
|
|
||||||
|
## 2. When should I use prompt templates?
|
||||||
|
|
||||||
|
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
|
||||||
|
|
||||||
|
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
|
||||||
|
|
||||||
|
**Example usage with EmbeddingGemma:**
|
||||||
|
```bash
|
||||||
|
# Build with document prompt
|
||||||
|
leann build my-docs --embedding-prompt-template "title: none | text: "
|
||||||
|
|
||||||
|
# Search with query prompt
|
||||||
|
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
|
||||||
|
|
||||||
|
## 3. Why is LM Studio loading multiple copies of my model?
|
||||||
|
|
||||||
|
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
|
||||||
|
|
||||||
|
**If you still see duplicates:**
|
||||||
|
- Update to the latest LEANN version
|
||||||
|
- Restart LM Studio to clear loaded models
|
||||||
|
- Check that you have JIT auto-evict enabled in LM Studio settings
|
||||||
|
|
||||||
|
**How it works now:**
|
||||||
|
1. LEANN loads model temporarily to get context length
|
||||||
|
2. Immediately unloads after query
|
||||||
|
3. LM Studio JIT loads model on-demand for actual embeddings
|
||||||
|
4. Auto-evicts per your settings
|
||||||
|
|
||||||
|
## 4. Do I need Node.js and @lmstudio/sdk?
|
||||||
|
|
||||||
|
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
|
||||||
|
|
||||||
|
**Benefits if you install it:**
|
||||||
|
- Automatic context length detection for LM Studio models
|
||||||
|
- No manual registry maintenance
|
||||||
|
- Always gets accurate token limits from the model itself
|
||||||
|
|
||||||
|
**To install (optional):**
|
||||||
|
```bash
|
||||||
|
npm install -g @lmstudio/sdk
|
||||||
|
```
|
||||||
|
|
||||||
|
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.
|
||||||
|
|||||||
@@ -916,6 +916,7 @@ class LeannSearcher:
|
|||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
use_grep: bool = False,
|
use_grep: bool = False,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
@@ -979,10 +980,24 @@ class LeannSearcher:
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Extract query template from stored embedding_options with fallback chain:
|
||||||
|
# 1. Check provider_options override (highest priority)
|
||||||
|
# 2. Check query_prompt_template (new format)
|
||||||
|
# 3. Check prompt_template (old format for backward compat)
|
||||||
|
# 4. None (no template)
|
||||||
|
query_template = None
|
||||||
|
if provider_options and "prompt_template" in provider_options:
|
||||||
|
query_template = provider_options["prompt_template"]
|
||||||
|
elif "query_prompt_template" in self.embedding_options:
|
||||||
|
query_template = self.embedding_options["query_prompt_template"]
|
||||||
|
elif "prompt_template" in self.embedding_options:
|
||||||
|
query_template = self.embedding_options["prompt_template"]
|
||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
query,
|
query,
|
||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
|
query_template=query_template,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
|
|||||||
@@ -144,6 +144,18 @@ Examples:
|
|||||||
default=None,
|
default=None,
|
||||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
)
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--query-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template for queries (different from build template for task-specific models)",
|
||||||
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||||
)
|
)
|
||||||
@@ -260,6 +272,12 @@ Examples:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Display file paths and metadata in search results",
|
help="Display file paths and metadata in search results",
|
||||||
)
|
)
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
|
||||||
|
)
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
@@ -1398,6 +1416,14 @@ Examples:
|
|||||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
if resolved_embedding_key:
|
if resolved_embedding_key:
|
||||||
embedding_options["api_key"] = resolved_embedding_key
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
if args.query_prompt_template:
|
||||||
|
# New format: separate templates
|
||||||
|
if args.embedding_prompt_template:
|
||||||
|
embedding_options["build_prompt_template"] = args.embedding_prompt_template
|
||||||
|
embedding_options["query_prompt_template"] = args.query_prompt_template
|
||||||
|
elif args.embedding_prompt_template:
|
||||||
|
# Old format: single template (backward compat)
|
||||||
|
embedding_options["prompt_template"] = args.embedding_prompt_template
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend_name,
|
backend_name=args.backend_name,
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ Consolidates all embedding computation logic using SentenceTransformer
|
|||||||
Preserves all optimization parameters to ensure performance
|
Preserves all optimization parameters to ensure performance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@@ -40,6 +42,11 @@ EMBEDDING_MODEL_LIMITS = {
|
|||||||
"text-embedding-ada-002": 8192,
|
"text-embedding-ada-002": 8192,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Runtime cache for dynamically discovered token limits
|
||||||
|
# Key: (model_name, base_url), Value: token_limit
|
||||||
|
# Prevents repeated SDK/API calls for the same model
|
||||||
|
_token_limit_cache: dict[tuple[str, str], int] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(
|
def get_model_token_limit(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -49,6 +56,7 @@ def get_model_token_limit(
|
|||||||
"""
|
"""
|
||||||
Get token limit for a given embedding model.
|
Get token limit for a given embedding model.
|
||||||
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||||
|
Caches discovered limits to prevent repeated API/SDK calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the embedding model
|
model_name: Name of the embedding model
|
||||||
@@ -58,12 +66,33 @@ def get_model_token_limit(
|
|||||||
Returns:
|
Returns:
|
||||||
Token limit for the model in tokens
|
Token limit for the model in tokens
|
||||||
"""
|
"""
|
||||||
|
# Check cache first to avoid repeated SDK/API calls
|
||||||
|
cache_key = (model_name, base_url or "")
|
||||||
|
if cache_key in _token_limit_cache:
|
||||||
|
cached_limit = _token_limit_cache[cache_key]
|
||||||
|
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
|
||||||
|
return cached_limit
|
||||||
|
|
||||||
# Try Ollama dynamic discovery if base_url provided
|
# Try Ollama dynamic discovery if base_url provided
|
||||||
if base_url:
|
if base_url:
|
||||||
# Detect Ollama servers by port or "ollama" in URL
|
# Detect Ollama servers by port or "ollama" in URL
|
||||||
if "11434" in base_url or "ollama" in base_url.lower():
|
if "11434" in base_url or "ollama" in base_url.lower():
|
||||||
limit = _query_ollama_context_limit(model_name, base_url)
|
limit = _query_ollama_context_limit(model_name, base_url)
|
||||||
if limit:
|
if limit:
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Try LM Studio SDK discovery
|
||||||
|
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
|
||||||
|
# Convert HTTP to WebSocket URL
|
||||||
|
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
|
||||||
|
# Remove /v1 suffix if present
|
||||||
|
if ws_url.endswith("/v1"):
|
||||||
|
ws_url = ws_url[:-3]
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(model_name, ws_url)
|
||||||
|
if limit:
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
return limit
|
return limit
|
||||||
|
|
||||||
# Fallback to known model registry with version handling (from PR #154)
|
# Fallback to known model registry with version handling (from PR #154)
|
||||||
@@ -72,19 +101,25 @@ def get_model_token_limit(
|
|||||||
|
|
||||||
# Check exact match first
|
# Check exact match first
|
||||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
return EMBEDDING_MODEL_LIMITS[model_name]
|
limit = EMBEDDING_MODEL_LIMITS[model_name]
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
# Check base name match
|
# Check base name match
|
||||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||||
|
_token_limit_cache[cache_key] = limit
|
||||||
|
return limit
|
||||||
|
|
||||||
# Check partial matches for common patterns
|
# Check partial matches for common patterns
|
||||||
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
|
||||||
if known_model in base_model_name or base_model_name in known_model:
|
if known_model in base_model_name or base_model_name in known_model:
|
||||||
return limit
|
_token_limit_cache[cache_key] = registry_limit
|
||||||
|
return registry_limit
|
||||||
|
|
||||||
# Default fallback
|
# Default fallback
|
||||||
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||||
|
_token_limit_cache[cache_key] = default
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
@@ -185,6 +220,91 @@ def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Query LM Studio SDK for model context length via Node.js subprocess.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the LM Studio model
|
||||||
|
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context limit in tokens if found, None otherwise
|
||||||
|
"""
|
||||||
|
# Inline JavaScript using @lmstudio/sdk
|
||||||
|
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
|
||||||
|
js_code = f"""
|
||||||
|
const {{ LMStudioClient }} = require('@lmstudio/sdk');
|
||||||
|
(async () => {{
|
||||||
|
try {{
|
||||||
|
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
|
||||||
|
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
|
||||||
|
const contextLength = await model.getContextLength();
|
||||||
|
await model.unload(); // Unload immediately to respect JIT auto-evict settings
|
||||||
|
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
|
||||||
|
}} catch (error) {{
|
||||||
|
console.error(JSON.stringify({{ error: error.message }}));
|
||||||
|
process.exit(1);
|
||||||
|
}}
|
||||||
|
}})();
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
# Try to get npm global root (works with nvm, brew node, etc.)
|
||||||
|
try:
|
||||||
|
npm_root = subprocess.run(
|
||||||
|
["npm", "root", "-g"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if npm_root.returncode == 0:
|
||||||
|
global_modules = npm_root.stdout.strip()
|
||||||
|
# Append to existing NODE_PATH if present
|
||||||
|
existing_node_path = env.get("NODE_PATH", "")
|
||||||
|
env["NODE_PATH"] = (
|
||||||
|
f"{global_modules}:{existing_node_path}"
|
||||||
|
if existing_node_path
|
||||||
|
else global_modules
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# If npm not available, continue with existing NODE_PATH
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
["node", "-e", js_code],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=10,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.debug(f"LM Studio SDK error: {result.stderr}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
context_length = data.get("contextLength")
|
||||||
|
|
||||||
|
if context_length and context_length > 0:
|
||||||
|
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
|
||||||
|
return context_length
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.debug("LM Studio SDK query timeout")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("LM Studio SDK returned invalid JSON")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"LM Studio SDK query failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: dict[str, Any] = {}
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
@@ -232,6 +352,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
base_url=provider_options.get("base_url"),
|
base_url=provider_options.get("base_url"),
|
||||||
api_key=provider_options.get("api_key"),
|
api_key=provider_options.get("api_key"),
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
@@ -241,6 +362,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
host=provider_options.get("host"),
|
host=provider_options.get("host"),
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
elif mode == "gemini":
|
elif mode == "gemini":
|
||||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
@@ -579,6 +701,7 @@ def compute_embeddings_openai(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
@@ -597,26 +720,37 @@ def compute_embeddings_openai(
|
|||||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||||
)
|
)
|
||||||
|
|
||||||
resolved_base_url = resolve_openai_base_url(base_url)
|
# Extract base_url and api_key from provider_options if not provided directly
|
||||||
resolved_api_key = resolve_openai_api_key(api_key)
|
provider_options = provider_options or {}
|
||||||
|
effective_base_url = base_url or provider_options.get("base_url")
|
||||||
|
effective_api_key = api_key or provider_options.get("api_key")
|
||||||
|
|
||||||
|
resolved_base_url = resolve_openai_base_url(effective_base_url)
|
||||||
|
resolved_api_key = resolve_openai_api_key(effective_api_key)
|
||||||
|
|
||||||
if not resolved_api_key:
|
if not resolved_api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
# Cache OpenAI client
|
# Create OpenAI client
|
||||||
cache_key = f"openai_client::{resolved_base_url}"
|
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||||
if cache_key in _model_cache:
|
|
||||||
client = _model_cache[cache_key]
|
|
||||||
else:
|
|
||||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
|
||||||
_model_cache[cache_key] = client
|
|
||||||
logger.info("OpenAI client cached")
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
)
|
)
|
||||||
print(f"len of texts: {len(texts)}")
|
print(f"len of texts: {len(texts)}")
|
||||||
|
|
||||||
|
# Apply prompt template if provided
|
||||||
|
prompt_template = provider_options.get("prompt_template")
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||||
|
texts = [f"{prompt_template}{text}" for text in texts]
|
||||||
|
|
||||||
|
# Query token limit and apply truncation
|
||||||
|
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
|
||||||
|
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
|
||||||
|
texts = truncate_to_token_limit(texts, token_limit)
|
||||||
|
|
||||||
# OpenAI has limits on batch size and input length
|
# OpenAI has limits on batch size and input length
|
||||||
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
@@ -647,7 +781,15 @@ def compute_embeddings_openai(
|
|||||||
try:
|
try:
|
||||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
|
# Verify we got the expected number of embeddings
|
||||||
|
if len(batch_embeddings) != len(batch_texts):
|
||||||
|
logger.warning(
|
||||||
|
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only take the number of embeddings that match the batch size
|
||||||
|
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Batch {i} failed: {e}")
|
logger.error(f"Batch {i} failed: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -737,6 +879,7 @@ def compute_embeddings_ollama(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with true batch processing.
|
Compute embeddings using Ollama API with true batch processing.
|
||||||
@@ -749,6 +892,7 @@ def compute_embeddings_ollama(
|
|||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||||
|
provider_options: Optional provider-specific options (e.g., prompt_template)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
@@ -885,6 +1029,14 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||||
|
|
||||||
|
# Apply prompt template if provided
|
||||||
|
provider_options = provider_options or {}
|
||||||
|
prompt_template = provider_options.get("prompt_template")
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||||
|
texts = [f"{prompt_template}{text}" for text in texts]
|
||||||
|
|
||||||
# Get model token limit and apply truncation before batching
|
# Get model token limit and apply truncation before batching
|
||||||
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
|
query_template: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
query_template: Optional prompt template to prepend to query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Query embedding as numpy array with shape (1, D)
|
Query embedding as numpy array with shape (1, D)
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: int = 5557,
|
zmq_port: int = 5557,
|
||||||
|
query_template: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embedding for a query string.
|
Compute embedding for a query string.
|
||||||
@@ -98,10 +99,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
query_template: Optional prompt template to prepend to query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Query embedding as numpy array
|
Query embedding as numpy array
|
||||||
"""
|
"""
|
||||||
|
# Apply query template BEFORE any computation path
|
||||||
|
# This ensures template is applied consistently for both server and fallback paths
|
||||||
|
if query_template:
|
||||||
|
query = f"{query_template}{query}"
|
||||||
|
|
||||||
# Try to use embedding server if available and requested
|
# Try to use embedding server if available and requested
|
||||||
if use_server_if_available:
|
if use_server_if_available:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ python_functions = ["test_*"]
|
|||||||
markers = [
|
markers = [
|
||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"openai: marks tests that require OpenAI API key",
|
"openai: marks tests that require OpenAI API key",
|
||||||
|
"integration: marks tests that require live services (Ollama, LM Studio, etc.)",
|
||||||
]
|
]
|
||||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||||
addopts = [
|
addopts = [
|
||||||
|
|||||||
@@ -36,6 +36,14 @@ Tests DiskANN graph partitioning functionality:
|
|||||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||||
|
|
||||||
|
### `test_prompt_template_e2e.py`
|
||||||
|
Integration tests for prompt template feature with live embedding services:
|
||||||
|
- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio)
|
||||||
|
- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default)
|
||||||
|
- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk)
|
||||||
|
- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration`
|
||||||
|
- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
### Install test dependencies:
|
### Install test dependencies:
|
||||||
@@ -66,6 +74,12 @@ pytest tests/ -m "not openai"
|
|||||||
# Skip slow tests
|
# Skip slow tests
|
||||||
pytest tests/ -m "not slow"
|
pytest tests/ -m "not slow"
|
||||||
|
|
||||||
|
# Skip integration tests (that require live services)
|
||||||
|
pytest tests/ -m "not integration"
|
||||||
|
|
||||||
|
# Run only integration tests (requires LM Studio or Ollama running)
|
||||||
|
pytest tests/test_prompt_template_e2e.py -v -s
|
||||||
|
|
||||||
# Run DiskANN partition tests (requires local machine, not CI)
|
# Run DiskANN partition tests (requires local machine, not CI)
|
||||||
pytest tests/test_diskann_partition.py
|
pytest tests/test_diskann_partition.py
|
||||||
```
|
```
|
||||||
@@ -101,6 +115,20 @@ The `pytest.ini` file configures:
|
|||||||
- Custom markers for slow and OpenAI tests
|
- Custom markers for slow and OpenAI tests
|
||||||
- Verbose output with short tracebacks
|
- Verbose output with short tracebacks
|
||||||
|
|
||||||
|
### Integration Test Prerequisites
|
||||||
|
|
||||||
|
Integration tests (`test_prompt_template_e2e.py`) require live services:
|
||||||
|
|
||||||
|
**Required:**
|
||||||
|
- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded
|
||||||
|
|
||||||
|
**Optional:**
|
||||||
|
- Ollama running at `http://localhost:11434` for token limit detection tests
|
||||||
|
- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests
|
||||||
|
|
||||||
|
Tests gracefully skip if services are unavailable.
|
||||||
|
|
||||||
### Known Issues
|
### Known Issues
|
||||||
|
|
||||||
- OpenAI tests are automatically skipped if no API key is provided
|
- OpenAI tests are automatically skipped if no API key is provided
|
||||||
|
- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed)
|
||||||
|
|||||||
533
tests/test_cli_prompt_template.py
Normal file
533
tests/test_cli_prompt_template.py
Normal file
@@ -0,0 +1,533 @@
|
|||||||
|
"""
|
||||||
|
Tests for CLI argument integration of --embedding-prompt-template.
|
||||||
|
|
||||||
|
These tests verify that:
|
||||||
|
1. The --embedding-prompt-template flag is properly registered on build and search commands
|
||||||
|
2. The template value flows from CLI args to embedding_options dict
|
||||||
|
3. The template is passed through to compute_embeddings() function
|
||||||
|
4. Default behavior (no flag) is handled correctly
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from leann.cli import LeannCLI
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLIPromptTemplateArgument:
|
||||||
|
"""Tests for --embedding-prompt-template on build and search commands."""
|
||||||
|
|
||||||
|
def test_commands_accept_prompt_template_argument(self):
|
||||||
|
"""Verify that build and search parsers accept --embedding-prompt-template flag."""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
template_value = "search_query: "
|
||||||
|
|
||||||
|
# Test build command
|
||||||
|
build_args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
"/tmp/test-docs",
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template_value,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert build_args.command == "build"
|
||||||
|
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||||
|
"build command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert build_args.embedding_prompt_template == template_value
|
||||||
|
|
||||||
|
# Test search command
|
||||||
|
search_args = parser.parse_args(
|
||||||
|
["search", "test-index", "my query", "--embedding-prompt-template", template_value]
|
||||||
|
)
|
||||||
|
assert search_args.command == "search"
|
||||||
|
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||||
|
"search command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert search_args.embedding_prompt_template == template_value
|
||||||
|
|
||||||
|
def test_commands_default_to_none(self):
|
||||||
|
"""Verify default value is None when flag not provided (backward compatibility)."""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
# Test build command default
|
||||||
|
build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"])
|
||||||
|
assert hasattr(build_args, "embedding_prompt_template"), (
|
||||||
|
"build command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert build_args.embedding_prompt_template is None, (
|
||||||
|
"Build default value should be None when flag not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test search command default
|
||||||
|
search_args = parser.parse_args(["search", "test-index", "my query"])
|
||||||
|
assert hasattr(search_args, "embedding_prompt_template"), (
|
||||||
|
"search command should have embedding_prompt_template attribute"
|
||||||
|
)
|
||||||
|
assert search_args.embedding_prompt_template is None, (
|
||||||
|
"Search default value should be None when flag not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildCommandPromptTemplateArgumentExtras:
|
||||||
|
"""Additional build-specific tests for prompt template argument."""
|
||||||
|
|
||||||
|
def test_build_command_prompt_template_with_multiword_value(self):
|
||||||
|
"""
|
||||||
|
Verify that template values with spaces are handled correctly.
|
||||||
|
|
||||||
|
Templates like "search_document: " or "Represent this sentence for searching: "
|
||||||
|
should be accepted as a single string argument.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "Represent this sentence for searching: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
"/tmp/test-docs",
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert args.embedding_prompt_template == template
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateStoredInEmbeddingOptions:
|
||||||
|
"""Tests for template storage in embedding_options dict."""
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_prompt_template_stored_in_embedding_options_on_build(
|
||||||
|
self, mock_builder_class, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify that when --embedding-prompt-template is provided to build command,
|
||||||
|
the value is stored in embedding_options dict passed to LeannBuilder.
|
||||||
|
|
||||||
|
This test will fail because the CLI doesn't currently process this argument
|
||||||
|
and add it to embedding_options.
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
# Create CLI and run build command
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "search_query: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with embedding_options containing prompt_template
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when template provided"
|
||||||
|
)
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got {embedding_options.get('prompt_template')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
Verify that when --embedding-prompt-template is NOT provided,
|
||||||
|
embedding_options either doesn't have the key or it's None.
|
||||||
|
|
||||||
|
This ensures we don't pass empty/None values unnecessarily.
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--force", # Force rebuild to ensure LeannBuilder is called
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that if embedding_options is passed, it doesn't have prompt_template
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
if call_kwargs.get("embedding_options"):
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
# Either the key shouldn't exist, or it should be None
|
||||||
|
assert (
|
||||||
|
"prompt_template" not in embedding_options
|
||||||
|
or embedding_options["prompt_template"] is None
|
||||||
|
), "prompt_template should not be set when flag not provided"
|
||||||
|
|
||||||
|
# R1 Tests: Build-time separate template storage
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_stores_separate_templates(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 1: Verify that when both --embedding-prompt-template and
|
||||||
|
--query-prompt-template are provided to build command, both values
|
||||||
|
are stored separately in embedding_options dict as build_prompt_template
|
||||||
|
and query_prompt_template.
|
||||||
|
|
||||||
|
This test will fail because:
|
||||||
|
1. CLI doesn't accept --query-prompt-template flag yet
|
||||||
|
2. CLI doesn't store templates as separate build_prompt_template and
|
||||||
|
query_prompt_template keys
|
||||||
|
|
||||||
|
Expected behavior after implementation:
|
||||||
|
- .meta.json contains: {"embedding_options": {
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: "
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
build_template = "doc: "
|
||||||
|
query_template = "query: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
build_template,
|
||||||
|
"--query-prompt-template",
|
||||||
|
query_template,
|
||||||
|
"--force",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with separate template keys
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when templates provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "build_prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'build_prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["build_prompt_template"] == build_template, (
|
||||||
|
f"build_prompt_template should be '{build_template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "query_prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain 'query_prompt_template' key"
|
||||||
|
)
|
||||||
|
assert embedding_options["query_prompt_template"] == query_template, (
|
||||||
|
f"query_prompt_template should be '{query_template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Old key should NOT be present when using new separate template format
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"Old 'prompt_template' key should not be present with separate templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 2: Verify backward compatibility - when only
|
||||||
|
--embedding-prompt-template is provided (old behavior), it should
|
||||||
|
still be stored as 'prompt_template' in embedding_options.
|
||||||
|
|
||||||
|
This ensures existing workflows continue to work unchanged.
|
||||||
|
|
||||||
|
This test currently passes because it matches existing behavior, but it
|
||||||
|
documents the requirement that this behavior must be preserved after
|
||||||
|
implementing the separate template feature.
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}}
|
||||||
|
- No build_prompt_template or query_prompt_template keys
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "prompt: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--force",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that LeannBuilder was called with old format
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
|
||||||
|
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
assert embedding_options is not None, (
|
||||||
|
"embedding_options should not be None when template provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"embedding_options should contain old 'prompt_template' key for backward compat"
|
||||||
|
)
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"prompt_template should be '{template}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# New keys should NOT be present in backward compat mode
|
||||||
|
assert "build_prompt_template" not in embedding_options, (
|
||||||
|
"build_prompt_template should not be present with single template flag"
|
||||||
|
)
|
||||||
|
assert "query_prompt_template" not in embedding_options, (
|
||||||
|
"query_prompt_template should not be present with single template flag"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("leann.cli.LeannBuilder")
|
||||||
|
def test_build_no_templates(self, mock_builder_class, tmp_path):
|
||||||
|
"""
|
||||||
|
R1 Test 3: Verify that when no template flags are provided,
|
||||||
|
embedding_options has no prompt template keys.
|
||||||
|
|
||||||
|
This ensures clean defaults and no unnecessary keys in .meta.json.
|
||||||
|
|
||||||
|
This test currently passes because it matches existing behavior, but it
|
||||||
|
documents the requirement that this behavior must be preserved after
|
||||||
|
implementing the separate template feature.
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- .meta.json has no prompt_template, build_prompt_template, or
|
||||||
|
query_prompt_template keys (or embedding_options is empty/None)
|
||||||
|
"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_builder = Mock()
|
||||||
|
mock_builder_class.return_value = mock_builder
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a document so builder is created
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"])
|
||||||
|
|
||||||
|
# Run the build command
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Check that no template keys are present
|
||||||
|
call_kwargs = mock_builder_class.call_args.kwargs
|
||||||
|
if call_kwargs.get("embedding_options"):
|
||||||
|
embedding_options = call_kwargs["embedding_options"]
|
||||||
|
|
||||||
|
# None of the template keys should be present
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
assert "build_prompt_template" not in embedding_options, (
|
||||||
|
"build_prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
assert "query_prompt_template" not in embedding_options, (
|
||||||
|
"query_prompt_template should not be present when no flags provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateFlowsToComputeEmbeddings:
|
||||||
|
"""Tests for template flowing through to compute_embeddings function."""
|
||||||
|
|
||||||
|
@patch("leann.api.compute_embeddings")
|
||||||
|
def test_prompt_template_flows_to_compute_embeddings_via_provider_options(
|
||||||
|
self, mock_compute_embeddings, tmp_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify that the prompt template flows from CLI args through LeannBuilder
|
||||||
|
to compute_embeddings() function via provider_options parameter.
|
||||||
|
|
||||||
|
This is an integration test that verifies the complete flow:
|
||||||
|
CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options)
|
||||||
|
|
||||||
|
This test will fail because:
|
||||||
|
1. CLI doesn't capture the argument yet
|
||||||
|
2. embedding_options doesn't include prompt_template
|
||||||
|
3. LeannBuilder doesn't pass it through to compute_embeddings
|
||||||
|
"""
|
||||||
|
# Mock compute_embeddings to return dummy embeddings as numpy array
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
|
||||||
|
# Use real LeannBuilder (not mocked) to test the actual flow
|
||||||
|
cli = LeannCLI()
|
||||||
|
|
||||||
|
# Mock load_documents to return a simple document
|
||||||
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||||
|
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
template = "search_document: "
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"build",
|
||||||
|
"test-index",
|
||||||
|
"--docs",
|
||||||
|
str(tmp_path),
|
||||||
|
"--embedding-prompt-template",
|
||||||
|
template,
|
||||||
|
"--backend-name",
|
||||||
|
"hnsw", # Use hnsw backend
|
||||||
|
"--force", # Force rebuild to ensure index is created
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should fail because the flow isn't implemented yet
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(cli.build_index(args))
|
||||||
|
|
||||||
|
# Verify compute_embeddings was called with provider_options containing prompt_template
|
||||||
|
assert mock_compute_embeddings.called, "compute_embeddings should have been called"
|
||||||
|
|
||||||
|
# Check the call arguments
|
||||||
|
call_kwargs = mock_compute_embeddings.call_args.kwargs
|
||||||
|
assert "provider_options" in call_kwargs, (
|
||||||
|
"compute_embeddings should receive provider_options parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_options = call_kwargs["provider_options"]
|
||||||
|
assert provider_options is not None, "provider_options should not be None"
|
||||||
|
assert "prompt_template" in provider_options, (
|
||||||
|
"provider_options should contain prompt_template key"
|
||||||
|
)
|
||||||
|
assert provider_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got {provider_options.get('prompt_template')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateArgumentHelp:
|
||||||
|
"""Tests for argument help text and documentation."""
|
||||||
|
|
||||||
|
def test_build_command_prompt_template_has_help_text(self):
|
||||||
|
"""
|
||||||
|
Verify that --embedding-prompt-template has descriptive help text.
|
||||||
|
|
||||||
|
Good help text is crucial for CLI usability.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
# Get the build subparser
|
||||||
|
# This is a bit tricky - we need to parse to get the help
|
||||||
|
# We'll check that the help includes relevant keywords
|
||||||
|
import io
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
f = io.StringIO()
|
||||||
|
try:
|
||||||
|
with redirect_stdout(f):
|
||||||
|
parser.parse_args(["build", "--help"])
|
||||||
|
except SystemExit:
|
||||||
|
pass # --help causes sys.exit(0)
|
||||||
|
|
||||||
|
help_text = f.getvalue()
|
||||||
|
assert "--embedding-prompt-template" in help_text, (
|
||||||
|
"Help text should mention --embedding-prompt-template"
|
||||||
|
)
|
||||||
|
# Check for keywords that should be in the help
|
||||||
|
help_lower = help_text.lower()
|
||||||
|
assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), (
|
||||||
|
"Help text should explain what the prompt template does"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_command_prompt_template_has_help_text(self):
|
||||||
|
"""
|
||||||
|
Verify that search command also has help text for --embedding-prompt-template.
|
||||||
|
"""
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
import io
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
f = io.StringIO()
|
||||||
|
try:
|
||||||
|
with redirect_stdout(f):
|
||||||
|
parser.parse_args(["search", "--help"])
|
||||||
|
except SystemExit:
|
||||||
|
pass # --help causes sys.exit(0)
|
||||||
|
|
||||||
|
help_text = f.getvalue()
|
||||||
|
assert "--embedding-prompt-template" in help_text, (
|
||||||
|
"Search help text should mention --embedding-prompt-template"
|
||||||
|
)
|
||||||
281
tests/test_embedding_prompt_template.py
Normal file
281
tests/test_embedding_prompt_template.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""Unit tests for prompt template prepending in OpenAI embeddings.
|
||||||
|
|
||||||
|
This test suite defines the contract for prompt template functionality that allows
|
||||||
|
users to prepend a consistent prompt to all embedding inputs. These tests verify:
|
||||||
|
|
||||||
|
1. Template prepending to all input texts before embedding computation
|
||||||
|
2. Graceful handling of None/missing provider_options
|
||||||
|
3. Empty string template behavior (no-op)
|
||||||
|
4. Logging of template application for observability
|
||||||
|
5. Template application before token truncation
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
implementation does not exist yet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from leann.embedding_compute import compute_embeddings_openai
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplatePrepending:
|
||||||
|
"""Tests for prompt template prepending in compute_embeddings_openai."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client(self):
|
||||||
|
"""Create mock OpenAI client that captures input texts."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
# Mock the embeddings.create response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [
|
||||||
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||||
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||||
|
]
|
||||||
|
mock_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_module(self, mock_openai_client, monkeypatch):
|
||||||
|
"""Mock the openai module to return our mock client."""
|
||||||
|
# Mock the API key environment variable
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking")
|
||||||
|
|
||||||
|
# openai is imported inside the function, so we need to patch it there
|
||||||
|
with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai:
|
||||||
|
yield mock_openai
|
||||||
|
|
||||||
|
def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template is prepended to all input texts.
|
||||||
|
|
||||||
|
When provider_options contains "prompt_template", that template should
|
||||||
|
be prepended to every text in the input list before sending to OpenAI API.
|
||||||
|
|
||||||
|
This is the core functionality: the template acts as a consistent prefix
|
||||||
|
that provides context or instruction for the embedding model.
|
||||||
|
"""
|
||||||
|
texts = ["First document", "Second document"]
|
||||||
|
template = "search_document: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
# Call compute_embeddings_openai with provider_options
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embeddings.create was called with templated texts
|
||||||
|
mock_openai_client.embeddings.create.assert_called_once()
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
|
||||||
|
# Extract the input texts sent to API
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
|
||||||
|
# Verify template was prepended to all texts
|
||||||
|
assert len(sent_texts) == 2, "Should send same number of texts"
|
||||||
|
assert sent_texts[0] == "search_document: First document", (
|
||||||
|
"Template should be prepended to first text"
|
||||||
|
)
|
||||||
|
assert sent_texts[1] == "search_document: Second document", (
|
||||||
|
"Template should be prepended to second text"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result is valid embeddings array
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
assert result.shape == (2, 3), "Should return correct shape"
|
||||||
|
|
||||||
|
def test_template_not_applied_when_missing_or_empty(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template not applied when provider_options is None, missing key, or empty string.
|
||||||
|
|
||||||
|
This consolidated test covers three scenarios where templates should NOT be applied:
|
||||||
|
1. provider_options is None (default behavior)
|
||||||
|
2. provider_options exists but missing 'prompt_template' key
|
||||||
|
3. prompt_template is explicitly set to empty string ""
|
||||||
|
|
||||||
|
In all cases, texts should be sent to the API unchanged.
|
||||||
|
"""
|
||||||
|
# Scenario 1: None provider_options
|
||||||
|
texts = ["Original text one", "Original text two"]
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=None,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Original text one", (
|
||||||
|
"Text should be unchanged with None provider_options"
|
||||||
|
)
|
||||||
|
assert sent_texts[1] == "Original text two"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
assert result.shape == (2, 3)
|
||||||
|
|
||||||
|
# Reset mock for next scenario
|
||||||
|
mock_openai_client.reset_mock()
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [
|
||||||
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||||
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||||
|
]
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Scenario 2: Missing 'prompt_template' key
|
||||||
|
texts = ["Text without template", "Another text"]
|
||||||
|
provider_options = {"base_url": "https://api.openai.com/v1"}
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key"
|
||||||
|
assert sent_texts[1] == "Another text"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
# Reset mock for next scenario
|
||||||
|
mock_openai_client.reset_mock()
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Scenario 3: Empty string template
|
||||||
|
texts = ["Text one", "Text two"]
|
||||||
|
provider_options = {"prompt_template": ""}
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "Text one", "Empty template should not modify text"
|
||||||
|
assert sent_texts[1] == "Text two"
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template is prepended in all batches when texts exceed batch size.
|
||||||
|
|
||||||
|
OpenAI API has batch size limits. When input texts are split into
|
||||||
|
multiple batches, the template should be prepended to texts in every batch.
|
||||||
|
|
||||||
|
This ensures consistency across all API calls.
|
||||||
|
"""
|
||||||
|
# Create many texts that will be split into multiple batches
|
||||||
|
texts = [f"Document {i}" for i in range(1000)]
|
||||||
|
template = "passage: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
# Mock multiple batch responses
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)]
|
||||||
|
mock_openai_client.embeddings.create.return_value = mock_response
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embeddings.create was called multiple times (batching)
|
||||||
|
assert mock_openai_client.embeddings.create.call_count >= 2, (
|
||||||
|
"Should make multiple API calls for large text list"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was prepended in ALL batches
|
||||||
|
for call in mock_openai_client.embeddings.create.call_args_list:
|
||||||
|
sent_texts = call.kwargs["input"]
|
||||||
|
for text in sent_texts:
|
||||||
|
assert text.startswith(template), (
|
||||||
|
f"All texts in all batches should start with template. Got: {text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result shape
|
||||||
|
assert result.shape[0] == 1000, "Should return embeddings for all texts"
|
||||||
|
|
||||||
|
def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client):
|
||||||
|
"""Verify template with special characters is handled correctly.
|
||||||
|
|
||||||
|
Templates may contain special characters, Unicode, newlines, etc.
|
||||||
|
These should all be prepended correctly without encoding issues.
|
||||||
|
"""
|
||||||
|
texts = ["Document content"]
|
||||||
|
# Template with various special characters
|
||||||
|
template = "🔍 Search query [EN]: "
|
||||||
|
provider_options = {"prompt_template": template}
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify special characters in template were preserved
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
|
||||||
|
assert sent_texts[0] == "🔍 Search query [EN]: Document content", (
|
||||||
|
"Special characters in template should be preserved"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
def test_prompt_template_integration_with_existing_validation(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template works with existing input validation.
|
||||||
|
|
||||||
|
compute_embeddings_openai has validation for empty texts and whitespace.
|
||||||
|
Template prepending should happen AFTER validation, so validation errors
|
||||||
|
are thrown based on original texts, not templated texts.
|
||||||
|
|
||||||
|
This ensures users get clear error messages about their input.
|
||||||
|
"""
|
||||||
|
# Empty text should still raise ValueError even with template
|
||||||
|
texts = [""]
|
||||||
|
provider_options = {"prompt_template": "prefix: "}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="empty/invalid"):
|
||||||
|
compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prompt_template_with_api_key_and_base_url(
|
||||||
|
self, mock_openai_module, mock_openai_client
|
||||||
|
):
|
||||||
|
"""Verify template works alongside other provider_options.
|
||||||
|
|
||||||
|
provider_options may contain multiple settings: prompt_template,
|
||||||
|
base_url, api_key. All should work together correctly.
|
||||||
|
"""
|
||||||
|
texts = ["Test document"]
|
||||||
|
provider_options = {
|
||||||
|
"prompt_template": "embed: ",
|
||||||
|
"base_url": "https://custom.api.com/v1",
|
||||||
|
"api_key": "test-key-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied
|
||||||
|
call_args = mock_openai_client.embeddings.create.call_args
|
||||||
|
sent_texts = call_args.kwargs["input"]
|
||||||
|
assert sent_texts[0] == "embed: Test document"
|
||||||
|
|
||||||
|
# Verify OpenAI client was created with correct base_url
|
||||||
|
mock_openai_module.assert_called()
|
||||||
|
client_init_kwargs = mock_openai_module.call_args.kwargs
|
||||||
|
assert client_init_kwargs["base_url"] == "https://custom.api.com/v1"
|
||||||
|
assert client_init_kwargs["api_key"] == "test-key-123"
|
||||||
|
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
315
tests/test_lmstudio_bridge.py
Normal file
315
tests/test_lmstudio_bridge.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
"""Unit tests for LM Studio TypeScript SDK bridge functionality.
|
||||||
|
|
||||||
|
This test suite defines the contract for the LM Studio SDK bridge that queries
|
||||||
|
model context length via Node.js subprocess. These tests verify:
|
||||||
|
|
||||||
|
1. Successful SDK query returns context length
|
||||||
|
2. Graceful fallback when Node.js not installed (FileNotFoundError)
|
||||||
|
3. Graceful fallback when SDK not installed (npm error)
|
||||||
|
4. Timeout handling (subprocess.TimeoutExpired)
|
||||||
|
5. Invalid JSON response handling
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
`_query_lmstudio_context_limit` function does not exist yet.
|
||||||
|
|
||||||
|
The function contract:
|
||||||
|
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
|
||||||
|
- Outputs: context_length (int) or None on error
|
||||||
|
- Requirements:
|
||||||
|
1. Call Node.js with inline JavaScript using @lmstudio/sdk
|
||||||
|
2. 10-second timeout (accounts for Node.js startup)
|
||||||
|
3. Graceful fallback on any error (returns None, doesn't raise)
|
||||||
|
4. Parse JSON response with contextLength field
|
||||||
|
5. Log errors at debug level (not warning/error)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Try to import the function - if it doesn't exist, tests will fail as expected
|
||||||
|
try:
|
||||||
|
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||||
|
except ImportError:
|
||||||
|
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
|
||||||
|
def _query_lmstudio_context_limit(*args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioBridge:
|
||||||
|
"""Tests for LM Studio TypeScript SDK bridge integration."""
|
||||||
|
|
||||||
|
def test_query_lmstudio_success(self, monkeypatch):
|
||||||
|
"""Verify successful SDK query returns context length.
|
||||||
|
|
||||||
|
When the Node.js subprocess successfully queries the LM Studio SDK,
|
||||||
|
it should return a JSON response with contextLength field. The function
|
||||||
|
should parse this and return the integer context length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
# Verify timeout is set to 10 seconds
|
||||||
|
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
|
||||||
|
|
||||||
|
# Verify capture_output and text=True are set
|
||||||
|
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
|
||||||
|
assert kwargs.get("text") is True, "Should decode output as text"
|
||||||
|
|
||||||
|
# Return successful JSON response
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
# Test with typical LM Studio model
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 8192, "Should return context length from SDK response"
|
||||||
|
|
||||||
|
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when Node.js not installed.
|
||||||
|
|
||||||
|
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
|
||||||
|
The function should catch this and return None (graceful fallback to registry).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise FileNotFoundError("node: command not found")
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when Node.js not installed"
|
||||||
|
|
||||||
|
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when @lmstudio/sdk not installed.
|
||||||
|
|
||||||
|
When the SDK npm package is not installed, Node.js will return non-zero
|
||||||
|
exit code with error message in stderr. The function should detect this
|
||||||
|
and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 1
|
||||||
|
mock_result.stdout = ""
|
||||||
|
mock_result.stderr = (
|
||||||
|
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
|
||||||
|
)
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when SDK not installed"
|
||||||
|
|
||||||
|
def test_query_lmstudio_timeout(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when subprocess times out.
|
||||||
|
|
||||||
|
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
|
||||||
|
not responding), subprocess.TimeoutExpired should be raised. The function
|
||||||
|
should catch this and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None on timeout"
|
||||||
|
|
||||||
|
def test_query_lmstudio_invalid_json(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when response is invalid JSON.
|
||||||
|
|
||||||
|
When the subprocess returns malformed JSON (e.g., due to SDK error),
|
||||||
|
json.loads will raise ValueError/JSONDecodeError. The function should
|
||||||
|
catch this and return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = "This is not valid JSON"
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when JSON parsing fails"
|
||||||
|
|
||||||
|
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when JSON lacks contextLength field.
|
||||||
|
|
||||||
|
When the SDK returns valid JSON but without the expected contextLength
|
||||||
|
field (e.g., error response), the function should return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="nonexistent-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength field missing"
|
||||||
|
|
||||||
|
def test_query_lmstudio_null_context_length(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when contextLength is null.
|
||||||
|
|
||||||
|
When the SDK returns contextLength: null (model couldn't be loaded),
|
||||||
|
the function should return None for registry fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength is null"
|
||||||
|
|
||||||
|
def test_query_lmstudio_zero_context_length(self, monkeypatch):
|
||||||
|
"""Verify graceful fallback when contextLength is zero.
|
||||||
|
|
||||||
|
When the SDK returns contextLength: 0 (invalid value), the function
|
||||||
|
should return None to trigger registry fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit is None, "Should return None when contextLength is zero"
|
||||||
|
|
||||||
|
def test_query_lmstudio_with_custom_port(self, monkeypatch):
|
||||||
|
"""Verify SDK query works with non-default WebSocket port.
|
||||||
|
|
||||||
|
LM Studio can run on custom ports. The function should pass the
|
||||||
|
provided base_url to the Node.js subprocess.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
# Verify the base_url argument is passed correctly
|
||||||
|
command = args[0] if args else kwargs.get("args", [])
|
||||||
|
assert "ws://localhost:8080" in " ".join(command), (
|
||||||
|
"Should pass custom port to subprocess"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="custom-model", base_url="ws://localhost:8080"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 4096, "Should work with custom WebSocket port"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"context_length,expected",
|
||||||
|
[
|
||||||
|
(512, 512), # Small context
|
||||||
|
(2048, 2048), # Common context
|
||||||
|
(8192, 8192), # Large context
|
||||||
|
(32768, 32768), # Very large context
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
|
||||||
|
"""Verify SDK query handles various context length values.
|
||||||
|
|
||||||
|
Different models have different context lengths. The function should
|
||||||
|
correctly parse and return any positive integer value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
|
||||||
|
mock_result.stderr = ""
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
limit = _query_lmstudio_context_limit(
|
||||||
|
model_name="test-model", base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == expected, f"Should return {expected} for context length {context_length}"
|
||||||
|
|
||||||
|
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
|
||||||
|
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
|
||||||
|
|
||||||
|
Following the graceful fallback pattern from Ollama implementation,
|
||||||
|
errors should be logged at debug level to avoid alarming users when
|
||||||
|
fallback to registry works fine.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
|
||||||
|
|
||||||
|
def mock_run(*args, **kwargs):
|
||||||
|
raise FileNotFoundError("node: command not found")
|
||||||
|
|
||||||
|
monkeypatch.setattr("subprocess.run", mock_run)
|
||||||
|
|
||||||
|
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
|
||||||
|
|
||||||
|
# Check that debug logging occurred (not warning/error)
|
||||||
|
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
|
||||||
|
assert len(debug_logs) > 0, "Should log error at DEBUG level"
|
||||||
|
|
||||||
|
# Verify no WARNING or ERROR logs
|
||||||
|
warning_or_error_logs = [
|
||||||
|
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
|
||||||
|
]
|
||||||
|
assert len(warning_or_error_logs) == 0, (
|
||||||
|
"Should not log at WARNING/ERROR level for expected failures"
|
||||||
|
)
|
||||||
400
tests/test_prompt_template_e2e.py
Normal file
400
tests/test_prompt_template_e2e.py
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
"""End-to-end integration tests for prompt template and token limit features.
|
||||||
|
|
||||||
|
These tests verify real-world functionality with live services:
|
||||||
|
- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support
|
||||||
|
- Ollama with dynamic token limit detection
|
||||||
|
- Hybrid token limit discovery mechanism
|
||||||
|
|
||||||
|
Run with: pytest tests/test_prompt_template_e2e.py -v -s
|
||||||
|
Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration"
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
1. LM Studio running with embedding model: http://localhost:1234
|
||||||
|
2. [Optional] Ollama running: ollama serve
|
||||||
|
3. [Optional] Ollama model: ollama pull nomic-embed-text
|
||||||
|
4. [Optional] Node.js + @lmstudio/sdk for context length detection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from leann.embedding_compute import (
|
||||||
|
compute_embeddings_ollama,
|
||||||
|
compute_embeddings_openai,
|
||||||
|
get_model_token_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test markers for conditional execution
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool:
|
||||||
|
"""Check if a service is available on the given host:port."""
|
||||||
|
try:
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(timeout)
|
||||||
|
result = sock.connect_ex((host, port))
|
||||||
|
sock.close()
|
||||||
|
return result == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_available() -> bool:
|
||||||
|
"""Check if Ollama service is available."""
|
||||||
|
if not check_service_available("localhost", 11434):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_lmstudio_available() -> bool:
|
||||||
|
"""Check if LM Studio service is available."""
|
||||||
|
if not check_service_available("localhost", 1234):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=2.0)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_lmstudio_first_model() -> str:
|
||||||
|
"""Get the first available model from LM Studio."""
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||||
|
data = response.json()
|
||||||
|
models = data.get("data", [])
|
||||||
|
if models:
|
||||||
|
return models[0]["id"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateOpenAI:
|
||||||
|
"""End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio)."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234"
|
||||||
|
)
|
||||||
|
def test_lmstudio_embedding_with_prompt_template(self):
|
||||||
|
"""Test prompt templates with LM Studio using OpenAI-compatible API."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
texts = ["artificial intelligence", "machine learning"]
|
||||||
|
prompt_template = "search_query: "
|
||||||
|
|
||||||
|
# Get embeddings with prompt template via provider_options
|
||||||
|
provider_options = {"prompt_template": prompt_template}
|
||||||
|
embeddings = compute_embeddings_openai(
|
||||||
|
texts=texts,
|
||||||
|
model_name=model_name,
|
||||||
|
base_url="http://localhost:1234/v1",
|
||||||
|
api_key="lm-studio", # LM Studio doesn't require real key
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embeddings is not None
|
||||||
|
assert len(embeddings) == 2
|
||||||
|
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||||
|
assert all(len(emb) > 0 for emb in embeddings)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_prompt_template_affects_embeddings(self):
|
||||||
|
"""Verify that prompt templates actually change embedding values."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
text = "machine learning"
|
||||||
|
base_url = "http://localhost:1234/v1"
|
||||||
|
api_key = "lm-studio"
|
||||||
|
|
||||||
|
# Get embeddings without template
|
||||||
|
embeddings_no_template = compute_embeddings_openai(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
provider_options={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embeddings with template
|
||||||
|
embeddings_with_template = compute_embeddings_openai(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
provider_options={"prompt_template": "search_query: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embeddings should be different when template is applied
|
||||||
|
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||||
|
|
||||||
|
logger.info("✓ Prompt template changes embedding values as expected")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateOllama:
|
||||||
|
"""End-to-end tests for prompt template with Ollama."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not check_ollama_available(), reason="Ollama service not available on localhost:11434"
|
||||||
|
)
|
||||||
|
def test_ollama_embedding_with_prompt_template(self):
|
||||||
|
"""Test prompt templates with Ollama using any available embedding model."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
model_name = embedding_models[0]
|
||||||
|
|
||||||
|
texts = ["artificial intelligence", "machine learning"]
|
||||||
|
prompt_template = "search_query: "
|
||||||
|
|
||||||
|
# Get embeddings with prompt template via provider_options
|
||||||
|
provider_options = {"prompt_template": prompt_template}
|
||||||
|
embeddings = compute_embeddings_ollama(
|
||||||
|
texts=texts,
|
||||||
|
model_name=model_name,
|
||||||
|
is_build=False,
|
||||||
|
host="http://localhost:11434",
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embeddings is not None
|
||||||
|
assert len(embeddings) == 2
|
||||||
|
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
|
||||||
|
assert all(len(emb) > 0 for emb in embeddings)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_ollama_prompt_template_affects_embeddings(self):
|
||||||
|
"""Verify that prompt templates actually change embedding values with Ollama."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
model_name = embedding_models[0]
|
||||||
|
text = "machine learning"
|
||||||
|
host = "http://localhost:11434"
|
||||||
|
|
||||||
|
# Get embeddings without template
|
||||||
|
embeddings_no_template = compute_embeddings_ollama(
|
||||||
|
texts=[text], model_name=model_name, is_build=False, host=host, provider_options={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embeddings with template
|
||||||
|
embeddings_with_template = compute_embeddings_ollama(
|
||||||
|
texts=[text],
|
||||||
|
model_name=model_name,
|
||||||
|
is_build=False,
|
||||||
|
host=host,
|
||||||
|
provider_options={"prompt_template": "search_query: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embeddings should be different when template is applied
|
||||||
|
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
|
||||||
|
|
||||||
|
logger.info("✓ Ollama prompt template changes embedding values as expected")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama prompt template: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioSDK:
|
||||||
|
"""End-to-end tests for LM Studio SDK integration."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_model_listing(self):
|
||||||
|
"""Test that we can list models from LM Studio."""
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "data" in data
|
||||||
|
|
||||||
|
models = data["data"]
|
||||||
|
logger.info(f"✓ LM Studio models available: {len(models)}")
|
||||||
|
|
||||||
|
if models:
|
||||||
|
logger.info(f" First model: {models[0].get('id', 'unknown')}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"LM Studio API error: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
|
||||||
|
def test_lmstudio_sdk_context_length_detection(self):
|
||||||
|
"""Test context length detection via LM Studio SDK bridge (requires Node.js + SDK)."""
|
||||||
|
model_name = get_lmstudio_first_model()
|
||||||
|
if not model_name:
|
||||||
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||||
|
|
||||||
|
# SDK requires WebSocket URL (ws://)
|
||||||
|
context_length = _query_lmstudio_context_limit(
|
||||||
|
model_name=model_name, base_url="ws://localhost:1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
if context_length is None:
|
||||||
|
logger.warning(
|
||||||
|
"⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)"
|
||||||
|
)
|
||||||
|
pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable")
|
||||||
|
else:
|
||||||
|
assert context_length > 0
|
||||||
|
logger.info(
|
||||||
|
f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("_query_lmstudio_context_limit not implemented yet")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LM Studio SDK test error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaTokenLimit:
|
||||||
|
"""End-to-end tests for Ollama token limit discovery."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_ollama_token_limit_detection(self):
|
||||||
|
"""Test dynamic token limit detection from Ollama /api/show endpoint."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
test_model = embedding_models[0]
|
||||||
|
|
||||||
|
# Test token limit detection
|
||||||
|
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||||
|
|
||||||
|
assert limit > 0
|
||||||
|
logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test Ollama token detection: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridTokenLimit:
|
||||||
|
"""End-to-end tests for hybrid token limit discovery mechanism."""
|
||||||
|
|
||||||
|
def test_hybrid_discovery_registry_fallback(self):
|
||||||
|
"""Test fallback to static registry for known OpenAI models."""
|
||||||
|
# Use a known OpenAI model (should be in registry)
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
base_url="http://fake-server:9999", # Fake URL to force registry lookup
|
||||||
|
)
|
||||||
|
|
||||||
|
# text-embedding-3-small should have 8192 in registry
|
||||||
|
assert limit == 8192
|
||||||
|
logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens")
|
||||||
|
|
||||||
|
def test_hybrid_discovery_default_fallback(self):
|
||||||
|
"""Test fallback to safe default for completely unknown models."""
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="completely-unknown-model-xyz-12345",
|
||||||
|
base_url="http://fake-server:9999",
|
||||||
|
default=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get the specified default
|
||||||
|
assert limit == 512
|
||||||
|
logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
|
||||||
|
def test_hybrid_discovery_ollama_dynamic_first(self):
|
||||||
|
"""Test that Ollama models use dynamic discovery first."""
|
||||||
|
# Get any available embedding model
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
|
||||||
|
models = response.json().get("models", [])
|
||||||
|
|
||||||
|
embedding_models = []
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
base_name = name.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
|
||||||
|
embedding_models.append(name)
|
||||||
|
|
||||||
|
if not embedding_models:
|
||||||
|
pytest.skip("No embedding models available in Ollama")
|
||||||
|
|
||||||
|
test_model = embedding_models[0]
|
||||||
|
|
||||||
|
# Should query Ollama /api/show dynamically
|
||||||
|
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
|
||||||
|
|
||||||
|
assert limit > 0
|
||||||
|
logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Could not test hybrid Ollama discovery: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("INTEGRATION TEST SUITE - Real Service Testing")
|
||||||
|
print("=" * 70)
|
||||||
|
print("\nThese tests require live services:")
|
||||||
|
print(" • LM Studio: http://localhost:1234 (with embedding model loaded)")
|
||||||
|
print(" • [Optional] Ollama: http://localhost:11434")
|
||||||
|
print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests")
|
||||||
|
print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s")
|
||||||
|
print("=" * 70 + "\n")
|
||||||
808
tests/test_prompt_template_persistence.py
Normal file
808
tests/test_prompt_template_persistence.py
Normal file
@@ -0,0 +1,808 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for prompt template metadata persistence and reuse.
|
||||||
|
|
||||||
|
These tests verify the complete lifecycle of prompt template persistence:
|
||||||
|
1. Template is saved to .meta.json during index build
|
||||||
|
2. Template is automatically loaded during search operations
|
||||||
|
3. Template can be overridden with explicit flag during search
|
||||||
|
4. Template is reused during chat/ask operations
|
||||||
|
|
||||||
|
These are integration tests that:
|
||||||
|
- Use real file system with temporary directories
|
||||||
|
- Run actual build and search operations
|
||||||
|
- Inspect .meta.json file contents directly
|
||||||
|
- Mock embedding servers to avoid external dependencies
|
||||||
|
- Use small test codebases for fast execution
|
||||||
|
|
||||||
|
Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateMetadataPersistence:
|
||||||
|
"""Tests for prompt template storage in .meta.json during build."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
# Return dummy embeddings as numpy array
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that when build is run with embedding_options containing prompt_template,
|
||||||
|
the template value is saved to .meta.json file.
|
||||||
|
|
||||||
|
This is the core persistence requirement - templates must be saved to allow
|
||||||
|
reuse in subsequent search operations without re-specifying the flag.
|
||||||
|
|
||||||
|
Expected failure: .meta.json exists but doesn't contain embedding_options
|
||||||
|
with prompt_template, or the value is not persisted correctly.
|
||||||
|
"""
|
||||||
|
# Setup test data
|
||||||
|
index_path = temp_index_dir / "test_index.leann"
|
||||||
|
template = "search_document: "
|
||||||
|
|
||||||
|
# Build index with prompt template in embedding_options
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a simple document
|
||||||
|
builder.add_text("This is a test document for indexing")
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify .meta.json was created and contains the template
|
||||||
|
meta_path = temp_index_dir / "test_index.leann.meta.json"
|
||||||
|
assert meta_path.exists(), ".meta.json file should be created during build"
|
||||||
|
|
||||||
|
# Read and parse metadata
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
# Verify embedding_options exists in metadata
|
||||||
|
assert "embedding_options" in meta_data, (
|
||||||
|
"embedding_options should be saved to .meta.json when provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify prompt_template is in embedding_options
|
||||||
|
embedding_options = meta_data["embedding_options"]
|
||||||
|
assert "prompt_template" in embedding_options, (
|
||||||
|
"prompt_template should be saved within embedding_options"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the template value matches what we provided
|
||||||
|
assert embedding_options["prompt_template"] == template, (
|
||||||
|
f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that when no prompt template is provided during build,
|
||||||
|
.meta.json either doesn't have embedding_options or prompt_template key.
|
||||||
|
|
||||||
|
This ensures clean metadata without unnecessary keys when features aren't used.
|
||||||
|
|
||||||
|
Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template.
|
||||||
|
"""
|
||||||
|
index_path = temp_index_dir / "test_no_template.leann"
|
||||||
|
|
||||||
|
# Build index WITHOUT prompt template
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
# No embedding_options provided
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text("Document without template")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
meta_path = temp_index_dir / "test_no_template.leann.meta.json"
|
||||||
|
assert meta_path.exists()
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
# If embedding_options exists, it should not contain prompt_template
|
||||||
|
if "embedding_options" in meta_data:
|
||||||
|
embedding_options = meta_data["embedding_options"]
|
||||||
|
assert "prompt_template" not in embedding_options, (
|
||||||
|
"prompt_template should not be in metadata when not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateAutoLoadOnSearch:
|
||||||
|
"""Tests for automatic loading of prompt template during search operations.
|
||||||
|
|
||||||
|
NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search).
|
||||||
|
This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad
|
||||||
|
which uses simpler mocking and doesn't hang.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that searching an index built WITHOUT a prompt template
|
||||||
|
works correctly (backward compatibility).
|
||||||
|
|
||||||
|
The searcher should handle missing prompt_template gracefully.
|
||||||
|
|
||||||
|
Expected behavior: Search succeeds, no template is used.
|
||||||
|
"""
|
||||||
|
# Build index without template
|
||||||
|
index_path = temp_index_dir / "no_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
)
|
||||||
|
builder.add_text("Document without template")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mocks
|
||||||
|
mock_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Create searcher and search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
# Verify no template in embedding_options
|
||||||
|
assert "prompt_template" not in searcher.embedding_options, (
|
||||||
|
"Searcher should not have prompt_template when not in metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryPromptTemplateAutoLoad:
|
||||||
|
"""Tests for automatic loading of separate query_prompt_template during search (R2).
|
||||||
|
|
||||||
|
These tests verify the new two-template system where:
|
||||||
|
- build_prompt_template: Applied during index building
|
||||||
|
- query_prompt_template: Applied during search operations
|
||||||
|
|
||||||
|
Expected to FAIL in Red Phase (R2) because query template extraction
|
||||||
|
and application is not yet implemented in LeannSearcher.search().
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_compute_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
|
||||||
|
with patch("leann.embedding_compute.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that search() automatically loads and applies query_prompt_template from .meta.json.
|
||||||
|
|
||||||
|
Given: Index built with separate build_prompt_template and query_prompt_template
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding is computed with "query: my query" (query template applied)
|
||||||
|
|
||||||
|
This is the core R2 requirement - query templates must be auto-loaded and applied
|
||||||
|
during search without user intervention.
|
||||||
|
|
||||||
|
Expected failure: compute_embeddings called with raw "my query" instead of
|
||||||
|
"query: my query" because query template extraction is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with separate templates in new format
|
||||||
|
index_path = temp_index_dir / "query_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock to ignore build calls
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search with query
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
# Mock the backend search to avoid actual search
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {
|
||||||
|
"labels": [["test_id_0"]], # IDs (nested list for batch support)
|
||||||
|
"distances": [[0.9]], # Distances (nested list for batch support)
|
||||||
|
}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: compute_embeddings was called with query template applied
|
||||||
|
assert mock_compute_embeddings.called, "compute_embeddings should be called during search"
|
||||||
|
|
||||||
|
# Get the actual text passed to compute_embeddings
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0] # First positional arg (list of texts)
|
||||||
|
|
||||||
|
assert len(texts_arg) == 1, "Should compute embedding for one query"
|
||||||
|
assert texts_arg[0] == "query: my query", (
|
||||||
|
f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify backward compatibility with old single prompt_template format.
|
||||||
|
|
||||||
|
Given: Index with old format (single prompt_template, no query_prompt_template)
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding computed with "doc: my query" (old template applied)
|
||||||
|
|
||||||
|
This ensures indexes built with the old single-template system continue
|
||||||
|
to work correctly with the new search implementation.
|
||||||
|
|
||||||
|
Expected failure: Old template not recognized/applied because backward
|
||||||
|
compatibility logic is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with old single-template format
|
||||||
|
index_path = temp_index_dir / "old_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": "doc: "}, # Old format
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: Old template was applied
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "doc: my query", (
|
||||||
|
f"Old prompt_template should be applied for backward compatibility: "
|
||||||
|
f"expected 'doc: my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify backward compatibility when no template is present in .meta.json.
|
||||||
|
|
||||||
|
Given: Index with no template in .meta.json (very old indexes)
|
||||||
|
When: LeannSearcher.search("my query") is called
|
||||||
|
Then: Query embedding computed with "my query" (no template, raw query)
|
||||||
|
|
||||||
|
This ensures the most basic backward compatibility - indexes without
|
||||||
|
any template support continue to work as before.
|
||||||
|
|
||||||
|
Expected failure: May fail if default template is incorrectly applied,
|
||||||
|
or if missing template causes error.
|
||||||
|
"""
|
||||||
|
# Setup: Build index without any template
|
||||||
|
index_path = temp_index_dir / "no_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
# No embedding_options at all
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
searcher.search("my query", top_k=1, recompute_embeddings=False)
|
||||||
|
|
||||||
|
# Assert: No template applied (raw query)
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "my query", (
|
||||||
|
f"No template should be applied when missing from metadata: "
|
||||||
|
f"expected 'my query', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings):
|
||||||
|
"""
|
||||||
|
Verify that explicit provider_options can override stored query template.
|
||||||
|
|
||||||
|
Given: Index with query_prompt_template: "query: "
|
||||||
|
When: search() called with provider_options={"prompt_template": "override: "}
|
||||||
|
Then: Query embedding computed with "override: test" (override takes precedence)
|
||||||
|
|
||||||
|
This enables users to experiment with different query templates without
|
||||||
|
rebuilding the index, or to handle special query types differently.
|
||||||
|
|
||||||
|
Expected failure: provider_options parameter is accepted via **kwargs but
|
||||||
|
not used. Query embedding computed with raw "test" instead of "override: test"
|
||||||
|
because override logic is not implemented.
|
||||||
|
"""
|
||||||
|
# Setup: Build index with query template
|
||||||
|
index_path = temp_index_dir / "override_template.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={
|
||||||
|
"build_prompt_template": "doc: ",
|
||||||
|
"query_prompt_template": "query: ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.add_text("Test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_compute_embeddings.reset_mock()
|
||||||
|
|
||||||
|
# Act: Search with override
|
||||||
|
searcher = LeannSearcher(index_path=str(index_path))
|
||||||
|
|
||||||
|
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
|
||||||
|
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
|
||||||
|
|
||||||
|
# This should accept provider_options parameter
|
||||||
|
searcher.search(
|
||||||
|
"test",
|
||||||
|
top_k=1,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
provider_options={"prompt_template": "override: "},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Override template was applied
|
||||||
|
call_args = mock_compute_embeddings.call_args
|
||||||
|
texts_arg = call_args[0][0]
|
||||||
|
|
||||||
|
assert texts_arg[0] == "override: test", (
|
||||||
|
f"Override template should take precedence: "
|
||||||
|
f"expected 'override: test', got '{texts_arg[0]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateReuseInChat:
|
||||||
|
"""Tests for prompt template reuse in chat/ask operations."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings(self):
|
||||||
|
"""Mock compute_embeddings to return dummy embeddings."""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
yield mock_compute
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedding_server_manager(self):
|
||||||
|
"""Mock EmbeddingServerManager for chat tests."""
|
||||||
|
with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class:
|
||||||
|
mock_manager = Mock()
|
||||||
|
mock_manager.start_server.return_value = (True, 5557)
|
||||||
|
mock_manager_class.return_value = mock_manager
|
||||||
|
yield mock_manager
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def index_with_template(self, temp_index_dir, mock_embeddings):
|
||||||
|
"""Build an index with a prompt template."""
|
||||||
|
index_path = temp_index_dir / "chat_template_index.leann"
|
||||||
|
template = "document_query: "
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text("Test document for chat")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
return str(index_path), template
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateIntegrationWithEmbeddingModes:
|
||||||
|
"""Tests for prompt template compatibility with different embedding modes."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_dir(self):
|
||||||
|
"""Create temporary directory for test indexes."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mode,model,template,filename_prefix",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"openai",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
"Represent this for searching: ",
|
||||||
|
"openai_template",
|
||||||
|
),
|
||||||
|
("ollama", "nomic-embed-text", "search_query: ", "ollama_template"),
|
||||||
|
("sentence-transformers", "facebook/contriever", "query: ", "st_template"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_prompt_template_metadata_with_embedding_modes(
|
||||||
|
self, temp_index_dir, mode, model, template, filename_prefix
|
||||||
|
):
|
||||||
|
"""Verify prompt template is saved correctly across different embedding modes.
|
||||||
|
|
||||||
|
Tests that prompt templates are persisted to .meta.json for:
|
||||||
|
- OpenAI mode (primary use case)
|
||||||
|
- Ollama mode (also supports templates)
|
||||||
|
- Sentence-transformers mode (saved for forward compatibility)
|
||||||
|
|
||||||
|
Expected behavior: Template is saved to .meta.json regardless of mode.
|
||||||
|
"""
|
||||||
|
with patch("leann.api.compute_embeddings") as mock_compute:
|
||||||
|
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
|
||||||
|
|
||||||
|
index_path = temp_index_dir / f"{filename_prefix}.leann"
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model,
|
||||||
|
embedding_mode=mode,
|
||||||
|
embedding_options={"prompt_template": template},
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_text(f"{mode.capitalize()} test document")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
|
||||||
|
assert meta_data["embedding_mode"] == mode
|
||||||
|
# Template should be saved for all modes (even if not used by some)
|
||||||
|
if "embedding_options" in meta_data:
|
||||||
|
assert meta_data["embedding_options"]["prompt_template"] == template
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryTemplateApplicationInComputeEmbedding:
|
||||||
|
"""Tests for query template application in compute_query_embedding() (Bug Fix).
|
||||||
|
|
||||||
|
These tests verify that query templates are applied consistently in BOTH
|
||||||
|
code paths (server and fallback) when computing query embeddings.
|
||||||
|
|
||||||
|
This addresses the bug where query templates were only applied in the
|
||||||
|
fallback path, not when using the embedding server (the default path).
|
||||||
|
|
||||||
|
Bug Context:
|
||||||
|
- Issue: Query templates were stored in metadata but only applied during
|
||||||
|
fallback (direct) computation, not when using embedding server
|
||||||
|
- Fix: Move template application to BEFORE any computation path in
|
||||||
|
compute_query_embedding() (searcher_base.py:107-110)
|
||||||
|
- Impact: Critical for models like EmbeddingGemma that require task-specific
|
||||||
|
templates for optimal performance
|
||||||
|
|
||||||
|
These tests ensure the fix works correctly and prevent regression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_index_with_template(self):
|
||||||
|
"""Create a temporary index with query template in metadata"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
index_dir = Path(tmpdir)
|
||||||
|
index_file = index_dir / "test.leann"
|
||||||
|
meta_file = index_dir / "test.leann.meta.json"
|
||||||
|
|
||||||
|
# Create minimal metadata with query template
|
||||||
|
metadata = {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": "hnsw",
|
||||||
|
"embedding_model": "text-embedding-embeddinggemma-300m-qat",
|
||||||
|
"dimensions": 768,
|
||||||
|
"embedding_mode": "openai",
|
||||||
|
"backend_kwargs": {
|
||||||
|
"graph_degree": 32,
|
||||||
|
"complexity": 64,
|
||||||
|
"distance_metric": "cosine",
|
||||||
|
},
|
||||||
|
"embedding_options": {
|
||||||
|
"base_url": "http://localhost:1234/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"build_prompt_template": "title: none | text: ",
|
||||||
|
"query_prompt_template": "task: search result | query: ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||||
|
|
||||||
|
# Create minimal HNSW index file (empty is okay for this test)
|
||||||
|
index_file.write_bytes(b"")
|
||||||
|
|
||||||
|
yield str(index_file)
|
||||||
|
|
||||||
|
def test_query_template_applied_in_fallback_path(self, temp_index_with_template):
|
||||||
|
"""Test that query template is applied when using fallback (direct) path"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
# Create a concrete implementation for testing
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
# Mock compute_embeddings to capture the query text
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
# Call compute_query_embedding with template (fallback path)
|
||||||
|
result = searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False, # Force fallback path
|
||||||
|
query_template="task: search result | query: ",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "task: search result | query: vector database"
|
||||||
|
assert result.shape == (1, 768)
|
||||||
|
|
||||||
|
def test_query_template_applied_in_server_path(self, temp_index_with_template):
|
||||||
|
"""Test that query template is applied when using server path"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
# Create a concrete implementation for testing
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
# Mock the server methods to capture the query text
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_ensure_server_running(passages_file, port):
|
||||||
|
return port
|
||||||
|
|
||||||
|
def mock_compute_embedding_via_server(chunks, port):
|
||||||
|
captured_queries.extend(chunks)
|
||||||
|
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||||
|
|
||||||
|
searcher._ensure_server_running = mock_ensure_server_running
|
||||||
|
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||||
|
|
||||||
|
# Call compute_query_embedding with template (server path)
|
||||||
|
result = searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=True, # Use server path
|
||||||
|
query_template="task: search result | query: ",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify template was applied BEFORE calling server
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "task: search result | query: vector database"
|
||||||
|
assert result.shape == (1, 768)
|
||||||
|
|
||||||
|
def test_query_template_without_template_parameter(self, temp_index_with_template):
|
||||||
|
"""Test that query is unchanged when no template is provided"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template=None, # No template
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify query is unchanged
|
||||||
|
assert len(captured_queries) == 1
|
||||||
|
assert captured_queries[0] == "vector database"
|
||||||
|
|
||||||
|
def test_query_template_consistency_between_paths(self, temp_index_with_template):
|
||||||
|
"""Test that both paths apply template identically"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
query_template = "task: search result | query: "
|
||||||
|
original_query = "vector database"
|
||||||
|
|
||||||
|
# Capture queries from fallback path
|
||||||
|
fallback_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
fallback_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query=original_query,
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template=query_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture queries from server path
|
||||||
|
server_queries = []
|
||||||
|
|
||||||
|
def mock_ensure_server_running(passages_file, port):
|
||||||
|
return port
|
||||||
|
|
||||||
|
def mock_compute_embedding_via_server(chunks, port):
|
||||||
|
server_queries.extend(chunks)
|
||||||
|
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||||
|
|
||||||
|
searcher._ensure_server_running = mock_ensure_server_running
|
||||||
|
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
|
||||||
|
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query=original_query,
|
||||||
|
use_server_if_available=True,
|
||||||
|
query_template=query_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify both paths produced identical templated queries
|
||||||
|
assert len(fallback_queries) == 1
|
||||||
|
assert len(server_queries) == 1
|
||||||
|
assert fallback_queries[0] == server_queries[0]
|
||||||
|
assert fallback_queries[0] == f"{query_template}{original_query}"
|
||||||
|
|
||||||
|
def test_query_template_with_empty_string(self, temp_index_with_template):
|
||||||
|
"""Test behavior with empty template string"""
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
class TestSearcher(BaseSearcher):
|
||||||
|
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||||
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
|
searcher = object.__new__(TestSearcher)
|
||||||
|
searcher.index_path = Path(temp_index_with_template)
|
||||||
|
searcher.index_dir = searcher.index_path.parent
|
||||||
|
|
||||||
|
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
|
||||||
|
with open(meta_file) as f:
|
||||||
|
searcher.meta = json.load(f)
|
||||||
|
|
||||||
|
searcher.embedding_model = searcher.meta["embedding_model"]
|
||||||
|
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
searcher.embedding_options = searcher.meta.get("embedding_options", {})
|
||||||
|
|
||||||
|
captured_queries = []
|
||||||
|
|
||||||
|
def mock_compute_embeddings(texts, model, mode, provider_options=None):
|
||||||
|
captured_queries.extend(texts)
|
||||||
|
return np.random.rand(len(texts), 768).astype(np.float32)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
|
||||||
|
):
|
||||||
|
searcher.compute_query_embedding(
|
||||||
|
query="vector database",
|
||||||
|
use_server_if_available=False,
|
||||||
|
query_template="", # Empty string
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty string is falsy, so no template should be applied
|
||||||
|
assert captured_queries[0] == "vector database"
|
||||||
@@ -266,3 +266,378 @@ class TestTokenTruncation:
|
|||||||
assert result_tokens <= target_tokens, (
|
assert result_tokens <= target_tokens, (
|
||||||
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioHybridDiscovery:
|
||||||
|
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
|
||||||
|
|
||||||
|
These tests verify that get_model_token_limit() properly integrates with
|
||||||
|
the LM Studio SDK bridge for dynamic token limit discovery. The integration
|
||||||
|
should:
|
||||||
|
|
||||||
|
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
|
||||||
|
2. Convert HTTP URLs to WebSocket format for SDK queries
|
||||||
|
3. Query LM Studio SDK and use discovered limit
|
||||||
|
4. Fall back to registry when SDK returns None
|
||||||
|
5. Execute AFTER Ollama detection but BEFORE registry fallback
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
|
||||||
|
"""Verify LM Studio SDK query succeeds and returns detected limit.
|
||||||
|
|
||||||
|
When a LM Studio base_url is detected and the SDK query succeeds,
|
||||||
|
get_model_token_limit() should return the dynamically discovered
|
||||||
|
context length without falling back to the registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock _query_lmstudio_context_limit to return successful SDK query
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
# Verify WebSocket URL was passed (not HTTP)
|
||||||
|
assert base_url.startswith("ws://"), (
|
||||||
|
f"Should convert HTTP to WebSocket format, got: {base_url}"
|
||||||
|
)
|
||||||
|
return 8192 # Successful SDK query
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with HTTP URL that should be converted to WebSocket
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="custom-model", base_url="http://localhost:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert limit == 8192, "Should return limit from LM Studio SDK query"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
|
||||||
|
"""Verify fallback to registry when LM Studio SDK returns None.
|
||||||
|
|
||||||
|
When LM Studio SDK query fails (returns None), get_model_token_limit()
|
||||||
|
should fall back to the EMBEDDING_MODEL_LIMITS registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock _query_lmstudio_context_limit to return None (SDK failure)
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
return None # SDK query failed
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with known model that exists in registry
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fall back to registry value
|
||||||
|
assert limit == 2048, "Should fall back to registry when SDK returns None"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
|
||||||
|
"""Verify detection of LM Studio via port 1234.
|
||||||
|
|
||||||
|
get_model_token_limit() should recognize port 1234 as a LM Studio
|
||||||
|
server and attempt SDK query, regardless of hostname.
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with port 1234 (default LM Studio port)
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
|
||||||
|
|
||||||
|
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
|
||||||
|
assert limit == 4096, "Should return SDK query result"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_url,expected_limit,keyword",
|
||||||
|
[
|
||||||
|
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
|
||||||
|
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_model_token_limit_lmstudio_url_keyword_detection(
|
||||||
|
self, monkeypatch, test_url, expected_limit, keyword
|
||||||
|
):
|
||||||
|
"""Verify detection of LM Studio via keywords in URL.
|
||||||
|
|
||||||
|
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
|
||||||
|
in the URL as indicating a LM Studio server.
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return expected_limit
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
|
||||||
|
|
||||||
|
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
|
||||||
|
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_url,expected_protocol,expected_host",
|
||||||
|
[
|
||||||
|
("http://localhost:1234/v1", "ws://", "localhost:1234"),
|
||||||
|
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_model_token_limit_protocol_conversion(
|
||||||
|
self, monkeypatch, input_url, expected_protocol, expected_host
|
||||||
|
):
|
||||||
|
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
|
||||||
|
|
||||||
|
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
|
||||||
|
1. Convert 'http://' to 'ws://'
|
||||||
|
2. Convert 'https://' to 'wss://'
|
||||||
|
3. Remove '/v1' or other path suffixes (SDK expects base URL)
|
||||||
|
4. Preserve host and port
|
||||||
|
"""
|
||||||
|
conversions_tested = []
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
conversions_tested.append(base_url)
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_model_token_limit(model_name="test-model", base_url=input_url)
|
||||||
|
|
||||||
|
# Verify conversion happened
|
||||||
|
assert len(conversions_tested) == 1, "Should have called SDK query once"
|
||||||
|
assert conversions_tested[0].startswith(expected_protocol), (
|
||||||
|
f"Should convert to {expected_protocol}"
|
||||||
|
)
|
||||||
|
assert expected_host in conversions_tested[0], (
|
||||||
|
f"Should preserve host and port: {expected_host}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
|
||||||
|
"""Verify LM Studio detection happens AFTER Ollama detection.
|
||||||
|
|
||||||
|
The hybrid discovery order should be:
|
||||||
|
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
|
||||||
|
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
|
||||||
|
3. Registry fallback
|
||||||
|
|
||||||
|
If both Ollama and LM Studio patterns match, Ollama should take precedence.
|
||||||
|
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
|
||||||
|
"""
|
||||||
|
ollama_called = False
|
||||||
|
lmstudio_called = False
|
||||||
|
|
||||||
|
def mock_query_ollama(model_name, base_url):
|
||||||
|
nonlocal ollama_called
|
||||||
|
ollama_called = True
|
||||||
|
return 2048 # Ollama query succeeds
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal lmstudio_called
|
||||||
|
lmstudio_called = True
|
||||||
|
return None # Should not be reached if Ollama succeeds
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_ollama_context_limit",
|
||||||
|
mock_query_ollama,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with Ollama URL
|
||||||
|
limit = get_model_token_limit(
|
||||||
|
model_name="test-model", base_url="http://localhost:11434/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ollama_called, "Should attempt Ollama query first"
|
||||||
|
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
|
||||||
|
assert limit == 2048, "Should return Ollama result"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
|
||||||
|
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
|
||||||
|
|
||||||
|
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
|
||||||
|
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
|
||||||
|
"""
|
||||||
|
lmstudio_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal lmstudio_called
|
||||||
|
lmstudio_called = True
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with non-LM Studio URLs
|
||||||
|
test_cases = [
|
||||||
|
"http://localhost:8080/v1", # Different port
|
||||||
|
"http://openai.example.com/v1", # Different service
|
||||||
|
"http://localhost:3000/v1", # Another port
|
||||||
|
]
|
||||||
|
|
||||||
|
for base_url in test_cases:
|
||||||
|
lmstudio_called = False # Reset for each test
|
||||||
|
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
|
||||||
|
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
|
||||||
|
"""Verify LM Studio detection is case-insensitive for keywords.
|
||||||
|
|
||||||
|
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
|
||||||
|
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
|
||||||
|
"""
|
||||||
|
query_called = False
|
||||||
|
|
||||||
|
def mock_query_lmstudio(model_name, base_url):
|
||||||
|
nonlocal query_called
|
||||||
|
query_called = True
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
||||||
|
mock_query_lmstudio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test various case variations
|
||||||
|
test_cases = [
|
||||||
|
"http://LMStudio.local:8080/v1",
|
||||||
|
"http://LMSTUDIO.example.com/v1",
|
||||||
|
"http://LmStudio.local/v1",
|
||||||
|
"http://api.LM.STUDIO:5000/v1",
|
||||||
|
]
|
||||||
|
|
||||||
|
for base_url in test_cases:
|
||||||
|
query_called = False # Reset for each test
|
||||||
|
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
|
||||||
|
assert query_called, f"Should detect LM Studio in URL: {base_url}"
|
||||||
|
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenLimitCaching:
|
||||||
|
"""Tests for token limit caching to prevent repeated SDK/API calls.
|
||||||
|
|
||||||
|
Caching prevents duplicate SDK/API calls within the same Python process,
|
||||||
|
which is important because:
|
||||||
|
1. LM Studio SDK load() can load duplicate model instances
|
||||||
|
2. Ollama /api/show queries add latency
|
||||||
|
3. Registry lookups are pure overhead
|
||||||
|
|
||||||
|
Cache is process-scoped and resets between leann build invocations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Clear cache before each test."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
_token_limit_cache.clear()
|
||||||
|
|
||||||
|
def test_registry_lookup_is_cached(self):
|
||||||
|
"""Verify that registry lookups are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# First call
|
||||||
|
limit1 = get_model_token_limit("text-embedding-3-small")
|
||||||
|
assert limit1 == 8192
|
||||||
|
|
||||||
|
# Verify it's in cache
|
||||||
|
cache_key = ("text-embedding-3-small", "")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 8192
|
||||||
|
|
||||||
|
# Second call should use cache
|
||||||
|
limit2 = get_model_token_limit("text-embedding-3-small")
|
||||||
|
assert limit2 == 8192
|
||||||
|
|
||||||
|
def test_default_fallback_is_cached(self):
|
||||||
|
"""Verify that default fallbacks are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# First call with unknown model
|
||||||
|
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||||
|
assert limit1 == 512
|
||||||
|
|
||||||
|
# Verify it's in cache
|
||||||
|
cache_key = ("unknown-model-xyz", "")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 512
|
||||||
|
|
||||||
|
# Second call should use cache
|
||||||
|
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
|
||||||
|
assert limit2 == 512
|
||||||
|
|
||||||
|
def test_different_urls_create_separate_cache_entries(self):
|
||||||
|
"""Verify that different base_urls create separate cache entries."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# Same model, different URLs
|
||||||
|
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
|
||||||
|
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
|
||||||
|
|
||||||
|
# Both should find the model in registry (2048)
|
||||||
|
assert limit1 == 2048
|
||||||
|
assert limit2 == 2048
|
||||||
|
|
||||||
|
# But they should be separate cache entries
|
||||||
|
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
|
||||||
|
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
|
||||||
|
|
||||||
|
assert cache_key1 in _token_limit_cache
|
||||||
|
assert cache_key2 in _token_limit_cache
|
||||||
|
assert len(_token_limit_cache) == 2
|
||||||
|
|
||||||
|
def test_cache_prevents_repeated_lookups(self):
|
||||||
|
"""Verify that cache prevents repeated registry/API lookups."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
model_name = "text-embedding-ada-002"
|
||||||
|
|
||||||
|
# First call - should add to cache
|
||||||
|
assert len(_token_limit_cache) == 0
|
||||||
|
limit1 = get_model_token_limit(model_name)
|
||||||
|
|
||||||
|
cache_size_after_first = len(_token_limit_cache)
|
||||||
|
assert cache_size_after_first == 1
|
||||||
|
|
||||||
|
# Multiple subsequent calls - cache size should not change
|
||||||
|
for _ in range(5):
|
||||||
|
limit = get_model_token_limit(model_name)
|
||||||
|
assert limit == limit1
|
||||||
|
assert len(_token_limit_cache) == cache_size_after_first
|
||||||
|
|
||||||
|
def test_versioned_model_names_cached_correctly(self):
|
||||||
|
"""Verify that versioned model names (e.g., model:tag) are cached."""
|
||||||
|
from leann.embedding_compute import _token_limit_cache
|
||||||
|
|
||||||
|
# Model with version tag
|
||||||
|
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
|
||||||
|
assert limit == 2048
|
||||||
|
|
||||||
|
# Should be cached with full name including version
|
||||||
|
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
|
||||||
|
assert cache_key in _token_limit_cache
|
||||||
|
assert _token_limit_cache[cache_key] == 2048
|
||||||
|
|||||||
Reference in New Issue
Block a user