* Add prompt template support and LM Studio SDK integration Features: - Prompt template support for embedding models (via --embedding-prompt-template) - LM Studio SDK integration for automatic context length detection - Hybrid token limit discovery (Ollama → LM Studio → Registry → Default) - Client-side token truncation to prevent silent failures - Automatic persistence of embedding_options to .meta.json Implementation: - Added _query_lmstudio_context_limit() with Node.js subprocess bridge - Modified compute_embeddings_openai() to apply prompt templates before truncation - Extended CLI with --embedding-prompt-template flag for build and search - URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords) - HTTP→WebSocket URL conversion for SDK compatibility Tests: - 60 passing tests across 5 test files - Comprehensive coverage of prompt templates, LM Studio integration, and token handling - Parametrized tests for maintainability and clarity * Add integration tests and fix LM Studio SDK bridge Features: - End-to-end integration tests for prompt template with EmbeddingGemma - Integration tests for hybrid token limit discovery mechanism - Tests verify real-world functionality with live services (LM Studio, Ollama) Fixes: - LM Studio SDK bridge now uses client.embedding.load() for embedding models - Fixed NODE_PATH resolution to include npm global modules - Fixed integration test to use WebSocket URL (ws://) for SDK bridge Tests: - test_prompt_template_e2e.py: 8 integration tests covering: - Prompt template prepending with LM Studio (EmbeddingGemma) - LM Studio SDK bridge for context length detection - Ollama dynamic token limit detection - Hybrid discovery fallback mechanism (registry, default) - All tests marked with @pytest.mark.integration for selective execution - Tests gracefully skip when services unavailable Documentation: - Updated tests/README.md with integration test section - Added prerequisites and running instructions - Documented that prompt templates are ONLY for EmbeddingGemma - Added integration marker to pyproject.toml Test Results: - All 8 integration tests passing with live services - Confirmed prompt templates work correctly with EmbeddingGemma - Verified LM Studio SDK bridge auto-detects context length (2048) - Validated hybrid token limit discovery across all backends * Add prompt template support to Ollama mode Extends prompt template functionality from OpenAI mode to Ollama for backend consistency. Changes: - Add provider_options parameter to compute_embeddings_ollama() - Apply prompt template before token truncation (lines 1005-1011) - Pass provider_options through compute_embeddings() call chain Tests: - test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama - test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template - Both tests pass with live Ollama service (2/2 passing) Usage: leann build --embedding-mode ollama --embedding-prompt-template "query: " ... * Fix LM Studio SDK bridge to respect JIT auto-evict settings Problem: SDK bridge called client.embedding.load() which loaded models into LM Studio memory and bypassed JIT auto-evict settings, causing duplicate model instances to accumulate. Root cause analysis (from Perplexity research): - Explicit SDK load() commands are treated as "pinned" models - JIT auto-evict only applies to models loaded reactively via API requests - SDK-loaded models remain in memory until explicitly unloaded Solutions implemented: 1. Add model.unload() after metadata query (line 243) - Load model temporarily to get context length - Unload immediately to hand control back to JIT system - Subsequent API requests trigger JIT load with auto-evict 2. Add token limit caching to prevent repeated SDK calls - Cache discovered limits in _token_limit_cache dict (line 48) - Key: (model_name, base_url), Value: token_limit - Prevents duplicate load/unload cycles within same process - Cache shared across all discovery methods (Ollama, SDK, registry) Tests: - TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py) - Manual testing confirmed no duplicate models in LM Studio after fix - All existing tests pass Impact: - Respects user's LM Studio JIT and auto-evict settings - Reduces model memory footprint - Faster subsequent builds (cached limits) * Document prompt template and LM Studio SDK features Added comprehensive documentation for new optional embedding features: Configuration Guide (docs/configuration-guide.md): - New section: "Optional Embedding Features" - Task-Specific Prompt Templates subsection: - Explains EmbeddingGemma use case with document/query prompts - CLI and Python API examples - Clear warnings about compatible vs incompatible models - References to GitHub issue #155 and HuggingFace blog - LM Studio Auto-Detection subsection: - Prerequisites (Node.js + @lmstudio/sdk) - How auto-detection works (4-step process) - Benefits and optional nature clearly stated FAQ (docs/faq.md): - FAQ #2: When should I use prompt templates? - DO/DON'T guidance with examples - Links to detailed configuration guide - FAQ #3: Why is LM Studio loading multiple copies? - Explains the JIT auto-evict fix - Troubleshooting steps if still seeing issues - FAQ #4: Do I need Node.js and @lmstudio/sdk? - Clarifies it's completely optional - Lists benefits if installed - Installation instructions Cross-references between documents for easy navigation between quick reference and detailed guides. * Add separate build/query template support for task-specific models Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility. * Consolidate prompt template tests from 44 to 37 tests Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing. * Fix query template application in compute_query_embedding Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches. Changes: - Move template application to before any computation path (searcher_base.py:109-110) - Add comprehensive tests for both server and fallback paths - Consolidate tests into test_prompt_template_persistence.py Tests verify: - Template applied when using embedding server - Template applied in fallback path - Consistent behavior between both paths * Apply ruff formatting and fix linting issues - Remove unused imports - Fix import ordering - Remove unused variables - Apply code formatting * Fix CI test failures: mock OPENAI_API_KEY in tests Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
644 lines
25 KiB
Python
644 lines
25 KiB
Python
"""Unit tests for token-aware truncation functionality.
|
|
|
|
This test suite defines the contract for token truncation functions that prevent
|
|
500 errors from Ollama when text exceeds model token limits. These tests verify:
|
|
|
|
1. Model token limit retrieval (known and unknown models)
|
|
2. Text truncation behavior for single and multiple texts
|
|
3. Token counting and truncation accuracy using tiktoken
|
|
|
|
All tests are written in Red Phase - they should FAIL initially because the
|
|
implementation does not exist yet.
|
|
"""
|
|
|
|
import pytest
|
|
import tiktoken
|
|
from leann.embedding_compute import (
|
|
EMBEDDING_MODEL_LIMITS,
|
|
get_model_token_limit,
|
|
truncate_to_token_limit,
|
|
)
|
|
|
|
|
|
class TestModelTokenLimits:
|
|
"""Tests for retrieving model-specific token limits."""
|
|
|
|
def test_get_model_token_limit_known_model(self):
|
|
"""Verify correct token limit is returned for known models.
|
|
|
|
Known models should return their specific token limits from
|
|
EMBEDDING_MODEL_LIMITS dictionary.
|
|
"""
|
|
# Test nomic-embed-text (2048 tokens)
|
|
limit = get_model_token_limit("nomic-embed-text")
|
|
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
|
|
|
|
# Test nomic-embed-text-v1.5 (2048 tokens)
|
|
limit = get_model_token_limit("nomic-embed-text-v1.5")
|
|
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
|
|
|
|
# Test nomic-embed-text-v2 (512 tokens)
|
|
limit = get_model_token_limit("nomic-embed-text-v2")
|
|
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
|
|
|
|
# Test OpenAI models (8192 tokens)
|
|
limit = get_model_token_limit("text-embedding-3-small")
|
|
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
|
|
|
|
def test_get_model_token_limit_unknown_model(self):
|
|
"""Verify default token limit is returned for unknown models.
|
|
|
|
Unknown models should return the default limit (2048) to allow
|
|
operation with reasonable safety margin.
|
|
"""
|
|
# Test with completely unknown model
|
|
limit = get_model_token_limit("unknown-model-xyz")
|
|
assert limit == 2048, "Unknown models should return default 2048"
|
|
|
|
# Test with empty string
|
|
limit = get_model_token_limit("")
|
|
assert limit == 2048, "Empty model name should return default 2048"
|
|
|
|
def test_get_model_token_limit_custom_default(self):
|
|
"""Verify custom default can be specified for unknown models.
|
|
|
|
Allow callers to specify their own default token limit when
|
|
model is not in the known models dictionary.
|
|
"""
|
|
limit = get_model_token_limit("unknown-model", default=4096)
|
|
assert limit == 4096, "Should return custom default for unknown models"
|
|
|
|
# Known model should ignore custom default
|
|
limit = get_model_token_limit("nomic-embed-text", default=4096)
|
|
assert limit == 2048, "Known model should ignore custom default"
|
|
|
|
def test_embedding_model_limits_dictionary_exists(self):
|
|
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
|
|
|
|
The dictionary should be importable and contain at least the
|
|
known nomic models with correct token limits.
|
|
"""
|
|
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
|
|
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
|
|
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
|
|
"Should contain nomic-embed-text-v1.5"
|
|
)
|
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
|
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
|
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
|
|
# OpenAI models
|
|
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
|
|
|
|
|
|
class TestTokenTruncation:
|
|
"""Tests for truncating texts to token limits."""
|
|
|
|
@pytest.fixture
|
|
def tokenizer(self):
|
|
"""Provide tiktoken tokenizer for token counting verification."""
|
|
return tiktoken.get_encoding("cl100k_base")
|
|
|
|
def test_truncate_single_text_under_limit(self, tokenizer):
|
|
"""Verify text under token limit remains unchanged.
|
|
|
|
When text is already within the token limit, it should be
|
|
returned unchanged with no truncation.
|
|
"""
|
|
text = "This is a short text that is well under the token limit."
|
|
token_count = len(tokenizer.encode(text))
|
|
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
|
|
|
|
# Truncate with generous limit
|
|
result = truncate_to_token_limit([text], token_limit=512)
|
|
|
|
assert len(result) == 1, "Should return same number of texts"
|
|
assert result[0] == text, "Text under limit should be unchanged"
|
|
|
|
def test_truncate_single_text_over_limit(self, tokenizer):
|
|
"""Verify text over token limit is truncated correctly.
|
|
|
|
When text exceeds the token limit, it should be truncated to
|
|
fit within the limit while maintaining valid token boundaries.
|
|
"""
|
|
# Create a text that definitely exceeds limit
|
|
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
|
|
original_token_count = len(tokenizer.encode(text))
|
|
assert original_token_count > 50, (
|
|
f"Test setup: text should be long (has {original_token_count} tokens)"
|
|
)
|
|
|
|
# Truncate to 50 tokens
|
|
result = truncate_to_token_limit([text], token_limit=50)
|
|
|
|
assert len(result) == 1, "Should return same number of texts"
|
|
assert result[0] != text, "Text over limit should be truncated"
|
|
assert len(result[0]) < len(text), "Truncated text should be shorter"
|
|
|
|
# Verify truncated text is within token limit
|
|
truncated_token_count = len(tokenizer.encode(result[0]))
|
|
assert truncated_token_count <= 50, (
|
|
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
|
|
)
|
|
|
|
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
|
|
"""Verify multiple texts with mixed lengths are handled correctly.
|
|
|
|
When processing multiple texts:
|
|
- Texts under limit should remain unchanged
|
|
- Texts over limit should be truncated independently
|
|
- Output list should maintain same order and length
|
|
"""
|
|
texts = [
|
|
"Short text.", # Under limit
|
|
"word " * 200, # Over limit
|
|
"Another short one.", # Under limit
|
|
"token " * 150, # Over limit
|
|
]
|
|
|
|
# Verify test setup
|
|
for i, text in enumerate(texts):
|
|
token_count = len(tokenizer.encode(text))
|
|
if i in [1, 3]:
|
|
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
|
|
else:
|
|
assert token_count < 50, (
|
|
f"Text {i} should be under limit (has {token_count} tokens)"
|
|
)
|
|
|
|
# Truncate with 50 token limit
|
|
result = truncate_to_token_limit(texts, token_limit=50)
|
|
|
|
assert len(result) == len(texts), "Should return same number of texts"
|
|
|
|
# Verify each text individually
|
|
for i, (original, truncated) in enumerate(zip(texts, result)):
|
|
token_count = len(tokenizer.encode(truncated))
|
|
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
|
|
|
|
# Short texts should be unchanged
|
|
if i in [0, 2]:
|
|
assert truncated == original, f"Short text {i} should be unchanged"
|
|
# Long texts should be truncated
|
|
else:
|
|
assert len(truncated) < len(original), f"Long text {i} should be truncated"
|
|
|
|
def test_truncate_empty_list(self):
|
|
"""Verify empty input list returns empty output list.
|
|
|
|
Edge case: empty list should return empty list without errors.
|
|
"""
|
|
result = truncate_to_token_limit([], token_limit=512)
|
|
assert result == [], "Empty input should return empty output"
|
|
|
|
def test_truncate_preserves_order(self, tokenizer):
|
|
"""Verify truncation preserves original text order.
|
|
|
|
Output list should maintain the same order as input list,
|
|
regardless of which texts were truncated.
|
|
"""
|
|
texts = [
|
|
"First text " * 50, # Will be truncated
|
|
"Second text.", # Won't be truncated
|
|
"Third text " * 50, # Will be truncated
|
|
]
|
|
|
|
result = truncate_to_token_limit(texts, token_limit=20)
|
|
|
|
assert len(result) == 3, "Should preserve list length"
|
|
# Check that order is maintained by looking for distinctive words
|
|
assert "First" in result[0], "First text should remain in first position"
|
|
assert "Second" in result[1], "Second text should remain in second position"
|
|
assert "Third" in result[2], "Third text should remain in third position"
|
|
|
|
def test_truncate_extremely_long_text(self, tokenizer):
|
|
"""Verify extremely long texts are truncated efficiently.
|
|
|
|
Test with text that far exceeds token limit to ensure
|
|
truncation handles extreme cases without performance issues.
|
|
"""
|
|
# Create very long text (simulate real-world scenario)
|
|
text = "token " * 5000 # ~5000+ tokens
|
|
original_token_count = len(tokenizer.encode(text))
|
|
assert original_token_count > 1000, "Test setup: text should be very long"
|
|
|
|
# Truncate to small limit
|
|
result = truncate_to_token_limit([text], token_limit=100)
|
|
|
|
assert len(result) == 1
|
|
truncated_token_count = len(tokenizer.encode(result[0]))
|
|
assert truncated_token_count <= 100, (
|
|
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
|
|
)
|
|
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
|
|
|
|
def test_truncate_exact_token_limit(self, tokenizer):
|
|
"""Verify text at exactly token limit is handled correctly.
|
|
|
|
Edge case: text with exactly the token limit should either
|
|
remain unchanged or be safely truncated by 1 token.
|
|
"""
|
|
# Create text with approximately 50 tokens
|
|
# We'll adjust to get exactly 50
|
|
target_tokens = 50
|
|
text = "word " * 50
|
|
tokens = tokenizer.encode(text)
|
|
|
|
# Adjust to get exactly target_tokens
|
|
if len(tokens) > target_tokens:
|
|
tokens = tokens[:target_tokens]
|
|
text = tokenizer.decode(tokens)
|
|
elif len(tokens) < target_tokens:
|
|
# Add more words
|
|
while len(tokenizer.encode(text)) < target_tokens:
|
|
text += "word "
|
|
tokens = tokenizer.encode(text)[:target_tokens]
|
|
text = tokenizer.decode(tokens)
|
|
|
|
# Verify we have exactly target_tokens
|
|
assert len(tokenizer.encode(text)) == target_tokens, (
|
|
"Test setup: should have exactly 50 tokens"
|
|
)
|
|
|
|
result = truncate_to_token_limit([text], token_limit=target_tokens)
|
|
|
|
assert len(result) == 1
|
|
result_tokens = len(tokenizer.encode(result[0]))
|
|
assert result_tokens <= target_tokens, (
|
|
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
|
)
|
|
|
|
|
|
class TestLMStudioHybridDiscovery:
|
|
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
|
|
|
|
These tests verify that get_model_token_limit() properly integrates with
|
|
the LM Studio SDK bridge for dynamic token limit discovery. The integration
|
|
should:
|
|
|
|
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
|
|
2. Convert HTTP URLs to WebSocket format for SDK queries
|
|
3. Query LM Studio SDK and use discovered limit
|
|
4. Fall back to registry when SDK returns None
|
|
5. Execute AFTER Ollama detection but BEFORE registry fallback
|
|
|
|
All tests are written in Red Phase - they should FAIL initially because the
|
|
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
|
|
"""
|
|
|
|
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
|
|
"""Verify LM Studio SDK query succeeds and returns detected limit.
|
|
|
|
When a LM Studio base_url is detected and the SDK query succeeds,
|
|
get_model_token_limit() should return the dynamically discovered
|
|
context length without falling back to the registry.
|
|
"""
|
|
|
|
# Mock _query_lmstudio_context_limit to return successful SDK query
|
|
def mock_query_lmstudio(model_name, base_url):
|
|
# Verify WebSocket URL was passed (not HTTP)
|
|
assert base_url.startswith("ws://"), (
|
|
f"Should convert HTTP to WebSocket format, got: {base_url}"
|
|
)
|
|
return 8192 # Successful SDK query
|
|
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
|
mock_query_lmstudio,
|
|
)
|
|
|
|
# Test with HTTP URL that should be converted to WebSocket
|
|
limit = get_model_token_limit(
|
|
model_name="custom-model", base_url="http://localhost:1234/v1"
|
|
)
|
|
|
|
assert limit == 8192, "Should return limit from LM Studio SDK query"
|
|
|
|
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
|
|
"""Verify fallback to registry when LM Studio SDK returns None.
|
|
|
|
When LM Studio SDK query fails (returns None), get_model_token_limit()
|
|
should fall back to the EMBEDDING_MODEL_LIMITS registry.
|
|
"""
|
|
|
|
# Mock _query_lmstudio_context_limit to return None (SDK failure)
|
|
def mock_query_lmstudio(model_name, base_url):
|
|
return None # SDK query failed
|
|
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
|
mock_query_lmstudio,
|
|
)
|
|
|
|
# Test with known model that exists in registry
|
|
limit = get_model_token_limit(
|
|
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
|
|
)
|
|
|
|
# Should fall back to registry value
|
|
assert limit == 2048, "Should fall back to registry when SDK returns None"
|
|
|
|
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
|
|
"""Verify detection of LM Studio via port 1234.
|
|
|
|
get_model_token_limit() should recognize port 1234 as a LM Studio
|
|
server and attempt SDK query, regardless of hostname.
|
|
"""
|
|
query_called = False
|
|
|
|
def mock_query_lmstudio(model_name, base_url):
|
|
nonlocal query_called
|
|
query_called = True
|
|
return 4096
|
|
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
|
mock_query_lmstudio,
|
|
)
|
|
|
|
# Test with port 1234 (default LM Studio port)
|
|
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
|
|
|
|
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
|
|
assert limit == 4096, "Should return SDK query result"
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_url,expected_limit,keyword",
|
|
[
|
|
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
|
|
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
|
|
],
|
|
)
|
|
def test_get_model_token_limit_lmstudio_url_keyword_detection(
|
|
self, monkeypatch, test_url, expected_limit, keyword
|
|
):
|
|
"""Verify detection of LM Studio via keywords in URL.
|
|
|
|
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
|
|
in the URL as indicating a LM Studio server.
|
|
"""
|
|
query_called = False
|
|
|
|
def mock_query_lmstudio(model_name, base_url):
|
|
nonlocal query_called
|
|
query_called = True
|
|
return expected_limit
|
|
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
|
mock_query_lmstudio,
|
|
)
|
|
|
|
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
|
|
|
|
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
|
|
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
|
|
|
|
@pytest.mark.parametrize(
|
|
"input_url,expected_protocol,expected_host",
|
|
[
|
|
("http://localhost:1234/v1", "ws://", "localhost:1234"),
|
|
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
|
|
],
|
|
)
|
|
def test_get_model_token_limit_protocol_conversion(
|
|
self, monkeypatch, input_url, expected_protocol, expected_host
|
|
):
|
|
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
|
|
|
|
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
|
|
1. Convert 'http://' to 'ws://'
|
|
2. Convert 'https://' to 'wss://'
|
|
3. Remove '/v1' or other path suffixes (SDK expects base URL)
|
|
4. Preserve host and port
|
|
"""
|
|
conversions_tested = []
|
|
|
|
def mock_query_lmstudio(model_name, base_url):
|
|
conversions_tested.append(base_url)
|
|
return 8192
|
|
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
|
mock_query_lmstudio,
|
|
)
|
|
|
|
get_model_token_limit(model_name="test-model", base_url=input_url)
|
|
|
|
# Verify conversion happened
|
|
assert len(conversions_tested) == 1, "Should have called SDK query once"
|
|
assert conversions_tested[0].startswith(expected_protocol), (
|
|
f"Should convert to {expected_protocol}"
|
|
)
|
|
assert expected_host in conversions_tested[0], (
|
|
f"Should preserve host and port: {expected_host}"
|
|
)
|
|
|
|
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
|
|
"""Verify LM Studio detection happens AFTER Ollama detection.
|
|
|
|
The hybrid discovery order should be:
|
|
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
|
|
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
|
|
3. Registry fallback
|
|
|
|
If both Ollama and LM Studio patterns match, Ollama should take precedence.
|
|
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
|
|
"""
|
|
ollama_called = False
|
|
lmstudio_called = False
|
|
|
|
def mock_query_ollama(model_name, base_url):
|
|
nonlocal ollama_called
|
|
ollama_called = True
|
|
return 2048 # Ollama query succeeds
|
|
|
|
def mock_query_lmstudio(model_name, base_url):
|
|
nonlocal lmstudio_called
|
|
lmstudio_called = True
|
|
return None # Should not be reached if Ollama succeeds
|
|
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_ollama_context_limit",
|
|
mock_query_ollama,
|
|
)
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
|
mock_query_lmstudio,
|
|
)
|
|
|
|
# Test with Ollama URL
|
|
limit = get_model_token_limit(
|
|
model_name="test-model", base_url="http://localhost:11434/api"
|
|
)
|
|
|
|
assert ollama_called, "Should attempt Ollama query first"
|
|
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
|
|
assert limit == 2048, "Should return Ollama result"
|
|
|
|
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
|
|
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
|
|
|
|
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
|
|
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
|
|
"""
|
|
lmstudio_called = False
|
|
|
|
def mock_query_lmstudio(model_name, base_url):
|
|
nonlocal lmstudio_called
|
|
lmstudio_called = True
|
|
return 8192
|
|
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
|
mock_query_lmstudio,
|
|
)
|
|
|
|
# Test with non-LM Studio URLs
|
|
test_cases = [
|
|
"http://localhost:8080/v1", # Different port
|
|
"http://openai.example.com/v1", # Different service
|
|
"http://localhost:3000/v1", # Another port
|
|
]
|
|
|
|
for base_url in test_cases:
|
|
lmstudio_called = False # Reset for each test
|
|
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
|
|
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
|
|
|
|
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
|
|
"""Verify LM Studio detection is case-insensitive for keywords.
|
|
|
|
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
|
|
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
|
|
"""
|
|
query_called = False
|
|
|
|
def mock_query_lmstudio(model_name, base_url):
|
|
nonlocal query_called
|
|
query_called = True
|
|
return 8192
|
|
|
|
monkeypatch.setattr(
|
|
"leann.embedding_compute._query_lmstudio_context_limit",
|
|
mock_query_lmstudio,
|
|
)
|
|
|
|
# Test various case variations
|
|
test_cases = [
|
|
"http://LMStudio.local:8080/v1",
|
|
"http://LMSTUDIO.example.com/v1",
|
|
"http://LmStudio.local/v1",
|
|
"http://api.LM.STUDIO:5000/v1",
|
|
]
|
|
|
|
for base_url in test_cases:
|
|
query_called = False # Reset for each test
|
|
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
|
|
assert query_called, f"Should detect LM Studio in URL: {base_url}"
|
|
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
|
|
|
|
|
|
class TestTokenLimitCaching:
|
|
"""Tests for token limit caching to prevent repeated SDK/API calls.
|
|
|
|
Caching prevents duplicate SDK/API calls within the same Python process,
|
|
which is important because:
|
|
1. LM Studio SDK load() can load duplicate model instances
|
|
2. Ollama /api/show queries add latency
|
|
3. Registry lookups are pure overhead
|
|
|
|
Cache is process-scoped and resets between leann build invocations.
|
|
"""
|
|
|
|
def setup_method(self):
|
|
"""Clear cache before each test."""
|
|
from leann.embedding_compute import _token_limit_cache
|
|
|
|
_token_limit_cache.clear()
|
|
|
|
def test_registry_lookup_is_cached(self):
|
|
"""Verify that registry lookups are cached."""
|
|
from leann.embedding_compute import _token_limit_cache
|
|
|
|
# First call
|
|
limit1 = get_model_token_limit("text-embedding-3-small")
|
|
assert limit1 == 8192
|
|
|
|
# Verify it's in cache
|
|
cache_key = ("text-embedding-3-small", "")
|
|
assert cache_key in _token_limit_cache
|
|
assert _token_limit_cache[cache_key] == 8192
|
|
|
|
# Second call should use cache
|
|
limit2 = get_model_token_limit("text-embedding-3-small")
|
|
assert limit2 == 8192
|
|
|
|
def test_default_fallback_is_cached(self):
|
|
"""Verify that default fallbacks are cached."""
|
|
from leann.embedding_compute import _token_limit_cache
|
|
|
|
# First call with unknown model
|
|
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
|
|
assert limit1 == 512
|
|
|
|
# Verify it's in cache
|
|
cache_key = ("unknown-model-xyz", "")
|
|
assert cache_key in _token_limit_cache
|
|
assert _token_limit_cache[cache_key] == 512
|
|
|
|
# Second call should use cache
|
|
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
|
|
assert limit2 == 512
|
|
|
|
def test_different_urls_create_separate_cache_entries(self):
|
|
"""Verify that different base_urls create separate cache entries."""
|
|
from leann.embedding_compute import _token_limit_cache
|
|
|
|
# Same model, different URLs
|
|
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
|
|
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
|
|
|
|
# Both should find the model in registry (2048)
|
|
assert limit1 == 2048
|
|
assert limit2 == 2048
|
|
|
|
# But they should be separate cache entries
|
|
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
|
|
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
|
|
|
|
assert cache_key1 in _token_limit_cache
|
|
assert cache_key2 in _token_limit_cache
|
|
assert len(_token_limit_cache) == 2
|
|
|
|
def test_cache_prevents_repeated_lookups(self):
|
|
"""Verify that cache prevents repeated registry/API lookups."""
|
|
from leann.embedding_compute import _token_limit_cache
|
|
|
|
model_name = "text-embedding-ada-002"
|
|
|
|
# First call - should add to cache
|
|
assert len(_token_limit_cache) == 0
|
|
limit1 = get_model_token_limit(model_name)
|
|
|
|
cache_size_after_first = len(_token_limit_cache)
|
|
assert cache_size_after_first == 1
|
|
|
|
# Multiple subsequent calls - cache size should not change
|
|
for _ in range(5):
|
|
limit = get_model_token_limit(model_name)
|
|
assert limit == limit1
|
|
assert len(_token_limit_cache) == cache_size_after_first
|
|
|
|
def test_versioned_model_names_cached_correctly(self):
|
|
"""Verify that versioned model names (e.g., model:tag) are cached."""
|
|
from leann.embedding_compute import _token_limit_cache
|
|
|
|
# Model with version tag
|
|
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
|
|
assert limit == 2048
|
|
|
|
# Should be cached with full name including version
|
|
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
|
|
assert cache_key in _token_limit_cache
|
|
assert _token_limit_cache[cache_key] == 2048
|