Files
LEANN/tests/test_lmstudio_bridge.py
ww26 1ef9cba7de Feature/prompt templates and lmstudio sdk (#171)
* Add prompt template support and LM Studio SDK integration

Features:

- Prompt template support for embedding models (via --embedding-prompt-template)

- LM Studio SDK integration for automatic context length detection

- Hybrid token limit discovery (Ollama → LM Studio → Registry → Default)

- Client-side token truncation to prevent silent failures

- Automatic persistence of embedding_options to .meta.json

Implementation:

- Added _query_lmstudio_context_limit() with Node.js subprocess bridge

- Modified compute_embeddings_openai() to apply prompt templates before truncation

- Extended CLI with --embedding-prompt-template flag for build and search

- URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords)

- HTTP→WebSocket URL conversion for SDK compatibility

Tests:

- 60 passing tests across 5 test files

- Comprehensive coverage of prompt templates, LM Studio integration, and token handling

- Parametrized tests for maintainability and clarity

* Add integration tests and fix LM Studio SDK bridge

Features:
- End-to-end integration tests for prompt template with EmbeddingGemma
- Integration tests for hybrid token limit discovery mechanism
- Tests verify real-world functionality with live services (LM Studio, Ollama)

Fixes:
- LM Studio SDK bridge now uses client.embedding.load() for embedding models
- Fixed NODE_PATH resolution to include npm global modules
- Fixed integration test to use WebSocket URL (ws://) for SDK bridge

Tests:
- test_prompt_template_e2e.py: 8 integration tests covering:
  - Prompt template prepending with LM Studio (EmbeddingGemma)
  - LM Studio SDK bridge for context length detection
  - Ollama dynamic token limit detection
  - Hybrid discovery fallback mechanism (registry, default)
- All tests marked with @pytest.mark.integration for selective execution
- Tests gracefully skip when services unavailable

Documentation:
- Updated tests/README.md with integration test section
- Added prerequisites and running instructions
- Documented that prompt templates are ONLY for EmbeddingGemma
- Added integration marker to pyproject.toml

Test Results:
- All 8 integration tests passing with live services
- Confirmed prompt templates work correctly with EmbeddingGemma
- Verified LM Studio SDK bridge auto-detects context length (2048)
- Validated hybrid token limit discovery across all backends

* Add prompt template support to Ollama mode

Extends prompt template functionality from OpenAI mode to Ollama for backend consistency.

Changes:
- Add provider_options parameter to compute_embeddings_ollama()
- Apply prompt template before token truncation (lines 1005-1011)
- Pass provider_options through compute_embeddings() call chain

Tests:
- test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama
- test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template
- Both tests pass with live Ollama service (2/2 passing)

Usage:
leann build --embedding-mode ollama --embedding-prompt-template "query: " ...

* Fix LM Studio SDK bridge to respect JIT auto-evict settings

Problem: SDK bridge called client.embedding.load() which loaded models into
LM Studio memory and bypassed JIT auto-evict settings, causing duplicate
model instances to accumulate.

Root cause analysis (from Perplexity research):
- Explicit SDK load() commands are treated as "pinned" models
- JIT auto-evict only applies to models loaded reactively via API requests
- SDK-loaded models remain in memory until explicitly unloaded

Solutions implemented:

1. Add model.unload() after metadata query (line 243)
   - Load model temporarily to get context length
   - Unload immediately to hand control back to JIT system
   - Subsequent API requests trigger JIT load with auto-evict

2. Add token limit caching to prevent repeated SDK calls
   - Cache discovered limits in _token_limit_cache dict (line 48)
   - Key: (model_name, base_url), Value: token_limit
   - Prevents duplicate load/unload cycles within same process
   - Cache shared across all discovery methods (Ollama, SDK, registry)

Tests:
- TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py)
- Manual testing confirmed no duplicate models in LM Studio after fix
- All existing tests pass

Impact:
- Respects user's LM Studio JIT and auto-evict settings
- Reduces model memory footprint
- Faster subsequent builds (cached limits)

* Document prompt template and LM Studio SDK features

Added comprehensive documentation for new optional embedding features:

Configuration Guide (docs/configuration-guide.md):
- New section: "Optional Embedding Features"
- Task-Specific Prompt Templates subsection:
  - Explains EmbeddingGemma use case with document/query prompts
  - CLI and Python API examples
  - Clear warnings about compatible vs incompatible models
  - References to GitHub issue #155 and HuggingFace blog
- LM Studio Auto-Detection subsection:
  - Prerequisites (Node.js + @lmstudio/sdk)
  - How auto-detection works (4-step process)
  - Benefits and optional nature clearly stated

FAQ (docs/faq.md):
- FAQ #2: When should I use prompt templates?
  - DO/DON'T guidance with examples
  - Links to detailed configuration guide
- FAQ #3: Why is LM Studio loading multiple copies?
  - Explains the JIT auto-evict fix
  - Troubleshooting steps if still seeing issues
- FAQ #4: Do I need Node.js and @lmstudio/sdk?
  - Clarifies it's completely optional
  - Lists benefits if installed
  - Installation instructions

Cross-references between documents for easy navigation between quick reference and detailed guides.

* Add separate build/query template support for task-specific models

Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility.

* Consolidate prompt template tests from 44 to 37 tests

Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing.

* Fix query template application in compute_query_embedding

Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches.

Changes:

- Move template application to before any computation path (searcher_base.py:109-110)

- Add comprehensive tests for both server and fallback paths

- Consolidate tests into test_prompt_template_persistence.py

Tests verify:

- Template applied when using embedding server

- Template applied in fallback path

- Consistent behavior between both paths

* Apply ruff formatting and fix linting issues

- Remove unused imports

- Fix import ordering

- Remove unused variables

- Apply code formatting

* Fix CI test failures: mock OPENAI_API_KEY in tests

Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
2025-11-14 15:25:17 -08:00

316 lines
12 KiB
Python

"""Unit tests for LM Studio TypeScript SDK bridge functionality.
This test suite defines the contract for the LM Studio SDK bridge that queries
model context length via Node.js subprocess. These tests verify:
1. Successful SDK query returns context length
2. Graceful fallback when Node.js not installed (FileNotFoundError)
3. Graceful fallback when SDK not installed (npm error)
4. Timeout handling (subprocess.TimeoutExpired)
5. Invalid JSON response handling
All tests are written in Red Phase - they should FAIL initially because the
`_query_lmstudio_context_limit` function does not exist yet.
The function contract:
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
- Outputs: context_length (int) or None on error
- Requirements:
1. Call Node.js with inline JavaScript using @lmstudio/sdk
2. 10-second timeout (accounts for Node.js startup)
3. Graceful fallback on any error (returns None, doesn't raise)
4. Parse JSON response with contextLength field
5. Log errors at debug level (not warning/error)
"""
import subprocess
from unittest.mock import Mock
import pytest
# Try to import the function - if it doesn't exist, tests will fail as expected
try:
from leann.embedding_compute import _query_lmstudio_context_limit
except ImportError:
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
def _query_lmstudio_context_limit(*args, **kwargs):
raise NotImplementedError(
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
)
class TestLMStudioBridge:
"""Tests for LM Studio TypeScript SDK bridge integration."""
def test_query_lmstudio_success(self, monkeypatch):
"""Verify successful SDK query returns context length.
When the Node.js subprocess successfully queries the LM Studio SDK,
it should return a JSON response with contextLength field. The function
should parse this and return the integer context length.
"""
def mock_run(*args, **kwargs):
# Verify timeout is set to 10 seconds
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
# Verify capture_output and text=True are set
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
assert kwargs.get("text") is True, "Should decode output as text"
# Return successful JSON response
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
# Test with typical LM Studio model
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit == 8192, "Should return context length from SDK response"
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
"""Verify graceful fallback when Node.js not installed.
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
The function should catch this and return None (graceful fallback to registry).
"""
def mock_run(*args, **kwargs):
raise FileNotFoundError("node: command not found")
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when Node.js not installed"
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
"""Verify graceful fallback when @lmstudio/sdk not installed.
When the SDK npm package is not installed, Node.js will return non-zero
exit code with error message in stderr. The function should detect this
and return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 1
mock_result.stdout = ""
mock_result.stderr = (
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
)
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when SDK not installed"
def test_query_lmstudio_timeout(self, monkeypatch):
"""Verify graceful fallback when subprocess times out.
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
not responding), subprocess.TimeoutExpired should be raised. The function
should catch this and return None.
"""
def mock_run(*args, **kwargs):
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None on timeout"
def test_query_lmstudio_invalid_json(self, monkeypatch):
"""Verify graceful fallback when response is invalid JSON.
When the subprocess returns malformed JSON (e.g., due to SDK error),
json.loads will raise ValueError/JSONDecodeError. The function should
catch this and return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = "This is not valid JSON"
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when JSON parsing fails"
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
"""Verify graceful fallback when JSON lacks contextLength field.
When the SDK returns valid JSON but without the expected contextLength
field (e.g., error response), the function should return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="nonexistent-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength field missing"
def test_query_lmstudio_null_context_length(self, monkeypatch):
"""Verify graceful fallback when contextLength is null.
When the SDK returns contextLength: null (model couldn't be loaded),
the function should return None for registry fallback.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength is null"
def test_query_lmstudio_zero_context_length(self, monkeypatch):
"""Verify graceful fallback when contextLength is zero.
When the SDK returns contextLength: 0 (invalid value), the function
should return None to trigger registry fallback.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength is zero"
def test_query_lmstudio_with_custom_port(self, monkeypatch):
"""Verify SDK query works with non-default WebSocket port.
LM Studio can run on custom ports. The function should pass the
provided base_url to the Node.js subprocess.
"""
def mock_run(*args, **kwargs):
# Verify the base_url argument is passed correctly
command = args[0] if args else kwargs.get("args", [])
assert "ws://localhost:8080" in " ".join(command), (
"Should pass custom port to subprocess"
)
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:8080"
)
assert limit == 4096, "Should work with custom WebSocket port"
@pytest.mark.parametrize(
"context_length,expected",
[
(512, 512), # Small context
(2048, 2048), # Common context
(8192, 8192), # Large context
(32768, 32768), # Very large context
],
)
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
"""Verify SDK query handles various context length values.
Different models have different context lengths. The function should
correctly parse and return any positive integer value.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit == expected, f"Should return {expected} for context length {context_length}"
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
Following the graceful fallback pattern from Ollama implementation,
errors should be logged at debug level to avoid alarming users when
fallback to registry works fine.
"""
import logging
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
def mock_run(*args, **kwargs):
raise FileNotFoundError("node: command not found")
monkeypatch.setattr("subprocess.run", mock_run)
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
# Check that debug logging occurred (not warning/error)
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
assert len(debug_logs) > 0, "Should log error at DEBUG level"
# Verify no WARNING or ERROR logs
warning_or_error_logs = [
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
]
assert len(warning_or_error_logs) == 0, (
"Should not log at WARNING/ERROR level for expected failures"
)