Feature/prompt templates and lmstudio sdk (#171)
* Add prompt template support and LM Studio SDK integration Features: - Prompt template support for embedding models (via --embedding-prompt-template) - LM Studio SDK integration for automatic context length detection - Hybrid token limit discovery (Ollama → LM Studio → Registry → Default) - Client-side token truncation to prevent silent failures - Automatic persistence of embedding_options to .meta.json Implementation: - Added _query_lmstudio_context_limit() with Node.js subprocess bridge - Modified compute_embeddings_openai() to apply prompt templates before truncation - Extended CLI with --embedding-prompt-template flag for build and search - URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords) - HTTP→WebSocket URL conversion for SDK compatibility Tests: - 60 passing tests across 5 test files - Comprehensive coverage of prompt templates, LM Studio integration, and token handling - Parametrized tests for maintainability and clarity * Add integration tests and fix LM Studio SDK bridge Features: - End-to-end integration tests for prompt template with EmbeddingGemma - Integration tests for hybrid token limit discovery mechanism - Tests verify real-world functionality with live services (LM Studio, Ollama) Fixes: - LM Studio SDK bridge now uses client.embedding.load() for embedding models - Fixed NODE_PATH resolution to include npm global modules - Fixed integration test to use WebSocket URL (ws://) for SDK bridge Tests: - test_prompt_template_e2e.py: 8 integration tests covering: - Prompt template prepending with LM Studio (EmbeddingGemma) - LM Studio SDK bridge for context length detection - Ollama dynamic token limit detection - Hybrid discovery fallback mechanism (registry, default) - All tests marked with @pytest.mark.integration for selective execution - Tests gracefully skip when services unavailable Documentation: - Updated tests/README.md with integration test section - Added prerequisites and running instructions - Documented that prompt templates are ONLY for EmbeddingGemma - Added integration marker to pyproject.toml Test Results: - All 8 integration tests passing with live services - Confirmed prompt templates work correctly with EmbeddingGemma - Verified LM Studio SDK bridge auto-detects context length (2048) - Validated hybrid token limit discovery across all backends * Add prompt template support to Ollama mode Extends prompt template functionality from OpenAI mode to Ollama for backend consistency. Changes: - Add provider_options parameter to compute_embeddings_ollama() - Apply prompt template before token truncation (lines 1005-1011) - Pass provider_options through compute_embeddings() call chain Tests: - test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama - test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template - Both tests pass with live Ollama service (2/2 passing) Usage: leann build --embedding-mode ollama --embedding-prompt-template "query: " ... * Fix LM Studio SDK bridge to respect JIT auto-evict settings Problem: SDK bridge called client.embedding.load() which loaded models into LM Studio memory and bypassed JIT auto-evict settings, causing duplicate model instances to accumulate. Root cause analysis (from Perplexity research): - Explicit SDK load() commands are treated as "pinned" models - JIT auto-evict only applies to models loaded reactively via API requests - SDK-loaded models remain in memory until explicitly unloaded Solutions implemented: 1. Add model.unload() after metadata query (line 243) - Load model temporarily to get context length - Unload immediately to hand control back to JIT system - Subsequent API requests trigger JIT load with auto-evict 2. Add token limit caching to prevent repeated SDK calls - Cache discovered limits in _token_limit_cache dict (line 48) - Key: (model_name, base_url), Value: token_limit - Prevents duplicate load/unload cycles within same process - Cache shared across all discovery methods (Ollama, SDK, registry) Tests: - TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py) - Manual testing confirmed no duplicate models in LM Studio after fix - All existing tests pass Impact: - Respects user's LM Studio JIT and auto-evict settings - Reduces model memory footprint - Faster subsequent builds (cached limits) * Document prompt template and LM Studio SDK features Added comprehensive documentation for new optional embedding features: Configuration Guide (docs/configuration-guide.md): - New section: "Optional Embedding Features" - Task-Specific Prompt Templates subsection: - Explains EmbeddingGemma use case with document/query prompts - CLI and Python API examples - Clear warnings about compatible vs incompatible models - References to GitHub issue #155 and HuggingFace blog - LM Studio Auto-Detection subsection: - Prerequisites (Node.js + @lmstudio/sdk) - How auto-detection works (4-step process) - Benefits and optional nature clearly stated FAQ (docs/faq.md): - FAQ #2: When should I use prompt templates? - DO/DON'T guidance with examples - Links to detailed configuration guide - FAQ #3: Why is LM Studio loading multiple copies? - Explains the JIT auto-evict fix - Troubleshooting steps if still seeing issues - FAQ #4: Do I need Node.js and @lmstudio/sdk? - Clarifies it's completely optional - Lists benefits if installed - Installation instructions Cross-references between documents for easy navigation between quick reference and detailed guides. * Add separate build/query template support for task-specific models Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility. * Consolidate prompt template tests from 44 to 37 tests Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing. * Fix query template application in compute_query_embedding Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches. Changes: - Move template application to before any computation path (searcher_base.py:109-110) - Add comprehensive tests for both server and fallback paths - Consolidate tests into test_prompt_template_persistence.py Tests verify: - Template applied when using embedding server - Template applied in fallback path - Consistent behavior between both paths * Apply ruff formatting and fix linting issues - Remove unused imports - Fix import ordering - Remove unused variables - Apply code formatting * Fix CI test failures: mock OPENAI_API_KEY in tests Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
This commit is contained in:
@@ -916,6 +916,7 @@ class LeannSearcher:
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||
batch_size: int = 0,
|
||||
use_grep: bool = False,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
@@ -979,10 +980,24 @@ class LeannSearcher:
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Extract query template from stored embedding_options with fallback chain:
|
||||
# 1. Check provider_options override (highest priority)
|
||||
# 2. Check query_prompt_template (new format)
|
||||
# 3. Check prompt_template (old format for backward compat)
|
||||
# 4. None (no template)
|
||||
query_template = None
|
||||
if provider_options and "prompt_template" in provider_options:
|
||||
query_template = provider_options["prompt_template"]
|
||||
elif "query_prompt_template" in self.embedding_options:
|
||||
query_template = self.embedding_options["query_prompt_template"]
|
||||
elif "prompt_template" in self.embedding_options:
|
||||
query_template = self.embedding_options["prompt_template"]
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
query_template=query_template,
|
||||
)
|
||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
|
||||
@@ -144,6 +144,18 @@ Examples:
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--query-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template for queries (different from build template for task-specific models)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||
)
|
||||
@@ -260,6 +272,12 @@ Examples:
|
||||
action="store_true",
|
||||
help="Display file paths and metadata in search results",
|
||||
)
|
||||
search_parser.add_argument(
|
||||
"--embedding-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
@@ -1398,6 +1416,14 @@ Examples:
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
if args.query_prompt_template:
|
||||
# New format: separate templates
|
||||
if args.embedding_prompt_template:
|
||||
embedding_options["build_prompt_template"] = args.embedding_prompt_template
|
||||
embedding_options["query_prompt_template"] = args.query_prompt_template
|
||||
elif args.embedding_prompt_template:
|
||||
# Old format: single template (backward compat)
|
||||
embedding_options["prompt_template"] = args.embedding_prompt_template
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
|
||||
@@ -4,8 +4,10 @@ Consolidates all embedding computation logic using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure performance
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -40,6 +42,11 @@ EMBEDDING_MODEL_LIMITS = {
|
||||
"text-embedding-ada-002": 8192,
|
||||
}
|
||||
|
||||
# Runtime cache for dynamically discovered token limits
|
||||
# Key: (model_name, base_url), Value: token_limit
|
||||
# Prevents repeated SDK/API calls for the same model
|
||||
_token_limit_cache: dict[tuple[str, str], int] = {}
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
model_name: str,
|
||||
@@ -49,6 +56,7 @@ def get_model_token_limit(
|
||||
"""
|
||||
Get token limit for a given embedding model.
|
||||
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||
Caches discovered limits to prevent repeated API/SDK calls.
|
||||
|
||||
Args:
|
||||
model_name: Name of the embedding model
|
||||
@@ -58,12 +66,33 @@ def get_model_token_limit(
|
||||
Returns:
|
||||
Token limit for the model in tokens
|
||||
"""
|
||||
# Check cache first to avoid repeated SDK/API calls
|
||||
cache_key = (model_name, base_url or "")
|
||||
if cache_key in _token_limit_cache:
|
||||
cached_limit = _token_limit_cache[cache_key]
|
||||
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
|
||||
return cached_limit
|
||||
|
||||
# Try Ollama dynamic discovery if base_url provided
|
||||
if base_url:
|
||||
# Detect Ollama servers by port or "ollama" in URL
|
||||
if "11434" in base_url or "ollama" in base_url.lower():
|
||||
limit = _query_ollama_context_limit(model_name, base_url)
|
||||
if limit:
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Try LM Studio SDK discovery
|
||||
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
|
||||
# Convert HTTP to WebSocket URL
|
||||
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
|
||||
# Remove /v1 suffix if present
|
||||
if ws_url.endswith("/v1"):
|
||||
ws_url = ws_url[:-3]
|
||||
|
||||
limit = _query_lmstudio_context_limit(model_name, ws_url)
|
||||
if limit:
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Fallback to known model registry with version handling (from PR #154)
|
||||
@@ -72,19 +101,25 @@ def get_model_token_limit(
|
||||
|
||||
# Check exact match first
|
||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||
return EMBEDDING_MODEL_LIMITS[model_name]
|
||||
limit = EMBEDDING_MODEL_LIMITS[model_name]
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Check base name match
|
||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Check partial matches for common patterns
|
||||
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
||||
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
|
||||
if known_model in base_model_name or base_model_name in known_model:
|
||||
return limit
|
||||
_token_limit_cache[cache_key] = registry_limit
|
||||
return registry_limit
|
||||
|
||||
# Default fallback
|
||||
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||
_token_limit_cache[cache_key] = default
|
||||
return default
|
||||
|
||||
|
||||
@@ -185,6 +220,91 @@ def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]
|
||||
return None
|
||||
|
||||
|
||||
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||
"""
|
||||
Query LM Studio SDK for model context length via Node.js subprocess.
|
||||
|
||||
Args:
|
||||
model_name: Name of the LM Studio model
|
||||
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
|
||||
|
||||
Returns:
|
||||
Context limit in tokens if found, None otherwise
|
||||
"""
|
||||
# Inline JavaScript using @lmstudio/sdk
|
||||
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
|
||||
js_code = f"""
|
||||
const {{ LMStudioClient }} = require('@lmstudio/sdk');
|
||||
(async () => {{
|
||||
try {{
|
||||
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
|
||||
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
|
||||
const contextLength = await model.getContextLength();
|
||||
await model.unload(); // Unload immediately to respect JIT auto-evict settings
|
||||
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
|
||||
}} catch (error) {{
|
||||
console.error(JSON.stringify({{ error: error.message }}));
|
||||
process.exit(1);
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
|
||||
try:
|
||||
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
|
||||
env = os.environ.copy()
|
||||
|
||||
# Try to get npm global root (works with nvm, brew node, etc.)
|
||||
try:
|
||||
npm_root = subprocess.run(
|
||||
["npm", "root", "-g"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if npm_root.returncode == 0:
|
||||
global_modules = npm_root.stdout.strip()
|
||||
# Append to existing NODE_PATH if present
|
||||
existing_node_path = env.get("NODE_PATH", "")
|
||||
env["NODE_PATH"] = (
|
||||
f"{global_modules}:{existing_node_path}"
|
||||
if existing_node_path
|
||||
else global_modules
|
||||
)
|
||||
except Exception:
|
||||
# If npm not available, continue with existing NODE_PATH
|
||||
pass
|
||||
|
||||
result = subprocess.run(
|
||||
["node", "-e", js_code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.debug(f"LM Studio SDK error: {result.stderr}")
|
||||
return None
|
||||
|
||||
data = json.loads(result.stdout)
|
||||
context_length = data.get("contextLength")
|
||||
|
||||
if context_length and context_length > 0:
|
||||
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
|
||||
return context_length
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.debug("LM Studio SDK query timeout")
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("LM Studio SDK returned invalid JSON")
|
||||
except Exception as e:
|
||||
logger.debug(f"LM Studio SDK query failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: dict[str, Any] = {}
|
||||
|
||||
@@ -232,6 +352,7 @@ def compute_embeddings(
|
||||
model_name,
|
||||
base_url=provider_options.get("base_url"),
|
||||
api_key=provider_options.get("api_key"),
|
||||
provider_options=provider_options,
|
||||
)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
@@ -241,6 +362,7 @@ def compute_embeddings(
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
host=provider_options.get("host"),
|
||||
provider_options=provider_options,
|
||||
)
|
||||
elif mode == "gemini":
|
||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||
@@ -579,6 +701,7 @@ def compute_embeddings_openai(
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
@@ -597,26 +720,37 @@ def compute_embeddings_openai(
|
||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||
)
|
||||
|
||||
resolved_base_url = resolve_openai_base_url(base_url)
|
||||
resolved_api_key = resolve_openai_api_key(api_key)
|
||||
# Extract base_url and api_key from provider_options if not provided directly
|
||||
provider_options = provider_options or {}
|
||||
effective_base_url = base_url or provider_options.get("base_url")
|
||||
effective_api_key = api_key or provider_options.get("api_key")
|
||||
|
||||
resolved_base_url = resolve_openai_base_url(effective_base_url)
|
||||
resolved_api_key = resolve_openai_api_key(effective_api_key)
|
||||
|
||||
if not resolved_api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
# Cache OpenAI client
|
||||
cache_key = f"openai_client::{resolved_base_url}"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
# Create OpenAI client
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
print(f"len of texts: {len(texts)}")
|
||||
|
||||
# Apply prompt template if provided
|
||||
prompt_template = provider_options.get("prompt_template")
|
||||
|
||||
if prompt_template:
|
||||
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||
texts = [f"{prompt_template}{text}" for text in texts]
|
||||
|
||||
# Query token limit and apply truncation
|
||||
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
|
||||
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
|
||||
texts = truncate_to_token_limit(texts, token_limit)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||
all_embeddings = []
|
||||
@@ -647,7 +781,15 @@ def compute_embeddings_openai(
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Verify we got the expected number of embeddings
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
logger.warning(
|
||||
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
|
||||
)
|
||||
|
||||
# Only take the number of embeddings that match the batch size
|
||||
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {i} failed: {e}")
|
||||
raise
|
||||
@@ -737,6 +879,7 @@ def compute_embeddings_ollama(
|
||||
model_name: str,
|
||||
is_build: bool = False,
|
||||
host: Optional[str] = None,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API with true batch processing.
|
||||
@@ -749,6 +892,7 @@ def compute_embeddings_ollama(
|
||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||
provider_options: Optional provider-specific options (e.g., prompt_template)
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
@@ -885,6 +1029,14 @@ def compute_embeddings_ollama(
|
||||
|
||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||
|
||||
# Apply prompt template if provided
|
||||
provider_options = provider_options or {}
|
||||
prompt_template = provider_options.get("prompt_template")
|
||||
|
||||
if prompt_template:
|
||||
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||
texts = [f"{prompt_template}{text}" for text in texts]
|
||||
|
||||
# Get model token limit and apply truncation before batching
|
||||
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||
|
||||
@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
query_template: Optional prompt template to prepend to query
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array with shape (1, D)
|
||||
|
||||
@@ -90,6 +90,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
@@ -98,10 +99,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
query_template: Optional prompt template to prepend to query
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Apply query template BEFORE any computation path
|
||||
# This ensures template is applied consistently for both server and fallback paths
|
||||
if query_template:
|
||||
query = f"{query_template}{query}"
|
||||
|
||||
# Try to use embedding server if available and requested
|
||||
if use_server_if_available:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user