Feature/prompt templates and lmstudio sdk (#171)

* Add prompt template support and LM Studio SDK integration

Features:

- Prompt template support for embedding models (via --embedding-prompt-template)

- LM Studio SDK integration for automatic context length detection

- Hybrid token limit discovery (Ollama → LM Studio → Registry → Default)

- Client-side token truncation to prevent silent failures

- Automatic persistence of embedding_options to .meta.json

Implementation:

- Added _query_lmstudio_context_limit() with Node.js subprocess bridge

- Modified compute_embeddings_openai() to apply prompt templates before truncation

- Extended CLI with --embedding-prompt-template flag for build and search

- URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords)

- HTTP→WebSocket URL conversion for SDK compatibility

Tests:

- 60 passing tests across 5 test files

- Comprehensive coverage of prompt templates, LM Studio integration, and token handling

- Parametrized tests for maintainability and clarity

* Add integration tests and fix LM Studio SDK bridge

Features:
- End-to-end integration tests for prompt template with EmbeddingGemma
- Integration tests for hybrid token limit discovery mechanism
- Tests verify real-world functionality with live services (LM Studio, Ollama)

Fixes:
- LM Studio SDK bridge now uses client.embedding.load() for embedding models
- Fixed NODE_PATH resolution to include npm global modules
- Fixed integration test to use WebSocket URL (ws://) for SDK bridge

Tests:
- test_prompt_template_e2e.py: 8 integration tests covering:
  - Prompt template prepending with LM Studio (EmbeddingGemma)
  - LM Studio SDK bridge for context length detection
  - Ollama dynamic token limit detection
  - Hybrid discovery fallback mechanism (registry, default)
- All tests marked with @pytest.mark.integration for selective execution
- Tests gracefully skip when services unavailable

Documentation:
- Updated tests/README.md with integration test section
- Added prerequisites and running instructions
- Documented that prompt templates are ONLY for EmbeddingGemma
- Added integration marker to pyproject.toml

Test Results:
- All 8 integration tests passing with live services
- Confirmed prompt templates work correctly with EmbeddingGemma
- Verified LM Studio SDK bridge auto-detects context length (2048)
- Validated hybrid token limit discovery across all backends

* Add prompt template support to Ollama mode

Extends prompt template functionality from OpenAI mode to Ollama for backend consistency.

Changes:
- Add provider_options parameter to compute_embeddings_ollama()
- Apply prompt template before token truncation (lines 1005-1011)
- Pass provider_options through compute_embeddings() call chain

Tests:
- test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama
- test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template
- Both tests pass with live Ollama service (2/2 passing)

Usage:
leann build --embedding-mode ollama --embedding-prompt-template "query: " ...

* Fix LM Studio SDK bridge to respect JIT auto-evict settings

Problem: SDK bridge called client.embedding.load() which loaded models into
LM Studio memory and bypassed JIT auto-evict settings, causing duplicate
model instances to accumulate.

Root cause analysis (from Perplexity research):
- Explicit SDK load() commands are treated as "pinned" models
- JIT auto-evict only applies to models loaded reactively via API requests
- SDK-loaded models remain in memory until explicitly unloaded

Solutions implemented:

1. Add model.unload() after metadata query (line 243)
   - Load model temporarily to get context length
   - Unload immediately to hand control back to JIT system
   - Subsequent API requests trigger JIT load with auto-evict

2. Add token limit caching to prevent repeated SDK calls
   - Cache discovered limits in _token_limit_cache dict (line 48)
   - Key: (model_name, base_url), Value: token_limit
   - Prevents duplicate load/unload cycles within same process
   - Cache shared across all discovery methods (Ollama, SDK, registry)

Tests:
- TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py)
- Manual testing confirmed no duplicate models in LM Studio after fix
- All existing tests pass

Impact:
- Respects user's LM Studio JIT and auto-evict settings
- Reduces model memory footprint
- Faster subsequent builds (cached limits)

* Document prompt template and LM Studio SDK features

Added comprehensive documentation for new optional embedding features:

Configuration Guide (docs/configuration-guide.md):
- New section: "Optional Embedding Features"
- Task-Specific Prompt Templates subsection:
  - Explains EmbeddingGemma use case with document/query prompts
  - CLI and Python API examples
  - Clear warnings about compatible vs incompatible models
  - References to GitHub issue #155 and HuggingFace blog
- LM Studio Auto-Detection subsection:
  - Prerequisites (Node.js + @lmstudio/sdk)
  - How auto-detection works (4-step process)
  - Benefits and optional nature clearly stated

FAQ (docs/faq.md):
- FAQ #2: When should I use prompt templates?
  - DO/DON'T guidance with examples
  - Links to detailed configuration guide
- FAQ #3: Why is LM Studio loading multiple copies?
  - Explains the JIT auto-evict fix
  - Troubleshooting steps if still seeing issues
- FAQ #4: Do I need Node.js and @lmstudio/sdk?
  - Clarifies it's completely optional
  - Lists benefits if installed
  - Installation instructions

Cross-references between documents for easy navigation between quick reference and detailed guides.

* Add separate build/query template support for task-specific models

Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility.

* Consolidate prompt template tests from 44 to 37 tests

Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing.

* Fix query template application in compute_query_embedding

Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches.

Changes:

- Move template application to before any computation path (searcher_base.py:109-110)

- Add comprehensive tests for both server and fallback paths

- Consolidate tests into test_prompt_template_persistence.py

Tests verify:

- Template applied when using embedding server

- Template applied in fallback path

- Consistent behavior between both paths

* Apply ruff formatting and fix linting issues

- Remove unused imports

- Fix import ordering

- Remove unused variables

- Apply code formatting

* Fix CI test failures: mock OPENAI_API_KEY in tests

Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
This commit is contained in:
ww26
2025-11-14 18:25:17 -05:00
committed by GitHub
parent a63550944b
commit 1ef9cba7de
15 changed files with 3095 additions and 15 deletions

View File

@@ -916,6 +916,7 @@ class LeannSearcher:
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
provider_options: Optional[dict[str, Any]] = None,
**kwargs,
) -> list[SearchResult]:
"""
@@ -979,10 +980,24 @@ class LeannSearcher:
start_time = time.time()
# Extract query template from stored embedding_options with fallback chain:
# 1. Check provider_options override (highest priority)
# 2. Check query_prompt_template (new format)
# 3. Check prompt_template (old format for backward compat)
# 4. None (no template)
query_template = None
if provider_options and "prompt_template" in provider_options:
query_template = provider_options["prompt_template"]
elif "query_prompt_template" in self.embedding_options:
query_template = self.embedding_options["query_prompt_template"]
elif "prompt_template" in self.embedding_options:
query_template = self.embedding_options["prompt_template"]
query_embedding = self.backend_impl.compute_query_embedding(
query,
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
query_template=query_template,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time

View File

@@ -144,6 +144,18 @@ Examples:
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
build_parser.add_argument(
"--embedding-prompt-template",
type=str,
default=None,
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
)
build_parser.add_argument(
"--query-prompt-template",
type=str,
default=None,
help="Prompt template for queries (different from build template for task-specific models)",
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index"
)
@@ -260,6 +272,12 @@ Examples:
action="store_true",
help="Display file paths and metadata in search results",
)
search_parser.add_argument(
"--embedding-prompt-template",
type=str,
default=None,
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
@@ -1398,6 +1416,14 @@ Examples:
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
if args.query_prompt_template:
# New format: separate templates
if args.embedding_prompt_template:
embedding_options["build_prompt_template"] = args.embedding_prompt_template
embedding_options["query_prompt_template"] = args.query_prompt_template
elif args.embedding_prompt_template:
# Old format: single template (backward compat)
embedding_options["prompt_template"] = args.embedding_prompt_template
builder = LeannBuilder(
backend_name=args.backend_name,

View File

@@ -4,8 +4,10 @@ Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import json
import logging
import os
import subprocess
import time
from typing import Any, Optional
@@ -40,6 +42,11 @@ EMBEDDING_MODEL_LIMITS = {
"text-embedding-ada-002": 8192,
}
# Runtime cache for dynamically discovered token limits
# Key: (model_name, base_url), Value: token_limit
# Prevents repeated SDK/API calls for the same model
_token_limit_cache: dict[tuple[str, str], int] = {}
def get_model_token_limit(
model_name: str,
@@ -49,6 +56,7 @@ def get_model_token_limit(
"""
Get token limit for a given embedding model.
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
Caches discovered limits to prevent repeated API/SDK calls.
Args:
model_name: Name of the embedding model
@@ -58,12 +66,33 @@ def get_model_token_limit(
Returns:
Token limit for the model in tokens
"""
# Check cache first to avoid repeated SDK/API calls
cache_key = (model_name, base_url or "")
if cache_key in _token_limit_cache:
cached_limit = _token_limit_cache[cache_key]
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
return cached_limit
# Try Ollama dynamic discovery if base_url provided
if base_url:
# Detect Ollama servers by port or "ollama" in URL
if "11434" in base_url or "ollama" in base_url.lower():
limit = _query_ollama_context_limit(model_name, base_url)
if limit:
_token_limit_cache[cache_key] = limit
return limit
# Try LM Studio SDK discovery
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
# Convert HTTP to WebSocket URL
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
# Remove /v1 suffix if present
if ws_url.endswith("/v1"):
ws_url = ws_url[:-3]
limit = _query_lmstudio_context_limit(model_name, ws_url)
if limit:
_token_limit_cache[cache_key] = limit
return limit
# Fallback to known model registry with version handling (from PR #154)
@@ -72,19 +101,25 @@ def get_model_token_limit(
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[model_name]
limit = EMBEDDING_MODEL_LIMITS[model_name]
_token_limit_cache[cache_key] = limit
return limit
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[base_model_name]
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
_token_limit_cache[cache_key] = limit
return limit
# Check partial matches for common patterns
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
return limit
_token_limit_cache[cache_key] = registry_limit
return registry_limit
# Default fallback
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
_token_limit_cache[cache_key] = default
return default
@@ -185,6 +220,91 @@ def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]
return None
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query LM Studio SDK for model context length via Node.js subprocess.
Args:
model_name: Name of the LM Studio model
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
Returns:
Context limit in tokens if found, None otherwise
"""
# Inline JavaScript using @lmstudio/sdk
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
js_code = f"""
const {{ LMStudioClient }} = require('@lmstudio/sdk');
(async () => {{
try {{
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
const contextLength = await model.getContextLength();
await model.unload(); // Unload immediately to respect JIT auto-evict settings
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
}} catch (error) {{
console.error(JSON.stringify({{ error: error.message }}));
process.exit(1);
}}
}})();
"""
try:
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
env = os.environ.copy()
# Try to get npm global root (works with nvm, brew node, etc.)
try:
npm_root = subprocess.run(
["npm", "root", "-g"],
capture_output=True,
text=True,
timeout=5,
)
if npm_root.returncode == 0:
global_modules = npm_root.stdout.strip()
# Append to existing NODE_PATH if present
existing_node_path = env.get("NODE_PATH", "")
env["NODE_PATH"] = (
f"{global_modules}:{existing_node_path}"
if existing_node_path
else global_modules
)
except Exception:
# If npm not available, continue with existing NODE_PATH
pass
result = subprocess.run(
["node", "-e", js_code],
capture_output=True,
text=True,
timeout=10,
env=env,
)
if result.returncode != 0:
logger.debug(f"LM Studio SDK error: {result.stderr}")
return None
data = json.loads(result.stdout)
context_length = data.get("contextLength")
if context_length and context_length > 0:
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
return context_length
except FileNotFoundError:
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
except subprocess.TimeoutExpired:
logger.debug("LM Studio SDK query timeout")
except json.JSONDecodeError:
logger.debug("LM Studio SDK returned invalid JSON")
except Exception as e:
logger.debug(f"LM Studio SDK query failed: {e}")
return None
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
@@ -232,6 +352,7 @@ def compute_embeddings(
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
provider_options=provider_options,
)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
@@ -241,6 +362,7 @@ def compute_embeddings(
model_name,
is_build=is_build,
host=provider_options.get("host"),
provider_options=provider_options,
)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
@@ -579,6 +701,7 @@ def compute_embeddings_openai(
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
@@ -597,26 +720,37 @@ def compute_embeddings_openai(
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
resolved_base_url = resolve_openai_base_url(base_url)
resolved_api_key = resolve_openai_api_key(api_key)
# Extract base_url and api_key from provider_options if not provided directly
provider_options = provider_options or {}
effective_base_url = base_url or provider_options.get("base_url")
effective_api_key = api_key or provider_options.get("api_key")
resolved_base_url = resolve_openai_base_url(effective_base_url)
resolved_api_key = resolve_openai_api_key(effective_api_key)
if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
# Create OpenAI client
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
print(f"len of texts: {len(texts)}")
# Apply prompt template if provided
prompt_template = provider_options.get("prompt_template")
if prompt_template:
logger.warning(f"Applying prompt template: '{prompt_template}'")
texts = [f"{prompt_template}{text}" for text in texts]
# Query token limit and apply truncation
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
texts = truncate_to_token_limit(texts, token_limit)
# OpenAI has limits on batch size and input length
max_batch_size = 800 # Conservative batch size because the token limit is 300K
all_embeddings = []
@@ -647,7 +781,15 @@ def compute_embeddings_openai(
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
# Verify we got the expected number of embeddings
if len(batch_embeddings) != len(batch_texts):
logger.warning(
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
)
# Only take the number of embeddings that match the batch size
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
@@ -737,6 +879,7 @@ def compute_embeddings_ollama(
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Compute embeddings using Ollama API with true batch processing.
@@ -749,6 +892,7 @@ def compute_embeddings_ollama(
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (defaults to environment or http://localhost:11434)
provider_options: Optional provider-specific options (e.g., prompt_template)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -885,6 +1029,14 @@ def compute_embeddings_ollama(
logger.info(f"Using batch size: {batch_size} for true batch processing")
# Apply prompt template if provided
provider_options = provider_options or {}
prompt_template = provider_options.get("prompt_template")
if prompt_template:
logger.warning(f"Applying prompt template: '{prompt_template}'")
texts = [f"{prompt_template}{text}" for text in texts]
# Get model token limit and apply truncation before batching
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
logger.info(f"Model '{model_name}' token limit: {token_limit}")

View File

@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
query_template: Optional[str] = None,
) -> np.ndarray:
"""Compute embedding for a query string
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
query_template: Optional prompt template to prepend to query
Returns:
Query embedding as numpy array with shape (1, D)

View File

@@ -90,6 +90,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
query: str,
use_server_if_available: bool = True,
zmq_port: int = 5557,
query_template: Optional[str] = None,
) -> np.ndarray:
"""
Compute embedding for a query string.
@@ -98,10 +99,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
query_template: Optional prompt template to prepend to query
Returns:
Query embedding as numpy array
"""
# Apply query template BEFORE any computation path
# This ensures template is applied consistently for both server and fallback paths
if query_template:
query = f"{query_template}{query}"
# Try to use embedding server if available and requested
if use_server_if_available:
try: