Files
LEANN/tests/test_token_truncation.py
ww26 1ef9cba7de Feature/prompt templates and lmstudio sdk (#171)
* Add prompt template support and LM Studio SDK integration

Features:

- Prompt template support for embedding models (via --embedding-prompt-template)

- LM Studio SDK integration for automatic context length detection

- Hybrid token limit discovery (Ollama → LM Studio → Registry → Default)

- Client-side token truncation to prevent silent failures

- Automatic persistence of embedding_options to .meta.json

Implementation:

- Added _query_lmstudio_context_limit() with Node.js subprocess bridge

- Modified compute_embeddings_openai() to apply prompt templates before truncation

- Extended CLI with --embedding-prompt-template flag for build and search

- URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords)

- HTTP→WebSocket URL conversion for SDK compatibility

Tests:

- 60 passing tests across 5 test files

- Comprehensive coverage of prompt templates, LM Studio integration, and token handling

- Parametrized tests for maintainability and clarity

* Add integration tests and fix LM Studio SDK bridge

Features:
- End-to-end integration tests for prompt template with EmbeddingGemma
- Integration tests for hybrid token limit discovery mechanism
- Tests verify real-world functionality with live services (LM Studio, Ollama)

Fixes:
- LM Studio SDK bridge now uses client.embedding.load() for embedding models
- Fixed NODE_PATH resolution to include npm global modules
- Fixed integration test to use WebSocket URL (ws://) for SDK bridge

Tests:
- test_prompt_template_e2e.py: 8 integration tests covering:
  - Prompt template prepending with LM Studio (EmbeddingGemma)
  - LM Studio SDK bridge for context length detection
  - Ollama dynamic token limit detection
  - Hybrid discovery fallback mechanism (registry, default)
- All tests marked with @pytest.mark.integration for selective execution
- Tests gracefully skip when services unavailable

Documentation:
- Updated tests/README.md with integration test section
- Added prerequisites and running instructions
- Documented that prompt templates are ONLY for EmbeddingGemma
- Added integration marker to pyproject.toml

Test Results:
- All 8 integration tests passing with live services
- Confirmed prompt templates work correctly with EmbeddingGemma
- Verified LM Studio SDK bridge auto-detects context length (2048)
- Validated hybrid token limit discovery across all backends

* Add prompt template support to Ollama mode

Extends prompt template functionality from OpenAI mode to Ollama for backend consistency.

Changes:
- Add provider_options parameter to compute_embeddings_ollama()
- Apply prompt template before token truncation (lines 1005-1011)
- Pass provider_options through compute_embeddings() call chain

Tests:
- test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama
- test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template
- Both tests pass with live Ollama service (2/2 passing)

Usage:
leann build --embedding-mode ollama --embedding-prompt-template "query: " ...

* Fix LM Studio SDK bridge to respect JIT auto-evict settings

Problem: SDK bridge called client.embedding.load() which loaded models into
LM Studio memory and bypassed JIT auto-evict settings, causing duplicate
model instances to accumulate.

Root cause analysis (from Perplexity research):
- Explicit SDK load() commands are treated as "pinned" models
- JIT auto-evict only applies to models loaded reactively via API requests
- SDK-loaded models remain in memory until explicitly unloaded

Solutions implemented:

1. Add model.unload() after metadata query (line 243)
   - Load model temporarily to get context length
   - Unload immediately to hand control back to JIT system
   - Subsequent API requests trigger JIT load with auto-evict

2. Add token limit caching to prevent repeated SDK calls
   - Cache discovered limits in _token_limit_cache dict (line 48)
   - Key: (model_name, base_url), Value: token_limit
   - Prevents duplicate load/unload cycles within same process
   - Cache shared across all discovery methods (Ollama, SDK, registry)

Tests:
- TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py)
- Manual testing confirmed no duplicate models in LM Studio after fix
- All existing tests pass

Impact:
- Respects user's LM Studio JIT and auto-evict settings
- Reduces model memory footprint
- Faster subsequent builds (cached limits)

* Document prompt template and LM Studio SDK features

Added comprehensive documentation for new optional embedding features:

Configuration Guide (docs/configuration-guide.md):
- New section: "Optional Embedding Features"
- Task-Specific Prompt Templates subsection:
  - Explains EmbeddingGemma use case with document/query prompts
  - CLI and Python API examples
  - Clear warnings about compatible vs incompatible models
  - References to GitHub issue #155 and HuggingFace blog
- LM Studio Auto-Detection subsection:
  - Prerequisites (Node.js + @lmstudio/sdk)
  - How auto-detection works (4-step process)
  - Benefits and optional nature clearly stated

FAQ (docs/faq.md):
- FAQ #2: When should I use prompt templates?
  - DO/DON'T guidance with examples
  - Links to detailed configuration guide
- FAQ #3: Why is LM Studio loading multiple copies?
  - Explains the JIT auto-evict fix
  - Troubleshooting steps if still seeing issues
- FAQ #4: Do I need Node.js and @lmstudio/sdk?
  - Clarifies it's completely optional
  - Lists benefits if installed
  - Installation instructions

Cross-references between documents for easy navigation between quick reference and detailed guides.

* Add separate build/query template support for task-specific models

Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility.

* Consolidate prompt template tests from 44 to 37 tests

Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing.

* Fix query template application in compute_query_embedding

Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches.

Changes:

- Move template application to before any computation path (searcher_base.py:109-110)

- Add comprehensive tests for both server and fallback paths

- Consolidate tests into test_prompt_template_persistence.py

Tests verify:

- Template applied when using embedding server

- Template applied in fallback path

- Consistent behavior between both paths

* Apply ruff formatting and fix linting issues

- Remove unused imports

- Fix import ordering

- Remove unused variables

- Apply code formatting

* Fix CI test failures: mock OPENAI_API_KEY in tests

Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
2025-11-14 15:25:17 -08:00

644 lines
25 KiB
Python

"""Unit tests for token-aware truncation functionality.
This test suite defines the contract for token truncation functions that prevent
500 errors from Ollama when text exceeds model token limits. These tests verify:
1. Model token limit retrieval (known and unknown models)
2. Text truncation behavior for single and multiple texts
3. Token counting and truncation accuracy using tiktoken
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
import pytest
import tiktoken
from leann.embedding_compute import (
EMBEDDING_MODEL_LIMITS,
get_model_token_limit,
truncate_to_token_limit,
)
class TestModelTokenLimits:
"""Tests for retrieving model-specific token limits."""
def test_get_model_token_limit_known_model(self):
"""Verify correct token limit is returned for known models.
Known models should return their specific token limits from
EMBEDDING_MODEL_LIMITS dictionary.
"""
# Test nomic-embed-text (2048 tokens)
limit = get_model_token_limit("nomic-embed-text")
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
# Test nomic-embed-text-v1.5 (2048 tokens)
limit = get_model_token_limit("nomic-embed-text-v1.5")
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
# Test nomic-embed-text-v2 (512 tokens)
limit = get_model_token_limit("nomic-embed-text-v2")
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
# Test OpenAI models (8192 tokens)
limit = get_model_token_limit("text-embedding-3-small")
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
def test_get_model_token_limit_unknown_model(self):
"""Verify default token limit is returned for unknown models.
Unknown models should return the default limit (2048) to allow
operation with reasonable safety margin.
"""
# Test with completely unknown model
limit = get_model_token_limit("unknown-model-xyz")
assert limit == 2048, "Unknown models should return default 2048"
# Test with empty string
limit = get_model_token_limit("")
assert limit == 2048, "Empty model name should return default 2048"
def test_get_model_token_limit_custom_default(self):
"""Verify custom default can be specified for unknown models.
Allow callers to specify their own default token limit when
model is not in the known models dictionary.
"""
limit = get_model_token_limit("unknown-model", default=4096)
assert limit == 4096, "Should return custom default for unknown models"
# Known model should ignore custom default
limit = get_model_token_limit("nomic-embed-text", default=4096)
assert limit == 2048, "Known model should ignore custom default"
def test_embedding_model_limits_dictionary_exists(self):
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
The dictionary should be importable and contain at least the
known nomic models with correct token limits.
"""
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
"Should contain nomic-embed-text-v1.5"
)
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
# OpenAI models
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
class TestTokenTruncation:
"""Tests for truncating texts to token limits."""
@pytest.fixture
def tokenizer(self):
"""Provide tiktoken tokenizer for token counting verification."""
return tiktoken.get_encoding("cl100k_base")
def test_truncate_single_text_under_limit(self, tokenizer):
"""Verify text under token limit remains unchanged.
When text is already within the token limit, it should be
returned unchanged with no truncation.
"""
text = "This is a short text that is well under the token limit."
token_count = len(tokenizer.encode(text))
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
# Truncate with generous limit
result = truncate_to_token_limit([text], token_limit=512)
assert len(result) == 1, "Should return same number of texts"
assert result[0] == text, "Text under limit should be unchanged"
def test_truncate_single_text_over_limit(self, tokenizer):
"""Verify text over token limit is truncated correctly.
When text exceeds the token limit, it should be truncated to
fit within the limit while maintaining valid token boundaries.
"""
# Create a text that definitely exceeds limit
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 50, (
f"Test setup: text should be long (has {original_token_count} tokens)"
)
# Truncate to 50 tokens
result = truncate_to_token_limit([text], token_limit=50)
assert len(result) == 1, "Should return same number of texts"
assert result[0] != text, "Text over limit should be truncated"
assert len(result[0]) < len(text), "Truncated text should be shorter"
# Verify truncated text is within token limit
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 50, (
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
)
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
"""Verify multiple texts with mixed lengths are handled correctly.
When processing multiple texts:
- Texts under limit should remain unchanged
- Texts over limit should be truncated independently
- Output list should maintain same order and length
"""
texts = [
"Short text.", # Under limit
"word " * 200, # Over limit
"Another short one.", # Under limit
"token " * 150, # Over limit
]
# Verify test setup
for i, text in enumerate(texts):
token_count = len(tokenizer.encode(text))
if i in [1, 3]:
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
else:
assert token_count < 50, (
f"Text {i} should be under limit (has {token_count} tokens)"
)
# Truncate with 50 token limit
result = truncate_to_token_limit(texts, token_limit=50)
assert len(result) == len(texts), "Should return same number of texts"
# Verify each text individually
for i, (original, truncated) in enumerate(zip(texts, result)):
token_count = len(tokenizer.encode(truncated))
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
# Short texts should be unchanged
if i in [0, 2]:
assert truncated == original, f"Short text {i} should be unchanged"
# Long texts should be truncated
else:
assert len(truncated) < len(original), f"Long text {i} should be truncated"
def test_truncate_empty_list(self):
"""Verify empty input list returns empty output list.
Edge case: empty list should return empty list without errors.
"""
result = truncate_to_token_limit([], token_limit=512)
assert result == [], "Empty input should return empty output"
def test_truncate_preserves_order(self, tokenizer):
"""Verify truncation preserves original text order.
Output list should maintain the same order as input list,
regardless of which texts were truncated.
"""
texts = [
"First text " * 50, # Will be truncated
"Second text.", # Won't be truncated
"Third text " * 50, # Will be truncated
]
result = truncate_to_token_limit(texts, token_limit=20)
assert len(result) == 3, "Should preserve list length"
# Check that order is maintained by looking for distinctive words
assert "First" in result[0], "First text should remain in first position"
assert "Second" in result[1], "Second text should remain in second position"
assert "Third" in result[2], "Third text should remain in third position"
def test_truncate_extremely_long_text(self, tokenizer):
"""Verify extremely long texts are truncated efficiently.
Test with text that far exceeds token limit to ensure
truncation handles extreme cases without performance issues.
"""
# Create very long text (simulate real-world scenario)
text = "token " * 5000 # ~5000+ tokens
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 1000, "Test setup: text should be very long"
# Truncate to small limit
result = truncate_to_token_limit([text], token_limit=100)
assert len(result) == 1
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 100, (
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
)
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
def test_truncate_exact_token_limit(self, tokenizer):
"""Verify text at exactly token limit is handled correctly.
Edge case: text with exactly the token limit should either
remain unchanged or be safely truncated by 1 token.
"""
# Create text with approximately 50 tokens
# We'll adjust to get exactly 50
target_tokens = 50
text = "word " * 50
tokens = tokenizer.encode(text)
# Adjust to get exactly target_tokens
if len(tokens) > target_tokens:
tokens = tokens[:target_tokens]
text = tokenizer.decode(tokens)
elif len(tokens) < target_tokens:
# Add more words
while len(tokenizer.encode(text)) < target_tokens:
text += "word "
tokens = tokenizer.encode(text)[:target_tokens]
text = tokenizer.decode(tokens)
# Verify we have exactly target_tokens
assert len(tokenizer.encode(text)) == target_tokens, (
"Test setup: should have exactly 50 tokens"
)
result = truncate_to_token_limit([text], token_limit=target_tokens)
assert len(result) == 1
result_tokens = len(tokenizer.encode(result[0]))
assert result_tokens <= target_tokens, (
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
)
class TestLMStudioHybridDiscovery:
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
These tests verify that get_model_token_limit() properly integrates with
the LM Studio SDK bridge for dynamic token limit discovery. The integration
should:
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
2. Convert HTTP URLs to WebSocket format for SDK queries
3. Query LM Studio SDK and use discovered limit
4. Fall back to registry when SDK returns None
5. Execute AFTER Ollama detection but BEFORE registry fallback
All tests are written in Red Phase - they should FAIL initially because the
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
"""
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
"""Verify LM Studio SDK query succeeds and returns detected limit.
When a LM Studio base_url is detected and the SDK query succeeds,
get_model_token_limit() should return the dynamically discovered
context length without falling back to the registry.
"""
# Mock _query_lmstudio_context_limit to return successful SDK query
def mock_query_lmstudio(model_name, base_url):
# Verify WebSocket URL was passed (not HTTP)
assert base_url.startswith("ws://"), (
f"Should convert HTTP to WebSocket format, got: {base_url}"
)
return 8192 # Successful SDK query
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with HTTP URL that should be converted to WebSocket
limit = get_model_token_limit(
model_name="custom-model", base_url="http://localhost:1234/v1"
)
assert limit == 8192, "Should return limit from LM Studio SDK query"
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
"""Verify fallback to registry when LM Studio SDK returns None.
When LM Studio SDK query fails (returns None), get_model_token_limit()
should fall back to the EMBEDDING_MODEL_LIMITS registry.
"""
# Mock _query_lmstudio_context_limit to return None (SDK failure)
def mock_query_lmstudio(model_name, base_url):
return None # SDK query failed
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with known model that exists in registry
limit = get_model_token_limit(
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
)
# Should fall back to registry value
assert limit == 2048, "Should fall back to registry when SDK returns None"
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
"""Verify detection of LM Studio via port 1234.
get_model_token_limit() should recognize port 1234 as a LM Studio
server and attempt SDK query, regardless of hostname.
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return 4096
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with port 1234 (default LM Studio port)
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
assert limit == 4096, "Should return SDK query result"
@pytest.mark.parametrize(
"test_url,expected_limit,keyword",
[
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
],
)
def test_get_model_token_limit_lmstudio_url_keyword_detection(
self, monkeypatch, test_url, expected_limit, keyword
):
"""Verify detection of LM Studio via keywords in URL.
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
in the URL as indicating a LM Studio server.
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return expected_limit
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
@pytest.mark.parametrize(
"input_url,expected_protocol,expected_host",
[
("http://localhost:1234/v1", "ws://", "localhost:1234"),
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
],
)
def test_get_model_token_limit_protocol_conversion(
self, monkeypatch, input_url, expected_protocol, expected_host
):
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
1. Convert 'http://' to 'ws://'
2. Convert 'https://' to 'wss://'
3. Remove '/v1' or other path suffixes (SDK expects base URL)
4. Preserve host and port
"""
conversions_tested = []
def mock_query_lmstudio(model_name, base_url):
conversions_tested.append(base_url)
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
get_model_token_limit(model_name="test-model", base_url=input_url)
# Verify conversion happened
assert len(conversions_tested) == 1, "Should have called SDK query once"
assert conversions_tested[0].startswith(expected_protocol), (
f"Should convert to {expected_protocol}"
)
assert expected_host in conversions_tested[0], (
f"Should preserve host and port: {expected_host}"
)
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
"""Verify LM Studio detection happens AFTER Ollama detection.
The hybrid discovery order should be:
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
3. Registry fallback
If both Ollama and LM Studio patterns match, Ollama should take precedence.
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
"""
ollama_called = False
lmstudio_called = False
def mock_query_ollama(model_name, base_url):
nonlocal ollama_called
ollama_called = True
return 2048 # Ollama query succeeds
def mock_query_lmstudio(model_name, base_url):
nonlocal lmstudio_called
lmstudio_called = True
return None # Should not be reached if Ollama succeeds
monkeypatch.setattr(
"leann.embedding_compute._query_ollama_context_limit",
mock_query_ollama,
)
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with Ollama URL
limit = get_model_token_limit(
model_name="test-model", base_url="http://localhost:11434/api"
)
assert ollama_called, "Should attempt Ollama query first"
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
assert limit == 2048, "Should return Ollama result"
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
"""
lmstudio_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal lmstudio_called
lmstudio_called = True
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with non-LM Studio URLs
test_cases = [
"http://localhost:8080/v1", # Different port
"http://openai.example.com/v1", # Different service
"http://localhost:3000/v1", # Another port
]
for base_url in test_cases:
lmstudio_called = False # Reset for each test
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
"""Verify LM Studio detection is case-insensitive for keywords.
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test various case variations
test_cases = [
"http://LMStudio.local:8080/v1",
"http://LMSTUDIO.example.com/v1",
"http://LmStudio.local/v1",
"http://api.LM.STUDIO:5000/v1",
]
for base_url in test_cases:
query_called = False # Reset for each test
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
assert query_called, f"Should detect LM Studio in URL: {base_url}"
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
class TestTokenLimitCaching:
"""Tests for token limit caching to prevent repeated SDK/API calls.
Caching prevents duplicate SDK/API calls within the same Python process,
which is important because:
1. LM Studio SDK load() can load duplicate model instances
2. Ollama /api/show queries add latency
3. Registry lookups are pure overhead
Cache is process-scoped and resets between leann build invocations.
"""
def setup_method(self):
"""Clear cache before each test."""
from leann.embedding_compute import _token_limit_cache
_token_limit_cache.clear()
def test_registry_lookup_is_cached(self):
"""Verify that registry lookups are cached."""
from leann.embedding_compute import _token_limit_cache
# First call
limit1 = get_model_token_limit("text-embedding-3-small")
assert limit1 == 8192
# Verify it's in cache
cache_key = ("text-embedding-3-small", "")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 8192
# Second call should use cache
limit2 = get_model_token_limit("text-embedding-3-small")
assert limit2 == 8192
def test_default_fallback_is_cached(self):
"""Verify that default fallbacks are cached."""
from leann.embedding_compute import _token_limit_cache
# First call with unknown model
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
assert limit1 == 512
# Verify it's in cache
cache_key = ("unknown-model-xyz", "")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 512
# Second call should use cache
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
assert limit2 == 512
def test_different_urls_create_separate_cache_entries(self):
"""Verify that different base_urls create separate cache entries."""
from leann.embedding_compute import _token_limit_cache
# Same model, different URLs
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
# Both should find the model in registry (2048)
assert limit1 == 2048
assert limit2 == 2048
# But they should be separate cache entries
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
assert cache_key1 in _token_limit_cache
assert cache_key2 in _token_limit_cache
assert len(_token_limit_cache) == 2
def test_cache_prevents_repeated_lookups(self):
"""Verify that cache prevents repeated registry/API lookups."""
from leann.embedding_compute import _token_limit_cache
model_name = "text-embedding-ada-002"
# First call - should add to cache
assert len(_token_limit_cache) == 0
limit1 = get_model_token_limit(model_name)
cache_size_after_first = len(_token_limit_cache)
assert cache_size_after_first == 1
# Multiple subsequent calls - cache size should not change
for _ in range(5):
limit = get_model_token_limit(model_name)
assert limit == limit1
assert len(_token_limit_cache) == cache_size_after_first
def test_versioned_model_names_cached_correctly(self):
"""Verify that versioned model names (e.g., model:tag) are cached."""
from leann.embedding_compute import _token_limit_cache
# Model with version tag
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
assert limit == 2048
# Should be cached with full name including version
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 2048