Compare commits

...

38 Commits

Author SHA1 Message Date
Andy Lee
8e6aa34afd Exclude macos-15-intel + Python 3.13 (no PyTorch wheels available) 2025-12-25 01:22:32 +00:00
Andy Lee
5791367d13 Fix macos-15-intel deployment target
The macos-15-intel runner runs macOS 15.7, so Homebrew libraries are
built for macOS 14+. Setting MACOSX_DEPLOYMENT_TARGET=13.0 causes
delocate to fail because system libraries require newer macOS.

Fix by setting deployment target to 15.0 for macos-15-intel, matching
the actual OS version. Intel Mac users will need macOS 15+.
2025-12-24 05:35:36 +00:00
Andy Lee
674977a950 Add macOS 26 (beta) to build matrix
Add macos-26 (arm64) runner to the build matrix for testing future
macOS compatibility. This is currently a beta runner that helps ensure
wheels work on upcoming macOS versions.
2025-12-24 01:48:09 +00:00
Andy Lee
56785d30ee Add macos-15-intel for Intel Mac builds (free runner)
Use macos-15-intel (free standard runner) instead of macos-15-large
(paid). This provides Intel Mac wheel support until Aug 2027.

- MACOSX_DEPLOYMENT_TARGET=13.0 for backward compatibility
- Replaces deprecated macos-13 runner
2025-12-24 01:44:12 +00:00
Andy Lee
a73640f95e Remove Intel Mac builds (macos-15-large requires paid plan)
Intel Mac users can build from source. This avoids:
- Paid GitHub Actions runners (macos-15-large)
- Complex cross-compilation setup
2025-12-24 01:06:07 +00:00
Andy Lee
47b91f7313 Set MACOSX_DEPLOYMENT_TARGET=13.x for Intel builds
Intel Mac wheels (macos-15-large) now target macOS 13.0/13.3 for
backward compatibility, allowing macOS 13/14/15 Intel users to
install pre-built wheels.
2025-12-24 01:03:15 +00:00
Andy Lee
7601e0b112 Add macos-15-large for Intel Mac builds
Replace deprecated macos-13 with macos-15-large (x86_64 Intel)
to continue supporting Intel Mac users.
2025-12-24 01:01:08 +00:00
Andy Lee
2a22ec1b26 Remove macos-13 from CI build matrix
GitHub Actions deprecated macos-13 runner (brownout started Sept 2025,
fully retired Dec 2025). See: https://github.blog/changelog/2025-09-19-github-actions-macos-13-runner-image-is-closing-down/
2025-12-24 00:59:15 +00:00
Andy Lee
530507d39d Drop Python 3.9 support, require Python 3.10+
Python 3.9 reached end-of-life and the codebase uses PEP 604 union
type syntax (str | None) which requires Python 3.10+.

Changes:
- Remove Python 3.9 from CI build matrix
- Update requires-python to >=3.10 in all pyproject.toml files
- Update classifiers to reflect supported Python versions (3.10-3.13)
2025-12-24 00:54:44 +00:00
Andy Lee
8a2ea37871 Fix: handle dict format from create_text_chunks (introduced in PR #157)
PR #157 changed create_text_chunks() to return list[dict] instead of
list[str] to preserve metadata, but base_rag_example.py was not updated
to handle the new format. This caused all chunks to fail validation
with "All provided chunks are empty or invalid".
2025-12-23 08:50:31 +00:00
Yichuan Wang
7ddb4772c0 Feature/custom folder multi vector/ add Readme to LEANN MCP (#189)
* Add custom folder support and improve image loading for multi-vector retrieval

- Enhanced _load_images_from_dir with recursive search support and better error handling
- Added support for WebP format and RGB conversion for all image modes
- Added custom folder CLI arguments (--custom-folder, --recursive, --rebuild-index)
- Improved documentation and removed completed TODO comment

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* Format code style in leann_multi_vector.py for better readability

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* docs: polish README performance tip section

- Fix typo: 'matrilize' -> 'materialize'
- Improve clarity and formatting of --no-recompute flag explanation
- Add code block for better readability

* format

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-19 17:29:14 -08:00
aakash
a1c21adbce Move COLQWEN_GUIDE.md to docs and remove test_colqwen_reproduction.py 2025-12-19 13:57:47 -08:00
Aakash Suresh
d1b3c93a5a Merge pull request #162 from yichuan-w/feature/colqwen-integration
add ColQwen multimodal PDF retrieval integration
2025-12-19 13:53:29 -08:00
Yichuan Wang
a6ee95b18a Add custom folder support and improve image loading for multi-vector … (#188)
* Add custom folder support and improve image loading for multi-vector retrieval

- Enhanced _load_images_from_dir with recursive search support and better error handling
- Added support for WebP format and RGB conversion for all image modes
- Added custom folder CLI arguments (--custom-folder, --recursive, --rebuild-index)
- Improved documentation and removed completed TODO comment

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* Format code style in leann_multi_vector.py for better readability

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-17 01:03:45 -08:00
Alex
17cbd07b25 Add Anthropic LLM support (#185)
* Add Anthropic LLM support

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>

* Update skypilot link

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>

* Handle anthropic base_url

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>

* Address ruff format finding

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>

---------

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>
2025-12-12 10:53:41 -08:00
Alex
3629ccf8f7 Use logger instead of print (#186)
Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>
2025-12-10 13:48:57 -08:00
aakash
0175bc9c20 docs: Add ColQwen guide to docs directory
Add COLQWEN_GUIDE.md to docs/ directory for proper documentation structure.
This file is referenced in the README and needs to be tracked in git.
2025-12-07 09:57:14 -08:00
aakash
af47dfdde7 fix: Update ColQwen guide link to docs/ directory 2025-12-06 03:33:02 -08:00
aakash
f13bd02fbd docs: Add ColQwen multimodal PDF retrieval to README
Add brief introduction and usage guide for ColQwen integration,
similar to other RAG application sections in the README.

- Quick start examples for building, searching, and interactive Q&A
- Setup instructions with prerequisites
- Model options (ColQwen2 vs ColPali)
- Link to detailed ColQwen guide
2025-12-06 03:28:08 -08:00
Yichuan Wang
a0bbf831db Add ColQwen2.5 model support and improve model selection (#183)
- Add ColQwen2.5 and ColQwen2_5_Processor imports
- Implement smart model type detection for colqwen2, colqwen2.5, and colpali
- Add task name aliases for easier benchmark invocation
- Add safe model name handling for file paths and index naming
- Support custom model paths including LoRA adapters
- Improve model choice validation and error handling

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-05 03:36:55 -08:00
aakash
86287d8832 Revert unnecessary faiss submodule update
Reset faiss submodule to match main branch to avoid unnecessary changes
2025-12-03 18:32:04 -08:00
Yichuan Wang
76cc798e3e Feat/multi vector timing and dataset improvements (#181)
* Add timing instrumentation and multi-dataset support for multi-vector retrieval

- Add timing measurements for search operations (load and core time)
- Increase embedding batch size from 1 to 32 for better performance
- Add explicit memory cleanup with del all_embeddings
- Support loading and merging multiple datasets with different splits
- Add CLI arguments for search method selection (ann/exact/exact-all)
- Auto-detect image field names across different dataset structures
- Print candidate doc counts for performance monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* update vidore

* reproduce docvqa results

* reproduce docvqa results and add debug file

* fix: format colqwen_forward.py to pass pre-commit checks

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-03 01:10:49 -08:00
Yichuan Wang
d599566fd7 Revert "[Multi-vector]Add timing instrumentation and multi-dataset support fo…" (#180)
This reverts commit 00770aebbb.
2025-12-03 01:09:39 -08:00
Yichuan Wang
00770aebbb [Multi-vector]Add timing instrumentation and multi-dataset support for multi-vector… (#161)
* Add timing instrumentation and multi-dataset support for multi-vector retrieval

- Add timing measurements for search operations (load and core time)
- Increase embedding batch size from 1 to 32 for better performance
- Add explicit memory cleanup with del all_embeddings
- Support loading and merging multiple datasets with different splits
- Add CLI arguments for search method selection (ann/exact/exact-all)
- Auto-detect image field names across different dataset structures
- Print candidate doc counts for performance monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* update vidore

* reproduce docvqa results

* reproduce docvqa results and add debug file

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-03 00:55:42 -08:00
Aakash Suresh
e268392d5b Fix: Prevent duplicate PDF processing when using --file-types .pdf (#179)
Fixes #175

Problem:
When --file-types .pdf is specified, PDFs were being processed twice:
1. Separately with PyMuPDF/pdfplumber extractors
2. Again in the 'other file types' section via SimpleDirectoryReader

This caused duplicate processing and potential conflicts.

Solution:
- Exclude .pdf from other_file_extensions when PDFs are already
  processed separately
- Only load other file types if there are extensions to process
- Prevents duplicate PDF processing

Changes:
- Added logic to filter out .pdf from code_extensions when loading
  other file types if PDFs were processed separately
- Updated SimpleDirectoryReader to use filtered extensions
- Added check to skip loading if no other extensions to process
2025-12-01 13:48:44 -08:00
Andy Lee
eb909ccec5 docs: survey 2025-11-20 09:20:45 -08:00
ww26
969f514564 Fix prompt template bugs: build template ignored and runtime override not wired (#173)
* Fix prompt template bugs in build and search

Bug 1: Build template ignored in new format
- Updated compute_embeddings_openai() to read build_prompt_template or prompt_template
- Updated compute_embeddings_ollama() with same fix
- Maintains backward compatibility with old single-template format

Bug 2: Runtime override not wired up
- Wired CLI search to pass provider_options to searcher.search()
- Enables runtime template override during search via --embedding-prompt-template

All 42 prompt template tests passing.

Fixes #155

* Fix: Prevent embedding server from applying templates during search

- Filter out all prompt templates (build_prompt_template, query_prompt_template, prompt_template) from provider_options when launching embedding server during search
- Templates are already applied in compute_query_embedding() before server call
- Prevents double-templating and ensures runtime override works correctly

This fixes the issue where --embedding-prompt-template during search was ignored because the server was applying build_prompt_template instead.

* Format code with ruff
2025-11-16 20:56:42 -08:00
ww26
1ef9cba7de Feature/prompt templates and lmstudio sdk (#171)
* Add prompt template support and LM Studio SDK integration

Features:

- Prompt template support for embedding models (via --embedding-prompt-template)

- LM Studio SDK integration for automatic context length detection

- Hybrid token limit discovery (Ollama → LM Studio → Registry → Default)

- Client-side token truncation to prevent silent failures

- Automatic persistence of embedding_options to .meta.json

Implementation:

- Added _query_lmstudio_context_limit() with Node.js subprocess bridge

- Modified compute_embeddings_openai() to apply prompt templates before truncation

- Extended CLI with --embedding-prompt-template flag for build and search

- URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords)

- HTTP→WebSocket URL conversion for SDK compatibility

Tests:

- 60 passing tests across 5 test files

- Comprehensive coverage of prompt templates, LM Studio integration, and token handling

- Parametrized tests for maintainability and clarity

* Add integration tests and fix LM Studio SDK bridge

Features:
- End-to-end integration tests for prompt template with EmbeddingGemma
- Integration tests for hybrid token limit discovery mechanism
- Tests verify real-world functionality with live services (LM Studio, Ollama)

Fixes:
- LM Studio SDK bridge now uses client.embedding.load() for embedding models
- Fixed NODE_PATH resolution to include npm global modules
- Fixed integration test to use WebSocket URL (ws://) for SDK bridge

Tests:
- test_prompt_template_e2e.py: 8 integration tests covering:
  - Prompt template prepending with LM Studio (EmbeddingGemma)
  - LM Studio SDK bridge for context length detection
  - Ollama dynamic token limit detection
  - Hybrid discovery fallback mechanism (registry, default)
- All tests marked with @pytest.mark.integration for selective execution
- Tests gracefully skip when services unavailable

Documentation:
- Updated tests/README.md with integration test section
- Added prerequisites and running instructions
- Documented that prompt templates are ONLY for EmbeddingGemma
- Added integration marker to pyproject.toml

Test Results:
- All 8 integration tests passing with live services
- Confirmed prompt templates work correctly with EmbeddingGemma
- Verified LM Studio SDK bridge auto-detects context length (2048)
- Validated hybrid token limit discovery across all backends

* Add prompt template support to Ollama mode

Extends prompt template functionality from OpenAI mode to Ollama for backend consistency.

Changes:
- Add provider_options parameter to compute_embeddings_ollama()
- Apply prompt template before token truncation (lines 1005-1011)
- Pass provider_options through compute_embeddings() call chain

Tests:
- test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama
- test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template
- Both tests pass with live Ollama service (2/2 passing)

Usage:
leann build --embedding-mode ollama --embedding-prompt-template "query: " ...

* Fix LM Studio SDK bridge to respect JIT auto-evict settings

Problem: SDK bridge called client.embedding.load() which loaded models into
LM Studio memory and bypassed JIT auto-evict settings, causing duplicate
model instances to accumulate.

Root cause analysis (from Perplexity research):
- Explicit SDK load() commands are treated as "pinned" models
- JIT auto-evict only applies to models loaded reactively via API requests
- SDK-loaded models remain in memory until explicitly unloaded

Solutions implemented:

1. Add model.unload() after metadata query (line 243)
   - Load model temporarily to get context length
   - Unload immediately to hand control back to JIT system
   - Subsequent API requests trigger JIT load with auto-evict

2. Add token limit caching to prevent repeated SDK calls
   - Cache discovered limits in _token_limit_cache dict (line 48)
   - Key: (model_name, base_url), Value: token_limit
   - Prevents duplicate load/unload cycles within same process
   - Cache shared across all discovery methods (Ollama, SDK, registry)

Tests:
- TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py)
- Manual testing confirmed no duplicate models in LM Studio after fix
- All existing tests pass

Impact:
- Respects user's LM Studio JIT and auto-evict settings
- Reduces model memory footprint
- Faster subsequent builds (cached limits)

* Document prompt template and LM Studio SDK features

Added comprehensive documentation for new optional embedding features:

Configuration Guide (docs/configuration-guide.md):
- New section: "Optional Embedding Features"
- Task-Specific Prompt Templates subsection:
  - Explains EmbeddingGemma use case with document/query prompts
  - CLI and Python API examples
  - Clear warnings about compatible vs incompatible models
  - References to GitHub issue #155 and HuggingFace blog
- LM Studio Auto-Detection subsection:
  - Prerequisites (Node.js + @lmstudio/sdk)
  - How auto-detection works (4-step process)
  - Benefits and optional nature clearly stated

FAQ (docs/faq.md):
- FAQ #2: When should I use prompt templates?
  - DO/DON'T guidance with examples
  - Links to detailed configuration guide
- FAQ #3: Why is LM Studio loading multiple copies?
  - Explains the JIT auto-evict fix
  - Troubleshooting steps if still seeing issues
- FAQ #4: Do I need Node.js and @lmstudio/sdk?
  - Clarifies it's completely optional
  - Lists benefits if installed
  - Installation instructions

Cross-references between documents for easy navigation between quick reference and detailed guides.

* Add separate build/query template support for task-specific models

Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility.

* Consolidate prompt template tests from 44 to 37 tests

Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing.

* Fix query template application in compute_query_embedding

Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches.

Changes:

- Move template application to before any computation path (searcher_base.py:109-110)

- Add comprehensive tests for both server and fallback paths

- Consolidate tests into test_prompt_template_persistence.py

Tests verify:

- Template applied when using embedding server

- Template applied in fallback path

- Consistent behavior between both paths

* Apply ruff formatting and fix linting issues

- Remove unused imports

- Fix import ordering

- Remove unused variables

- Apply code formatting

* Fix CI test failures: mock OPENAI_API_KEY in tests

Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
2025-11-14 15:25:17 -08:00
yichuan520030910320
a63550944b update pkg name 2025-11-13 14:57:28 -08:00
yichuan520030910320
97493a2896 update module to install 2025-11-13 14:53:25 -08:00
Aakash Suresh
f7d2dc6e7c Merge pull request #163 from orbisai0security/fix-semgrep-python.lang.security.audit.eval-detected.eval-detected-apps-slack-data-slack-mc-b134f52c-f326
Fix: Unsafe Code Execution Function Could Allow External Code Injection in apps/slack_data/slack_mcp_reader.py
2025-11-13 14:35:44 -08:00
Gil Vernik
ea86b283cb [bug] fix when no package name found (#167)
* fix when no package name

* fix pre-commit style issue

* fix to ruff (legacy alias) fail test
2025-11-13 14:23:13 -08:00
aakash
e7519bceaa Fix CI: sync uv.lock from main and remove .lycheeignore (workflow exclusion is sufficient) 2025-11-13 13:10:07 -08:00
aakash
abf0b2c676 Fix CI: improve security fix and add link checker configuration
- Fix import order (ast before asyncio)
- Remove NameError from exception handling (ast.literal_eval doesn't raise it)
- Add .lycheeignore to exclude intermittently unavailable star-history API
- Update link-check workflow to exclude star-history API and accept 503 status codes
2025-11-13 13:05:00 -08:00
GitHub Actions
3c4785bb63 chore: release v0.3.5 2025-11-12 06:01:25 +00:00
orbisai0security
930b79cc98 fix: semgrep_python.lang.security.audit.eval-detected.eval-detected_apps/slack_data/slack_mcp_reader.py_157 2025-11-12 03:47:18 +00:00
yichuan-w
3766ad1fd2 robust multi-vector 2025-11-09 02:34:53 +00:00
ww26
c3aceed1e0 metadata reveal for ast-chunking; smart detection of seq length in ollama; auto adjust chunk length for ast to prevent silent truncation (#157)
* feat: enhance token limits with dynamic discovery + AST metadata

Improves upon upstream PR #154 with two major enhancements:

1. **Hybrid Token Limit Discovery**
   - Dynamic: Query Ollama /api/show for context limits
   - Fallback: Registry for LM Studio/OpenAI
   - Zero maintenance for Ollama users
   - Respects custom num_ctx settings

2. **AST Metadata Preservation**
   - create_ast_chunks() returns dict format with metadata
   - Preserves file_path, file_name, timestamps
   - Includes astchunk metadata (line numbers, node counts)
   - Fixes content extraction bug (checks "content" key)
   - Enables --show-metadata flag

3. **Better Token Limits**
   - nomic-embed-text: 2048 tokens (vs 512)
   - nomic-embed-text-v1.5: 2048 tokens
   - Added OpenAI models: 8192 tokens

4. **Comprehensive Tests**
   - 11 tests for token truncation
   - 545 new lines in test_astchunk_integration.py
   - All metadata preservation tests passing

* fix: merge EMBEDDING_MODEL_LIMITS and remove redundant validation

- Merged upstream's model list with our corrected token limits
- Kept our corrected nomic-embed-text: 2048 (not 512)
- Removed post-chunking validation (redundant with embedding-time truncation)
- All tests passing except 2 pre-existing integration test failures

* style: apply ruff formatting and restore PR #154 version handling

- Remove duplicate truncate_to_token_limit and get_model_token_limit functions
- Restore version handling logic (model:latest -> model) from PR #154
- Restore partial matching fallback for model name variations
- Apply ruff formatting to all modified files
- All 11 token truncation tests passing

* style: sort imports alphabetically (pre-commit auto-fix)

* fix: show AST token limit warning only once per session

- Add module-level flag to track if warning shown
- Prevents spam when processing multiple files
- Add clarifying note that auto-truncation happens at embedding time
- Addresses issue where warning appeared for every code file

* enhance: add detailed logging for token truncation

- Track and report truncation statistics (count, tokens removed, max length)
- Show first 3 individual truncations with exact token counts
- Provide comprehensive summary when truncation occurs
- Use WARNING level for data loss visibility
- Silent (DEBUG level only) when no truncation needed

Replaces misleading "truncated where necessary" message that appeared
even when nothing was truncated.
2025-11-08 17:37:31 -08:00
42 changed files with 10404 additions and 4891 deletions

View File

@@ -35,8 +35,8 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- os: ubuntu-22.04 # Note: Python 3.9 dropped - uses PEP 604 union syntax (str | None)
python: '3.9' # which requires Python 3.10+
- os: ubuntu-22.04 - os: ubuntu-22.04
python: '3.10' python: '3.10'
- os: ubuntu-22.04 - os: ubuntu-22.04
@@ -46,8 +46,6 @@ jobs:
- os: ubuntu-22.04 - os: ubuntu-22.04
python: '3.13' python: '3.13'
# ARM64 Linux builds # ARM64 Linux builds
- os: ubuntu-24.04-arm
python: '3.9'
- os: ubuntu-24.04-arm - os: ubuntu-24.04-arm
python: '3.10' python: '3.10'
- os: ubuntu-24.04-arm - os: ubuntu-24.04-arm
@@ -56,8 +54,6 @@ jobs:
python: '3.12' python: '3.12'
- os: ubuntu-24.04-arm - os: ubuntu-24.04-arm
python: '3.13' python: '3.13'
- os: macos-14
python: '3.9'
- os: macos-14 - os: macos-14
python: '3.10' python: '3.10'
- os: macos-14 - os: macos-14
@@ -66,8 +62,6 @@ jobs:
python: '3.12' python: '3.12'
- os: macos-14 - os: macos-14
python: '3.13' python: '3.13'
- os: macos-15
python: '3.9'
- os: macos-15 - os: macos-15
python: '3.10' python: '3.10'
- os: macos-15 - os: macos-15
@@ -76,16 +70,24 @@ jobs:
python: '3.12' python: '3.12'
- os: macos-15 - os: macos-15
python: '3.13' python: '3.13'
- os: macos-13 # Intel Mac builds (x86_64) - replaces deprecated macos-13
python: '3.9' # Note: Python 3.13 excluded - PyTorch has no wheels for macOS x86_64 + Python 3.13
- os: macos-13 # (PyTorch <=2.4.1 lacks cp313, PyTorch >=2.5.0 dropped Intel Mac support)
- os: macos-15-intel
python: '3.10' python: '3.10'
- os: macos-13 - os: macos-15-intel
python: '3.11' python: '3.11'
- os: macos-13 - os: macos-15-intel
python: '3.12' python: '3.12'
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility # macOS 26 (beta) - arm64
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64) - os: macos-26
python: '3.10'
- os: macos-26
python: '3.11'
- os: macos-26
python: '3.12'
- os: macos-26
python: '3.13'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@@ -204,13 +206,16 @@ jobs:
# Use system clang for better compatibility # Use system clang for better compatibility
export CC=clang export CC=clang
export CXX=clang++ export CXX=clang++
# Homebrew libraries on each macOS version require matching minimum version # Set deployment target based on runner
if [[ "${{ matrix.os }}" == "macos-13" ]]; then # macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
export MACOSX_DEPLOYMENT_TARGET=13.0 if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0 export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
export MACOSX_DEPLOYMENT_TARGET=26.0
fi fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else else
@@ -224,14 +229,16 @@ jobs:
# Use system clang for better compatibility # Use system clang for better compatibility
export CC=clang export CC=clang
export CXX=clang++ export CXX=clang++
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function # Set deployment target based on runner
# But Homebrew libraries on each macOS version require matching minimum version # macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
if [[ "${{ matrix.os }}" == "macos-13" ]]; then if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0 export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
export MACOSX_DEPLOYMENT_TARGET=26.0
fi fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else else
@@ -269,16 +276,19 @@ jobs:
if: runner.os == 'macOS' if: runner.os == 'macOS'
run: | run: |
# Determine deployment target based on runner OS # Determine deployment target based on runner OS
# Must match the Homebrew libraries for each macOS version # macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
if [[ "${{ matrix.os }}" == "macos-13" ]]; then if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
HNSW_TARGET="13.0"
DISKANN_TARGET="13.3"
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
HNSW_TARGET="15.0" HNSW_TARGET="15.0"
DISKANN_TARGET="15.0" DISKANN_TARGET="15.0"
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
HNSW_TARGET="15.0"
DISKANN_TARGET="15.0"
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
HNSW_TARGET="26.0"
DISKANN_TARGET="26.0"
fi fi
# Repair HNSW wheel # Repair HNSW wheel
@@ -334,12 +344,15 @@ jobs:
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')") PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
if [[ "$RUNNER_OS" == "macOS" ]]; then if [[ "$RUNNER_OS" == "macOS" ]]; then
if [[ "${{ matrix.os }}" == "macos-13" ]]; then # macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
export MACOSX_DEPLOYMENT_TARGET=13.3 if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0 export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
export MACOSX_DEPLOYMENT_TARGET=26.0
fi fi
fi fi

View File

@@ -14,6 +14,6 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2 - uses: lycheeverse/lychee-action@v2
with: with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/ args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

3
.gitignore vendored
View File

@@ -91,7 +91,8 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json *.meta.json
*.passages.json *.passages.json
*.npy
*.db
batchtest.py batchtest.py
tests/__pytest_cache__/ tests/__pytest_cache__/
tests/__pycache__/ tests/__pycache__/

View File

@@ -16,12 +16,24 @@
</a> </a>
</p> </p>
<div align="center">
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
</a>
<p>
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
</p>
</div>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto"> <h2 align="center" tabindex="-1" class="heading-element" dir="auto">
The smallest vector index in the world. RAG Everything with LEANN! The smallest vector index in the world. RAG Everything with LEANN!
</h2> </h2>
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**. LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276) LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy. **Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
@@ -189,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
#### LLM Backend #### LLM Backend
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API). LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
<details> <details>
@@ -257,6 +269,7 @@ Below is a list of base URLs for common providers to get you started.
| **SiliconFlow** | `https://api.siliconflow.cn/v1` | | **SiliconFlow** | `https://api.siliconflow.cn/v1` |
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` | | **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
| **Mistral AI** | `https://api.mistral.ai/v1` | | **Mistral AI** | `https://api.mistral.ai/v1` |
| **Anthropic** | `https://api.anthropic.com/v1` |
@@ -316,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama --embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models) # LLM Parameters (Text generation models)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai) --llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct --llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models) --thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
@@ -379,6 +392,54 @@ python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authenticat
</details> </details>
### 🎨 ColQwen: Multimodal PDF Retrieval with Vision-Language Models
Search through PDFs using both text and visual understanding with ColQwen2/ColPali models. Perfect for research papers, technical documents, and any PDFs with complex layouts, figures, or diagrams.
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
```bash
# Build index from PDFs
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
# Search with text queries
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
# Interactive Q&A
python -m apps.colqwen_rag ask research_papers --interactive
```
<details>
<summary><strong>📋 Click to expand: ColQwen Setup & Usage</strong></summary>
#### Prerequisites
```bash
# Install dependencies
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
brew install poppler # macOS only, for PDF processing
```
#### Build Index
```bash
python -m apps.colqwen_rag build \
--pdfs ./pdf_directory/ \
--index my_index \
--model colqwen2 # or colpali
```
#### Search
```bash
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
```
#### Models
- **ColQwen2** (`colqwen2`): Latest vision-language model with improved performance
- **ColPali** (`colpali`): Proven multimodal retriever
For detailed usage, see the [ColQwen Guide](docs/COLQWEN_GUIDE.md).
</details>
### 📧 Your Personal Email Secretary: RAG on Apple Mail! ### 📧 Your Personal Email Secretary: RAG on Apple Mail!
> **Note:** The examples below currently support macOS only. Windows support coming soon. > **Note:** The examples below currently support macOS only. Windows support coming soon.
@@ -1045,7 +1106,7 @@ Options:
leann ask INDEX_NAME [OPTIONS] leann ask INDEX_NAME [OPTIONS]
Options: Options:
--llm {ollama,openai,hf} LLM provider (default: ollama) --llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
--model MODEL Model name (default: qwen3:8b) --model MODEL Model name (default: qwen3:8b)
--interactive Interactive chat mode --interactive Interactive chat mode
--top-k N Retrieval count (default: 20) --top-k N Retrieval count (default: 20)

View File

@@ -6,7 +6,7 @@ Provides common parameters and functionality for all RAG examples.
import argparse import argparse
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Union
import dotenv import dotenv
from leann.api import LeannBuilder, LeannChat from leann.api import LeannBuilder, LeannChat
@@ -257,8 +257,8 @@ class BaseRAGExample(ABC):
pass pass
@abstractmethod @abstractmethod
async def load_data(self, args) -> list[str]: async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
"""Load data from the source. Returns list of text chunks.""" """Load data from the source. Returns list of text chunks (strings or dicts with 'text' key)."""
pass pass
def get_llm_config(self, args) -> dict[str, Any]: def get_llm_config(self, args) -> dict[str, Any]:
@@ -282,8 +282,8 @@ class BaseRAGExample(ABC):
return config return config
async def build_index(self, args, texts: list[str]) -> str: async def build_index(self, args, texts: list[Union[str, dict[str, Any]]]) -> str:
"""Build LEANN index from texts.""" """Build LEANN index from texts (accepts strings or dicts with 'text' key)."""
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann") index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
print(f"\n[Building Index] Creating {self.name} index...") print(f"\n[Building Index] Creating {self.name} index...")
@@ -314,8 +314,14 @@ class BaseRAGExample(ABC):
batch_size = 1000 batch_size = 1000
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size] batch = texts[i : i + batch_size]
for text in batch: for item in batch:
builder.add_text(text) # Handle both dict format (from create_text_chunks) and plain strings
if isinstance(item, dict):
text = item.get("text", "")
metadata = item.get("metadata")
builder.add_text(text, metadata)
else:
builder.add_text(item)
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...") print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
print("Building index structure...") print("Building index structure...")

View File

@@ -12,6 +12,7 @@ from pathlib import Path
try: try:
from leann.chunking_utils import ( from leann.chunking_utils import (
CODE_EXTENSIONS, CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks, create_ast_chunks,
create_text_chunks, create_text_chunks,
create_traditional_chunks, create_traditional_chunks,
@@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
sys.path.insert(0, str(leann_src)) sys.path.insert(0, str(leann_src))
from leann.chunking_utils import ( from leann.chunking_utils import (
CODE_EXTENSIONS, CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks, create_ast_chunks,
create_text_chunks, create_text_chunks,
create_traditional_chunks, create_traditional_chunks,
@@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
__all__ = [ __all__ = [
"CODE_EXTENSIONS", "CODE_EXTENSIONS",
"_traditional_chunks_as_dicts",
"create_ast_chunks", "create_ast_chunks",
"create_text_chunks", "create_text_chunks",
"create_traditional_chunks", "create_traditional_chunks",

View File

@@ -5,6 +5,7 @@ Supports PDF, TXT, MD, and other document formats.
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Union
# Add parent directory to path for imports # Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent))
@@ -51,7 +52,7 @@ class DocumentRAG(BaseRAGExample):
help="Enable AST-aware chunking for code files in the data directory", help="Enable AST-aware chunking for code files in the data directory",
) )
async def load_data(self, args) -> list[str]: async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
"""Load documents and convert to text chunks.""" """Load documents and convert to text chunks."""
print(f"Loading documents from: {args.data_dir}") print(f"Loading documents from: {args.data_dir}")
if args.file_types: if args.file_types:

View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
"""Simple test script to test colqwen2 forward pass with a single image."""
import os
import sys
from pathlib import Path
# Add the current directory to path to import leann_multi_vector
sys.path.insert(0, str(Path(__file__).parent))
import torch
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
from PIL import Image
# Ensure repo paths are importable
_ensure_repo_paths_importable(__file__)
# Set environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def create_test_image():
"""Create a simple test image."""
# Create a simple RGB image (800x600)
img = Image.new("RGB", (800, 600), color="white")
return img
def load_test_image_from_file():
"""Try to load an image from the indexes directory if available."""
# Try to find an existing image in the indexes directory
indexes_dir = Path(__file__).parent / "indexes"
# Look for images in common locations
possible_paths = [
indexes_dir / "vidore_fastplaid" / "images",
indexes_dir / "colvision_large.leann.images",
indexes_dir / "colvision.leann.images",
]
for img_dir in possible_paths:
if img_dir.exists():
# Find first image file
for ext in [".png", ".jpg", ".jpeg"]:
for img_file in img_dir.glob(f"*{ext}"):
print(f"Loading test image from: {img_file}")
return Image.open(img_file)
return None
def main():
print("=" * 60)
print("Testing ColQwen2 Forward Pass")
print("=" * 60)
# Step 1: Load or create test image
print("\n[Step 1] Loading test image...")
test_image = load_test_image_from_file()
if test_image is None:
print("No existing image found, creating a simple test image...")
test_image = create_test_image()
else:
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
# Convert to RGB if needed
if test_image.mode != "RGB":
test_image = test_image.convert("RGB")
print(f"✓ Converted to RGB: {test_image.size}")
# Step 2: Load model
print("\n[Step 2] Loading ColQwen2 model...")
try:
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
print(f"✓ Model loaded: {model_name}")
print(f"✓ Device: {device_str}, dtype: {dtype}")
# Print model info
if hasattr(model, "device"):
print(f"✓ Model device: {model.device}")
if hasattr(model, "dtype"):
print(f"✓ Model dtype: {model.dtype}")
except Exception as e:
print(f"✗ Error loading model: {e}")
import traceback
traceback.print_exc()
return
# Step 3: Test forward pass
print("\n[Step 3] Running forward pass...")
try:
# Use the _embed_images function which handles batching and forward pass
images = [test_image]
print(f"Processing {len(images)} image(s)...")
doc_vecs = _embed_images(model, processor, images)
print("✓ Forward pass completed!")
print(f"✓ Number of embeddings: {len(doc_vecs)}")
if len(doc_vecs) > 0:
emb = doc_vecs[0]
print(f"✓ Embedding shape: {emb.shape}")
print(f"✓ Embedding dtype: {emb.dtype}")
print("✓ Embedding stats:")
print(f" - Min: {emb.min().item():.4f}")
print(f" - Max: {emb.max().item():.4f}")
print(f" - Mean: {emb.mean().item():.4f}")
print(f" - Std: {emb.std().item():.4f}")
# Check for NaN or Inf
if torch.isnan(emb).any():
print("⚠ Warning: Embedding contains NaN values!")
if torch.isinf(emb).any():
print("⚠ Warning: Embedding contains Inf values!")
except Exception as e:
print(f"✗ Error during forward pass: {e}")
import traceback
traceback.print_exc()
return
print("\n" + "=" * 60)
print("Test completed successfully!")
print("=" * 60)
if __name__ == "__main__":
main()

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,448 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v1 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v1 tasks
python vidore_v1_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
# Use Fast-Plaid index
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
from typing import Optional
from datasets import load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
)
_ensure_repo_paths_importable(__file__)
# ViDoRe v1 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V1_TASKS = {
"VidoreArxivQARetrieval": {
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreDocVQARetrieval": {
"dataset_path": "vidore/docvqa_test_subsampled_beir",
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreInfoVQARetrieval": {
"dataset_path": "vidore/infovqa_test_subsampled_beir",
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreTabfquadRetrieval": {
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreTatdqaRetrieval": {
"dataset_path": "vidore/tatdqa_test_beir",
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreShiftProjectRetrieval": {
"dataset_path": "vidore/shiftproject_test_beir",
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAAIRetrieval": {
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAEnergyRetrieval": {
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
# Task name aliases (short names -> full names)
TASK_ALIASES = {
"arxivqa": "VidoreArxivQARetrieval",
"docvqa": "VidoreDocVQARetrieval",
"infovqa": "VidoreInfoVQARetrieval",
"tabfquad": "VidoreTabfquadRetrieval",
"tatdqa": "VidoreTatdqaRetrieval",
"shiftproject": "VidoreShiftProjectRetrieval",
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
}
def normalize_task_name(task_name: str) -> str:
"""Normalize task name (handle aliases)."""
task_name_lower = task_name.lower()
if task_name in VIDORE_V1_TASKS:
return task_name
if task_name_lower in TASK_ALIASES:
return TASK_ALIASES[task_name_lower]
# Try partial match
for alias, full_name in TASK_ALIASES.items():
if alias in task_name_lower or task_name_lower in alias:
return full_name
return task_name
def get_safe_model_name(model_name: str) -> str:
"""Get a safe model name for use in file paths."""
import hashlib
import os
# If it's a path, use basename or hash
if os.path.exists(model_name) and os.path.isdir(model_name):
# Use basename if it's reasonable, otherwise use hash
basename = os.path.basename(model_name.rstrip("/"))
if basename and len(basename) < 100 and not basename.startswith("."):
return basename
# Use hash for very long or problematic paths
return hashlib.md5(model_name.encode()).hexdigest()[:16]
# For HuggingFace model names, replace / with _
return model_name.replace("/", "_").replace(":", "_")
def load_vidore_v1_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
):
"""
Load ViDoRe v1 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
}
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 1000,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v1 task.
"""
print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}")
print(f"{'=' * 80}")
# Normalize task name (handle aliases)
task_name = normalize_task_name(task_name)
# Get task config
if task_name not in VIDORE_V1_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
task_config = VIDORE_V1_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Load data
corpus, queries, qrels = load_vidore_v1_data(
dataset_path=dataset_path,
revision=revision,
split="test",
)
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 20, 100, 1000]
# Check if we have any queries
if len(queries) == 0:
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name,
use_fast_plaid=use_fast_plaid,
top_k=top_k,
first_stage_k=first_stage_k,
k_values=k_values,
)
# Build or load index
# Use safe model name for index path (different models need different indexes)
safe_model_name = get_safe_model_name(model_name)
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus,
index_path=index_path_full,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
index_or_retriever=index_or_retriever,
fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'=' * 80}")
print(f"Results for {task_name}:")
print(f"{'=' * 80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--top-k",
type=int,
default=1000,
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,20,100,1000",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v1_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [normalize_task_name(args.task)]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
else:
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'=' * 80}")
print("SUMMARY")
print(f"{'=' * 80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,439 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v2 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v2 tasks
python vidore_v2_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
# Use Fast-Plaid index
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
from typing import Optional
from datasets import load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
)
_ensure_repo_paths_importable(__file__)
# Language name to dataset language field value mapping
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
LANGUAGE_MAPPING = {
"english": "eng-Latn",
"french": "fra-Latn",
"spanish": "spa-Latn",
"german": "deu-Latn",
}
# ViDoRe v2 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V2_TASKS = {
"Vidore2ESGReportsRetrieval": {
"dataset_path": "vidore/esg_reports_v2",
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2EconomicsReportsRetrieval": {
"dataset_path": "vidore/economics_reports_v2",
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2BioMedicalLecturesRetrieval": {
"dataset_path": "vidore/biomedical_lectures_v2",
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2ESGReportsHLRetrieval": {
"dataset_path": "vidore/esg_reports_human_labeled_v2",
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
# Note: This dataset doesn't have language filtering - all queries are English
"languages": None, # No language filtering needed
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
def load_vidore_v2_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
language: Optional[str] = None,
):
"""
Load ViDoRe v2 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
# Check if dataset has language field before filtering
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
if language and has_language_field:
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
dataset_language = LANGUAGE_MAPPING.get(language, language)
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
# Check if filtering resulted in empty dataset
if len(query_ds_filtered) == 0:
print(
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
)
# Try with original language value (dataset might use simple names like 'english')
print(f"Trying with original language value '{language}'...")
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
if len(query_ds_filtered) == 0:
# Try to get a sample to see actual language values
try:
sample_ds = load_dataset(
dataset_path, "queries", split=split, revision=revision
)
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"])
print(f"Available language values in dataset: {sample_langs}")
except Exception:
pass
else:
print(
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
)
query_ds = query_ds_filtered
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
}
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
language: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 100,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v2 task.
"""
print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}")
print(f"{'=' * 80}")
# Get task config
if task_name not in VIDORE_V2_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
task_config = VIDORE_V2_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Determine language
if language is None:
# Use first language if multiple available
languages = task_config.get("languages")
if languages is None:
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
language = None
elif len(languages) == 1:
language = languages[0]
else:
language = None
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 100]
# Load data
corpus, queries, qrels = load_vidore_v2_data(
dataset_path=dataset_path,
revision=revision,
split="test",
language=language,
)
# Check if we have any queries
if len(queries) == 0:
print(
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
)
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name,
use_fast_plaid=use_fast_plaid,
top_k=top_k,
first_stage_k=first_stage_k,
k_values=k_values,
)
# Build or load index
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus,
index_path=index_path_full,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
index_or_retriever=index_or_retriever,
fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'=' * 80}")
print(f"Results for {task_name}:")
print(f"{'=' * 80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
choices=["colqwen2", "colpali"],
help="Model to use",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--language",
type=str,
default=None,
help="Language to evaluate (default: first available)",
)
parser.add_argument(
"--top-k",
type=int,
default=100,
help="Top-k results to retrieve",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,100",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v2_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [args.task]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
else:
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
language=args.language,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'=' * 80}")
print("SUMMARY")
print(f"{'=' * 80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()

View File

@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
flexible message processing options. flexible message processing options.
""" """
import ast
import asyncio import asyncio
import json import json
import logging import logging
@@ -146,16 +147,16 @@ class SlackMCPReader:
match = re.search(r"'error':\s*(\{[^}]+\})", str(e)) match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match: if match:
try: try:
error_dict = eval(match.group(1)) error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError, NameError): except (ValueError, SyntaxError):
pass pass
else: else:
# Try alternative format # Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e)) match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match: if match:
try: try:
error_dict = eval(match.group(1)) error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError, NameError): except (ValueError, SyntaxError):
pass pass
if self._is_cache_sync_error(error_dict): if self._is_cache_sync_error(error_dict):

View File

View File

@@ -158,6 +158,95 @@ builder.build_index("./indexes/my-notes", chunks)
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you). `embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Optional Embedding Features
### Task-Specific Prompt Templates
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
- **Indexing documents**: `"title: none | text: "`
- **Search queries**: `"task: search result | query: "`
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
```bash
# Build index with EmbeddingGemma (via LM Studio or Ollama)
leann build my-docs \
--docs ./documents \
--embedding-mode openai \
--embedding-model text-embedding-embeddinggemma-300m-qat \
--embedding-api-base http://localhost:1234/v1 \
--embedding-prompt-template "title: none | text: " \
--force
# Search with query-specific prompt
leann search my-docs \
--query "What is quantum computing?" \
--embedding-prompt-template "task: search result | query: "
```
**Important Notes:**
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
**Python API:**
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
embedding_mode="openai",
embedding_model="text-embedding-embeddinggemma-300m-qat",
embedding_options={
"base_url": "http://localhost:1234/v1",
"api_key": "lm-studio",
"prompt_template": "title: none | text: ",
},
)
builder.build_index("./indexes/my-docs", chunks)
```
**References:**
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
### LM Studio Auto-Detection (Optional)
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
**Prerequisites:**
```bash
# Install Node.js (if not already installed)
# Then install the LM Studio SDK globally
npm install -g @lmstudio/sdk
```
**How it works:**
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
2. Queries model metadata via Node.js subprocess
3. Automatically unloads model after query (respects your JIT auto-evict settings)
4. Falls back to static registry if SDK unavailable
**No configuration needed** - it works automatically when SDK is installed:
```bash
leann build my-docs \
--docs ./documents \
--embedding-mode openai \
--embedding-model text-embedding-nomic-embed-text-v1.5 \
--embedding-api-base http://localhost:1234/v1
# Context length auto-detected if SDK available
# Falls back to registry (2048) if not
```
**Benefits:**
- ✅ Automatic token limit detection
- ✅ Respects LM Studio JIT auto-evict settings
- ✅ No manual registry maintenance
- ✅ Graceful fallback if SDK unavailable
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
## Index Selection: Matching Your Scale ## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World) ### HNSW (Hierarchical Navigable Small World)
@@ -365,7 +454,7 @@ leann search my-index "your query" \
### 2) Run remote builds with SkyPilot (cloud GPU) ### 2) Run remote builds with SkyPilot (cloud GPU)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`. Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
```bash ```bash
# One-time: install and configure SkyPilot # One-time: install and configure SkyPilot

View File

@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
--embedding-model sentence-transformers/all-MiniLM-L6-v2 --embedding-model sentence-transformers/all-MiniLM-L6-v2
``` ```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters) **Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
## 2. When should I use prompt templates?
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
**Example usage with EmbeddingGemma:**
```bash
# Build with document prompt
leann build my-docs --embedding-prompt-template "title: none | text: "
# Search with query prompt
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
```
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
## 3. Why is LM Studio loading multiple copies of my model?
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
**If you still see duplicates:**
- Update to the latest LEANN version
- Restart LM Studio to clear loaded models
- Check that you have JIT auto-evict enabled in LM Studio settings
**How it works now:**
1. LEANN loads model temporarily to get context length
2. Immediately unloads after query
3. LM Studio JIT loads model on-demand for actual embeddings
4. Auto-evicts per your settings
## 4. Do I need Node.js and @lmstudio/sdk?
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
**Benefits if you install it:**
- Automatic context length detection for LM Studio models
- No manual registry maintenance
- Always gets accurate token limits from the model itself
**To install (optional):**
```bash
npm install -g @lmstudio/sdk
```
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.3.4" version = "0.3.5"
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"] dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.3.4" version = "0.3.5"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.3.4", "leann-core==0.3.5",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",

View File

@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.3.4" version = "0.3.5"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.10"
license = { text = "MIT" } license = { text = "MIT" }
# All required dependencies included # All required dependencies included

View File

@@ -916,6 +916,7 @@ class LeannSearcher:
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0, batch_size: int = 0,
use_grep: bool = False, use_grep: bool = False,
provider_options: Optional[dict[str, Any]] = None,
**kwargs, **kwargs,
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
@@ -979,10 +980,24 @@ class LeannSearcher:
start_time = time.time() start_time = time.time()
# Extract query template from stored embedding_options with fallback chain:
# 1. Check provider_options override (highest priority)
# 2. Check query_prompt_template (new format)
# 3. Check prompt_template (old format for backward compat)
# 4. None (no template)
query_template = None
if provider_options and "prompt_template" in provider_options:
query_template = provider_options["prompt_template"]
elif "query_prompt_template" in self.embedding_options:
query_template = self.embedding_options["query_prompt_template"]
elif "prompt_template" in self.embedding_options:
query_template = self.embedding_options["prompt_template"]
query_embedding = self.backend_impl.compute_query_embedding( query_embedding = self.backend_impl.compute_query_embedding(
query, query,
use_server_if_available=recompute_embeddings, use_server_if_available=recompute_embeddings,
zmq_port=zmq_port, zmq_port=zmq_port,
query_template=query_template,
) )
logger.info(f" Generated embedding shape: {query_embedding.shape}") logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time embedding_time = time.time() - start_time
@@ -1236,15 +1251,15 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge." "Please provide the best answer you can based on this context and your knowledge."
) )
print("The context provided to the LLM is:") logger.info("The context provided to the LLM is:")
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}") logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
print("-" * 150) logger.info("-" * 150)
for r in results: for r in results:
chunk_relevance = f"{r.score:.3f}" chunk_relevance = f"{r.score:.3f}"
chunk_id = r.id chunk_id = r.id
chunk_content = r.text[:60] chunk_content = r.text[:60]
chunk_source = r.metadata.get("source", "")[:80] chunk_source = r.metadata.get("source", "")[:80]
print( logger.info(
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}" f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
) )
ask_time = time.time() ask_time = time.time()

View File

@@ -12,7 +12,13 @@ from typing import Any, Optional
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url from .settings import (
resolve_anthropic_api_key,
resolve_anthropic_base_url,
resolve_ollama_host,
resolve_openai_api_key,
resolve_openai_base_url,
)
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -845,6 +851,81 @@ class OpenAIChat(LLMInterface):
return f"Error: Could not get a response from OpenAI. Details: {e}" return f"Error: Could not get a response from OpenAI. Details: {e}"
class AnthropicChat(LLMInterface):
"""LLM interface for Anthropic Claude models."""
def __init__(
self,
model: str = "claude-haiku-4-5",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model
self.base_url = resolve_anthropic_base_url(base_url)
self.api_key = resolve_anthropic_api_key(api_key)
if not self.api_key:
raise ValueError(
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
)
logger.info(
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try:
import anthropic
# Allow custom Anthropic-compatible endpoints via base_url
self.client = anthropic.Anthropic(
api_key=self.api_key,
base_url=self.base_url,
)
except ImportError:
raise ImportError(
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
)
def ask(self, prompt: str, **kwargs) -> str:
logger.info(f"Sending request to Anthropic with model {self.model}")
try:
# Anthropic API parameters
params = {
"model": self.model,
"max_tokens": kwargs.get("max_tokens", 1000),
"messages": [{"role": "user", "content": prompt}],
}
# Add optional parameters
if "temperature" in kwargs:
params["temperature"] = kwargs["temperature"]
if "top_p" in kwargs:
params["top_p"] = kwargs["top_p"]
response = self.client.messages.create(**params)
# Extract text from response
response_text = response.content[0].text
# Log token usage
print(
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
f"input tokens = {response.usage.input_tokens}, "
f"output tokens = {response.usage.output_tokens}"
)
if response.stop_reason == "max_tokens":
print("The query is exceeding the maximum allowed number of tokens")
return response_text.strip()
except Exception as e:
logger.error(f"Error communicating with Anthropic: {e}")
return f"Error: Could not get a response from Anthropic. Details: {e}"
class SimulatedChat(LLMInterface): class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development.""" """A simple simulated chat for testing and development."""
@@ -897,6 +978,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
) )
elif llm_type == "gemini": elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "anthropic":
return AnthropicChat(
model=model or "claude-3-5-sonnet-20241022",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "simulated": elif llm_type == "simulated":
return SimulatedChat() return SimulatedChat()
else: else:

View File

@@ -5,12 +5,15 @@ Packaged within leann-core so installed wheels can import it reliably.
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Optional
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Flag to ensure AST token warning only shown once per session
_ast_token_warning_shown = False
def estimate_token_count(text: str) -> int: def estimate_token_count(text: str) -> int:
""" """
@@ -174,37 +177,44 @@ def create_ast_chunks(
max_chunk_size: int = 512, max_chunk_size: int = 512,
chunk_overlap: int = 64, chunk_overlap: int = 64,
metadata_template: str = "default", metadata_template: str = "default",
) -> list[str]: ) -> list[dict[str, Any]]:
"""Create AST-aware chunks from code documents using astchunk. """Create AST-aware chunks from code documents using astchunk.
Falls back to traditional chunking if astchunk is unavailable. Falls back to traditional chunking if astchunk is unavailable.
Returns:
List of dicts with {"text": str, "metadata": dict}
""" """
try: try:
from astchunk import ASTChunkBuilder # optional dependency from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e: except ImportError as e:
logger.error(f"astchunk not available: {e}") logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files") logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap) return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
all_chunks = [] all_chunks = []
for doc in documents: for doc in documents:
language = doc.metadata.get("language") language = doc.metadata.get("language")
if not language: if not language:
logger.warning("No language detected; falling back to traditional chunking") logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap)) all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
continue continue
try: try:
# Warn if AST chunk size + overlap might exceed common token limits # Warn once if AST chunk size + overlap might exceed common token limits
# Note: Actual truncation happens at embedding time with dynamic model limits
global _ast_token_warning_shown
estimated_max_tokens = int( estimated_max_tokens = int(
(max_chunk_size + chunk_overlap) * 1.2 (max_chunk_size + chunk_overlap) * 1.2
) # Conservative estimate ) # Conservative estimate
if estimated_max_tokens > 512: if estimated_max_tokens > 512 and not _ast_token_warning_shown:
logger.warning( logger.warning(
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars " f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). " f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}" f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
) )
_ast_token_warning_shown = True
configs = { configs = {
"max_chunk_size": max_chunk_size, "max_chunk_size": max_chunk_size,
@@ -229,17 +239,40 @@ def create_ast_chunks(
chunks = chunk_builder.chunkify(code_content) chunks = chunk_builder.chunkify(code_content)
for chunk in chunks: for chunk in chunks:
chunk_text = None
astchunk_metadata = {}
if hasattr(chunk, "text"): if hasattr(chunk, "text"):
chunk_text = chunk.text chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str): elif isinstance(chunk, str):
chunk_text = chunk chunk_text = chunk
elif isinstance(chunk, dict):
# Handle astchunk format: {"content": "...", "metadata": {...}}
if "content" in chunk:
chunk_text = chunk["content"]
astchunk_metadata = chunk.get("metadata", {})
elif "text" in chunk:
chunk_text = chunk["text"]
else:
chunk_text = str(chunk) # Last resort
else: else:
chunk_text = str(chunk) chunk_text = str(chunk)
if chunk_text and chunk_text.strip(): if chunk_text and chunk_text.strip():
all_chunks.append(chunk_text.strip()) # Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Merge document metadata + astchunk metadata
combined_metadata = {**doc_metadata, **astchunk_metadata}
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
logger.info( logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}" f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
@@ -247,15 +280,19 @@ def create_ast_chunks(
except Exception as e: except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}") logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking") logger.info("Falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap)) all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
return all_chunks return all_chunks
def create_traditional_chunks( def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128 documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]: ) -> list[dict[str, Any]]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter.""" """Create traditional text chunks using LlamaIndex SentenceSplitter.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if chunk_size <= 0: if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256") logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256 chunk_size = 256
@@ -271,19 +308,40 @@ def create_traditional_chunks(
paragraph_separator="\n\n", paragraph_separator="\n\n",
) )
all_texts = [] result = []
for doc in documents: for doc in documents:
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
try: try:
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
if nodes: if nodes:
all_texts.extend(node.get_content() for node in nodes) for node in nodes:
result.append({"text": node.get_content(), "metadata": doc_metadata})
except Exception as e: except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}") logger.error(f"Traditional chunking failed for document: {e}")
content = doc.get_content() content = doc.get_content()
if content and content.strip(): if content and content.strip():
all_texts.append(content.strip()) result.append({"text": content.strip(), "metadata": doc_metadata})
return all_texts return result
def _traditional_chunks_as_dicts(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]:
"""Helper: Traditional chunking that returns dict format for consistency.
This is now just an alias for create_traditional_chunks for backwards compatibility.
"""
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
def create_text_chunks( def create_text_chunks(
@@ -295,8 +353,12 @@ def create_text_chunks(
ast_chunk_overlap: int = 64, ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None, code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True, ast_fallback_traditional: bool = True,
) -> list[str]: ) -> list[dict[str, Any]]:
"""Create text chunks from documents with optional AST support for code files.""" """Create text chunks from documents with optional AST support for code files.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if not documents: if not documents:
logger.warning("No documents provided for chunking") logger.warning("No documents provided for chunking")
return [] return []
@@ -331,24 +393,17 @@ def create_text_chunks(
logger.error(f"AST chunking failed: {e}") logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional: if ast_fallback_traditional:
all_chunks.extend( all_chunks.extend(
create_traditional_chunks(code_docs, chunk_size, chunk_overlap) _traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
) )
else: else:
raise raise
if text_docs: if text_docs:
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap)) all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
else: else:
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap) all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}") logger.info(f"Total chunks created: {len(all_chunks)}")
# Validate chunk token limits (default to 512 for safety) # Note: Token truncation is now handled at embedding time with dynamic model limits
# This provides a safety net for embedding models with token limits # See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512) return all_chunks
if num_truncated > 0:
logger.info(
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
)
return validated_chunks

View File

@@ -11,7 +11,12 @@ from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannChat, LeannSearcher
from .interactive_utils import create_cli_session from .interactive_utils import create_cli_session
from .registry import register_project_directory from .registry import register_project_directory
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url from .settings import (
resolve_anthropic_base_url,
resolve_ollama_host,
resolve_openai_api_key,
resolve_openai_base_url,
)
def extract_pdf_text_with_pymupdf(file_path: str) -> str: def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -144,6 +149,18 @@ Examples:
default=None, default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)", help="API key for embedding service (defaults to OPENAI_API_KEY)",
) )
build_parser.add_argument(
"--embedding-prompt-template",
type=str,
default=None,
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
)
build_parser.add_argument(
"--query-prompt-template",
type=str,
default=None,
help="Prompt template for queries (different from build template for task-specific models)",
)
build_parser.add_argument( build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index" "--force", "-f", action="store_true", help="Force rebuild existing index"
) )
@@ -260,6 +277,12 @@ Examples:
action="store_true", action="store_true",
help="Display file paths and metadata in search results", help="Display file paths and metadata in search results",
) )
search_parser.add_argument(
"--embedding-prompt-template",
type=str,
default=None,
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
)
# Ask command # Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions") ask_parser = subparsers.add_parser("ask", help="Ask questions")
@@ -273,7 +296,7 @@ Examples:
"--llm", "--llm",
type=str, type=str,
default="ollama", default="ollama",
choices=["simulated", "ollama", "hf", "openai"], choices=["simulated", "ollama", "hf", "openai", "anthropic"],
help="LLM provider (default: ollama)", help="LLM provider (default: ollama)",
) )
ask_parser.add_argument( ask_parser.add_argument(
@@ -323,7 +346,7 @@ Examples:
"--api-key", "--api-key",
type=str, type=str,
default=None, default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)", help="API key for cloud LLM providers (OpenAI, Anthropic)",
) )
# List command # List command
@@ -1162,6 +1185,11 @@ Examples:
print(f"Warning: Could not process {file_path}: {e}") print(f"Warning: Could not process {file_path}: {e}")
# Load other file types with default reader # Load other file types with default reader
# Exclude PDFs from code_extensions if they were already processed separately
other_file_extensions = code_extensions
if should_process_pdfs and ".pdf" in code_extensions:
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
try: try:
# Create a custom file filter function using our PathSpec # Create a custom file filter function using our PathSpec
def file_filter( def file_filter(
@@ -1177,15 +1205,19 @@ Examples:
except (ValueError, OSError): except (ValueError, OSError):
return True # Include files that can't be processed return True # Include files that can't be processed
# Only load other file types if there are extensions to process
if other_file_extensions:
other_docs = SimpleDirectoryReader( other_docs = SimpleDirectoryReader(
docs_dir, docs_dir,
recursive=True, recursive=True,
encoding="utf-8", encoding="utf-8",
required_exts=code_extensions, required_exts=other_file_extensions,
file_extractor={}, # Use default extractors file_extractor={}, # Use default extractors
exclude_hidden=not include_hidden, exclude_hidden=not include_hidden,
filename_as_id=True, filename_as_id=True,
).load_data(show_progress=True) ).load_data(show_progress=True)
else:
other_docs = []
# Filter documents after loading based on gitignore rules # Filter documents after loading based on gitignore rules
filtered_docs = [] filtered_docs = []
@@ -1279,13 +1311,8 @@ Examples:
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True), ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
) )
# Note: AST chunking currently returns plain text chunks without metadata # create_text_chunks now returns list[dict] with metadata preserved
# We preserve basic file info by associating chunks with their source documents all_texts.extend(chunk_texts)
# For better metadata preservation, documents list order should be maintained
for chunk_text in chunk_texts:
# TODO: Enhance create_text_chunks to return metadata alongside text
# For now, we store chunks with empty metadata
all_texts.append({"text": chunk_text, "metadata": {}})
except ImportError as e: except ImportError as e:
print( print(
@@ -1403,6 +1430,14 @@ Examples:
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key) resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key: if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key embedding_options["api_key"] = resolved_embedding_key
if args.query_prompt_template:
# New format: separate templates
if args.embedding_prompt_template:
embedding_options["build_prompt_template"] = args.embedding_prompt_template
embedding_options["query_prompt_template"] = args.query_prompt_template
elif args.embedding_prompt_template:
# Old format: single template (backward compat)
embedding_options["prompt_template"] = args.embedding_prompt_template
builder = LeannBuilder( builder = LeannBuilder(
backend_name=args.backend_name, backend_name=args.backend_name,
@@ -1524,6 +1559,11 @@ Examples:
print("Invalid input. Aborting search.") print("Invalid input. Aborting search.")
return return
# Build provider_options for runtime override
provider_options = {}
if args.embedding_prompt_template:
provider_options["prompt_template"] = args.embedding_prompt_template
searcher = LeannSearcher(index_path=index_path) searcher = LeannSearcher(index_path=index_path)
results = searcher.search( results = searcher.search(
query, query,
@@ -1533,6 +1573,7 @@ Examples:
prune_ratio=args.prune_ratio, prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings, recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy, pruning_strategy=args.pruning_strategy,
provider_options=provider_options if provider_options else None,
) )
print(f"Search results for '{query}' (top {len(results)}):") print(f"Search results for '{query}' (top {len(results)}):")
@@ -1580,6 +1621,12 @@ Examples:
resolved_api_key = resolve_openai_api_key(args.api_key) resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key: if resolved_api_key:
llm_config["api_key"] = resolved_api_key llm_config["api_key"] = resolved_api_key
elif args.llm == "anthropic":
# For Anthropic, pass base_url and API key if provided
if args.api_base:
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
if args.api_key:
llm_config["api_key"] = args.api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config) chat = LeannChat(index_path=index_path, llm_config=llm_config)

View File

@@ -4,119 +4,310 @@ Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance Preserves all optimization parameters to ensure performance
""" """
import json
import logging import logging
import os import os
import subprocess
import time import time
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
import tiktoken
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
"""
Truncate texts to token limit using tiktoken or conservative character truncation.
Args:
texts: List of texts to truncate
max_tokens: Maximum tokens allowed per text
Returns:
List of truncated texts that should fit within token limit
"""
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
truncated = []
for text in texts:
tokens = encoder.encode(text)
if len(tokens) > max_tokens:
# Truncate to max_tokens and decode back to text
truncated_tokens = tokens[:max_tokens]
truncated_text = encoder.decode(truncated_tokens)
truncated.append(truncated_text)
logger.warning(
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
f"(from {len(text)} to {len(truncated_text)} characters)"
)
else:
truncated.append(text)
return truncated
except ImportError:
# Fallback: Conservative character truncation
# Assume worst case: 1.5 tokens per character for code content
char_limit = int(max_tokens / 1.5)
truncated = []
for text in texts:
if len(text) > char_limit:
truncated_text = text[:char_limit]
truncated.append(truncated_text)
logger.warning(
f"Truncated text from {len(text)} to {char_limit} characters "
f"(conservative estimate for {max_tokens} tokens)"
)
else:
truncated.append(text)
return truncated
def get_model_token_limit(model_name: str) -> int:
"""
Get token limit for a given embedding model.
Args:
model_name: Name of the embedding model
Returns:
Token limit for the model, defaults to 512 if unknown
"""
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
base_model_name = model_name.split(":")[0]
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[model_name]
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[base_model_name]
# Check partial matches for common patterns
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
return limit
# Default to conservative 512 token limit
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
return 512
# Set up logger with proper level # Set up logger with proper level
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING) log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level) logger.setLevel(log_level)
# Global model cache to avoid repeated loading # Token limit registry for embedding models
_model_cache: dict[str, Any] = {} # Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
# Ollama models use dynamic discovery via /api/show
# Known embedding model token limits
EMBEDDING_MODEL_LIMITS = { EMBEDDING_MODEL_LIMITS = {
"nomic-embed-text": 512, # Nomic models (common across servers)
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
"nomic-embed-text-v1.5": 2048,
"nomic-embed-text-v2": 512, "nomic-embed-text-v2": 512,
# Other embedding models
"mxbai-embed-large": 512, "mxbai-embed-large": 512,
"all-minilm": 512, "all-minilm": 512,
"bge-m3": 8192, "bge-m3": 8192,
"snowflake-arctic-embed": 512, "snowflake-arctic-embed": 512,
# Add more models as needed # OpenAI models
"text-embedding-3-small": 8192,
"text-embedding-3-large": 8192,
"text-embedding-ada-002": 8192,
} }
# Runtime cache for dynamically discovered token limits
# Key: (model_name, base_url), Value: token_limit
# Prevents repeated SDK/API calls for the same model
_token_limit_cache: dict[tuple[str, str], int] = {}
def get_model_token_limit(
model_name: str,
base_url: Optional[str] = None,
default: int = 2048,
) -> int:
"""
Get token limit for a given embedding model.
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
Caches discovered limits to prevent repeated API/SDK calls.
Args:
model_name: Name of the embedding model
base_url: Base URL of the embedding server (for dynamic discovery)
default: Default token limit if model not found
Returns:
Token limit for the model in tokens
"""
# Check cache first to avoid repeated SDK/API calls
cache_key = (model_name, base_url or "")
if cache_key in _token_limit_cache:
cached_limit = _token_limit_cache[cache_key]
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
return cached_limit
# Try Ollama dynamic discovery if base_url provided
if base_url:
# Detect Ollama servers by port or "ollama" in URL
if "11434" in base_url or "ollama" in base_url.lower():
limit = _query_ollama_context_limit(model_name, base_url)
if limit:
_token_limit_cache[cache_key] = limit
return limit
# Try LM Studio SDK discovery
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
# Convert HTTP to WebSocket URL
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
# Remove /v1 suffix if present
if ws_url.endswith("/v1"):
ws_url = ws_url[:-3]
limit = _query_lmstudio_context_limit(model_name, ws_url)
if limit:
_token_limit_cache[cache_key] = limit
return limit
# Fallback to known model registry with version handling (from PR #154)
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
base_model_name = model_name.split(":")[0]
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
limit = EMBEDDING_MODEL_LIMITS[model_name]
_token_limit_cache[cache_key] = limit
return limit
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
_token_limit_cache[cache_key] = limit
return limit
# Check partial matches for common patterns
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
_token_limit_cache[cache_key] = registry_limit
return registry_limit
# Default fallback
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
_token_limit_cache[cache_key] = default
return default
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
"""
Truncate texts to fit within token limit using tiktoken.
Args:
texts: List of text strings to truncate
token_limit: Maximum number of tokens allowed
Returns:
List of truncated texts (same length as input)
"""
if not texts:
return []
# Use tiktoken with cl100k_base encoding
enc = tiktoken.get_encoding("cl100k_base")
truncated_texts = []
truncation_count = 0
total_tokens_removed = 0
max_original_length = 0
for i, text in enumerate(texts):
tokens = enc.encode(text)
original_length = len(tokens)
if original_length <= token_limit:
# Text is within limit, keep as is
truncated_texts.append(text)
else:
# Truncate to token_limit
truncated_tokens = tokens[:token_limit]
truncated_text = enc.decode(truncated_tokens)
truncated_texts.append(truncated_text)
# Track truncation statistics
truncation_count += 1
tokens_removed = original_length - token_limit
total_tokens_removed += tokens_removed
max_original_length = max(max_original_length, original_length)
# Log individual truncation at WARNING level (first few only)
if truncation_count <= 3:
logger.warning(
f"Text {i + 1} truncated: {original_length}{token_limit} tokens "
f"({tokens_removed} tokens removed)"
)
elif truncation_count == 4:
logger.warning("Further truncation warnings suppressed...")
# Log summary at INFO level
if truncation_count > 0:
logger.warning(
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
)
else:
logger.debug(
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
)
return truncated_texts
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query Ollama /api/show for model context limit.
Args:
model_name: Name of the Ollama model
base_url: Base URL of the Ollama server
Returns:
Context limit in tokens if found, None otherwise
"""
try:
import requests
response = requests.post(
f"{base_url}/api/show",
json={"name": model_name},
timeout=5,
)
if response.status_code == 200:
data = response.json()
if "model_info" in data:
# Look for *.context_length in model_info
for key, value in data["model_info"].items():
if "context_length" in key and isinstance(value, int):
logger.info(f"Detected {model_name} context limit: {value} tokens")
return value
except Exception as e:
logger.debug(f"Failed to query Ollama context limit: {e}")
return None
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query LM Studio SDK for model context length via Node.js subprocess.
Args:
model_name: Name of the LM Studio model
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
Returns:
Context limit in tokens if found, None otherwise
"""
# Inline JavaScript using @lmstudio/sdk
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
js_code = f"""
const {{ LMStudioClient }} = require('@lmstudio/sdk');
(async () => {{
try {{
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
const contextLength = await model.getContextLength();
await model.unload(); // Unload immediately to respect JIT auto-evict settings
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
}} catch (error) {{
console.error(JSON.stringify({{ error: error.message }}));
process.exit(1);
}}
}})();
"""
try:
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
env = os.environ.copy()
# Try to get npm global root (works with nvm, brew node, etc.)
try:
npm_root = subprocess.run(
["npm", "root", "-g"],
capture_output=True,
text=True,
timeout=5,
)
if npm_root.returncode == 0:
global_modules = npm_root.stdout.strip()
# Append to existing NODE_PATH if present
existing_node_path = env.get("NODE_PATH", "")
env["NODE_PATH"] = (
f"{global_modules}:{existing_node_path}"
if existing_node_path
else global_modules
)
except Exception:
# If npm not available, continue with existing NODE_PATH
pass
result = subprocess.run(
["node", "-e", js_code],
capture_output=True,
text=True,
timeout=10,
env=env,
)
if result.returncode != 0:
logger.debug(f"LM Studio SDK error: {result.stderr}")
return None
data = json.loads(result.stdout)
context_length = data.get("contextLength")
if context_length and context_length > 0:
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
return context_length
except FileNotFoundError:
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
except subprocess.TimeoutExpired:
logger.debug("LM Studio SDK query timeout")
except json.JSONDecodeError:
logger.debug("LM Studio SDK returned invalid JSON")
except Exception as e:
logger.debug(f"LM Studio SDK query failed: {e}")
return None
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
def compute_embeddings( def compute_embeddings(
texts: list[str], texts: list[str],
@@ -161,6 +352,7 @@ def compute_embeddings(
model_name, model_name,
base_url=provider_options.get("base_url"), base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"), api_key=provider_options.get("api_key"),
provider_options=provider_options,
) )
elif mode == "mlx": elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name) return compute_embeddings_mlx(texts, model_name)
@@ -170,6 +362,7 @@ def compute_embeddings(
model_name, model_name,
is_build=is_build, is_build=is_build,
host=provider_options.get("host"), host=provider_options.get("host"),
provider_options=provider_options,
) )
elif mode == "gemini": elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build) return compute_embeddings_gemini(texts, model_name, is_build=is_build)
@@ -508,6 +701,7 @@ def compute_embeddings_openai(
model_name: str, model_name: str,
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray: ) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode # TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API""" """Compute embeddings using OpenAI API"""
@@ -526,26 +720,40 @@ def compute_embeddings_openai(
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI." f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
) )
resolved_base_url = resolve_openai_base_url(base_url) # Extract base_url and api_key from provider_options if not provided directly
resolved_api_key = resolve_openai_api_key(api_key) provider_options = provider_options or {}
effective_base_url = base_url or provider_options.get("base_url")
effective_api_key = api_key or provider_options.get("api_key")
resolved_base_url = resolve_openai_base_url(effective_base_url)
resolved_api_key = resolve_openai_api_key(effective_api_key)
if not resolved_api_key: if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set") raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client # Create OpenAI client
cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url) client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
logger.info( logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'" f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
) )
print(f"len of texts: {len(texts)}") print(f"len of texts: {len(texts)}")
# Apply prompt template if provided
# Priority: build_prompt_template (new format) > prompt_template (old format)
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
"prompt_template"
)
if prompt_template:
logger.warning(f"Applying prompt template: '{prompt_template}'")
texts = [f"{prompt_template}{text}" for text in texts]
# Query token limit and apply truncation
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
texts = truncate_to_token_limit(texts, token_limit)
# OpenAI has limits on batch size and input length # OpenAI has limits on batch size and input length
max_batch_size = 800 # Conservative batch size because the token limit is 300K max_batch_size = 800 # Conservative batch size because the token limit is 300K
all_embeddings = [] all_embeddings = []
@@ -576,7 +784,15 @@ def compute_embeddings_openai(
try: try:
response = client.embeddings.create(model=model_name, input=batch_texts) response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data] batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
# Verify we got the expected number of embeddings
if len(batch_embeddings) != len(batch_texts):
logger.warning(
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
)
# Only take the number of embeddings that match the batch size
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
except Exception as e: except Exception as e:
logger.error(f"Batch {i} failed: {e}") logger.error(f"Batch {i} failed: {e}")
raise raise
@@ -666,6 +882,7 @@ def compute_embeddings_ollama(
model_name: str, model_name: str,
is_build: bool = False, is_build: bool = False,
host: Optional[str] = None, host: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embeddings using Ollama API with true batch processing. Compute embeddings using Ollama API with true batch processing.
@@ -678,6 +895,7 @@ def compute_embeddings_ollama(
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large") model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar) is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (defaults to environment or http://localhost:11434) host: Ollama host URL (defaults to environment or http://localhost:11434)
provider_options: Optional provider-specific options (e.g., prompt_template)
Returns: Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -814,15 +1032,24 @@ def compute_embeddings_ollama(
logger.info(f"Using batch size: {batch_size} for true batch processing") logger.info(f"Using batch size: {batch_size} for true batch processing")
# Get model token limit and apply truncation # Apply prompt template if provided
token_limit = get_model_token_limit(model_name) provider_options = provider_options or {}
# Priority: build_prompt_template (new format) > prompt_template (old format)
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
"prompt_template"
)
if prompt_template:
logger.warning(f"Applying prompt template: '{prompt_template}'")
texts = [f"{prompt_template}{text}" for text in texts]
# Get model token limit and apply truncation before batching
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
logger.info(f"Model '{model_name}' token limit: {token_limit}") logger.info(f"Model '{model_name}' token limit: {token_limit}")
# Apply token-aware truncation to all texts # Apply truncation to all texts before batch processing
truncated_texts = truncate_to_token_limit(texts, token_limit) # Function logs truncation details internally
if len(truncated_texts) != len(texts): texts = truncate_to_token_limit(texts, token_limit)
logger.error("Truncation failed - text count mismatch")
truncated_texts = texts # Fallback to original texts
def get_batch_embeddings(batch_texts): def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts using /api/embed endpoint.""" """Get embeddings for a batch of texts using /api/embed endpoint."""
@@ -880,12 +1107,12 @@ def compute_embeddings_ollama(
return None, list(range(len(batch_texts))) return None, list(range(len(batch_texts)))
# Process truncated texts in batches # Process texts in batches
all_embeddings = [] all_embeddings = []
all_failed_indices = [] all_failed_indices = []
# Setup progress bar if needed # Setup progress bar if needed
show_progress = is_build or len(truncated_texts) > 10 show_progress = is_build or len(texts) > 10
try: try:
if show_progress: if show_progress:
from tqdm import tqdm from tqdm import tqdm
@@ -893,7 +1120,7 @@ def compute_embeddings_ollama(
show_progress = False show_progress = False
# Process batches # Process batches
num_batches = (len(truncated_texts) + batch_size - 1) // batch_size num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress: if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)") batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
@@ -902,8 +1129,8 @@ def compute_embeddings_ollama(
for batch_idx in batch_iterator: for batch_idx in batch_iterator:
start_idx = batch_idx * batch_size start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(truncated_texts)) end_idx = min(start_idx + batch_size, len(texts))
batch_texts = truncated_texts[start_idx:end_idx] batch_texts = texts[start_idx:end_idx]
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts) batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
@@ -918,11 +1145,11 @@ def compute_embeddings_ollama(
# Handle failed embeddings # Handle failed embeddings
if all_failed_indices: if all_failed_indices:
if len(all_failed_indices) == len(truncated_texts): if len(all_failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings") raise RuntimeError("Failed to compute any embeddings")
logger.warning( logger.warning(
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts" f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
) )
# Use zero embeddings as fallback for failed ones # Use zero embeddings as fallback for failed ones

View File

@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
query: str, query: str,
use_server_if_available: bool = True, use_server_if_available: bool = True,
zmq_port: Optional[int] = None, zmq_port: Optional[int] = None,
query_template: Optional[str] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Compute embedding for a query string """Compute embedding for a query string
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
query: The query string to embed query: The query string to embed
zmq_port: ZMQ port for embedding server zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first use_server_if_available: Whether to try using embedding server first
query_template: Optional prompt template to prepend to query
Returns: Returns:
Query embedding as numpy array with shape (1, D) Query embedding as numpy array with shape (1, D)

View File

@@ -33,6 +33,8 @@ def autodiscover_backends():
discovered_backends = [] discovered_backends = []
for dist in importlib.metadata.distributions(): for dist in importlib.metadata.distributions():
dist_name = dist.metadata["name"] dist_name = dist.metadata["name"]
if dist_name is None:
continue
if dist_name.startswith("leann-backend-"): if dist_name.startswith("leann-backend-"):
backend_module_name = dist_name.replace("-", "_") backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name) discovered_backends.append(backend_module_name)

View File

@@ -71,6 +71,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
or "mips" or "mips"
) )
# Filter out ALL prompt templates from provider_options during search
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
# The server should never apply templates during search to avoid double-templating
search_provider_options = {
k: v
for k, v in self.embedding_options.items()
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
}
server_started, actual_port = self.embedding_server_manager.start_server( server_started, actual_port = self.embedding_server_manager.start_server(
port=port, port=port,
model_name=self.embedding_model, model_name=self.embedding_model,
@@ -78,7 +87,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file, passages_file=passages_source_file,
distance_metric=distance_metric, distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False), enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options, provider_options=search_provider_options,
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}") raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -90,6 +99,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
query: str, query: str,
use_server_if_available: bool = True, use_server_if_available: bool = True,
zmq_port: int = 5557, zmq_port: int = 5557,
query_template: Optional[str] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embedding for a query string. Compute embedding for a query string.
@@ -98,10 +108,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
query: The query string to embed query: The query string to embed
zmq_port: ZMQ port for embedding server zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first use_server_if_available: Whether to try using embedding server first
query_template: Optional prompt template to prepend to query
Returns: Returns:
Query embedding as numpy array Query embedding as numpy array
""" """
# Apply query template BEFORE any computation path
# This ensures template is applied consistently for both server and fallback paths
if query_template:
query = f"{query_template}{query}"
# Try to use embedding server if available and requested # Try to use embedding server if available and requested
if use_server_if_available: if use_server_if_available:
try: try:

View File

@@ -9,6 +9,7 @@ from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place. # Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434" _DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1" _DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
def _clean_url(value: str) -> str: def _clean_url(value: str) -> str:
@@ -52,6 +53,23 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
return _clean_url(_DEFAULT_OPENAI_BASE_URL) return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for Anthropic-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
os.getenv("ANTHROPIC_BASE_URL"),
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None: def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services.""" """Resolve the API key for OpenAI-compatible services."""
@@ -61,6 +79,15 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
return os.getenv("OPENAI_API_KEY") return os.getenv("OPENAI_API_KEY")
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for Anthropic services."""
if explicit:
return explicit
return os.getenv("ANTHROPIC_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None: def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes.""" """Serialize provider options for child processes."""

View File

@@ -53,6 +53,11 @@ leann build my-project --docs $(git ls-files)
# Start Claude Code # Start Claude Code
claude claude
``` ```
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
```bash
leann build my-project --docs $(git ls-files) --no-recompute
```
## 🚀 Advanced Usage Examples to build the index ## 🚀 Advanced Usage Examples to build the index

View File

@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann" name = "leann"
version = "0.3.4" version = "0.3.5"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.10"
license = { text = "MIT" } license = { text = "MIT" }
authors = [ authors = [
{ name = "LEANN Team" } { name = "LEANN Team" }
@@ -18,10 +18,10 @@ classifiers = [
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
] ]
# Default installation: core + hnsw + diskann # Default installation: core + hnsw + diskann

View File

@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-workspace" name = "leann-workspace"
version = "0.1.0" version = "0.1.0"
requires-python = ">=3.9" requires-python = ">=3.10"
dependencies = [ dependencies = [
"leann-core", "leann-core",
@@ -57,6 +57,8 @@ dependencies = [
"tree-sitter-c-sharp>=0.20.0", "tree-sitter-c-sharp>=0.20.0",
"tree-sitter-typescript>=0.20.0", "tree-sitter-typescript>=0.20.0",
"torchvision>=0.23.0", "torchvision>=0.23.0",
"einops",
"seaborn",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -67,7 +69,8 @@ diskann = [
# Add a new optional dependency group for document processing # Add a new optional dependency group for document processing
documents = [ documents = [
"beautifulsoup4>=4.13.0", # For HTML parsing "beautifulsoup4>=4.13.0", # For HTML parsing
"python-docx>=0.8.11", # For Word documents "python-docx>=0.8.11", # For Word documents (creating/editing)
"docx2txt>=0.9", # For Word documents (text extraction)
"openpyxl>=3.1.0", # For Excel files "openpyxl>=3.1.0", # For Excel files
"pandas>=2.2.0", # For data processing "pandas>=2.2.0", # For data processing
] ]
@@ -162,6 +165,7 @@ python_functions = ["test_*"]
markers = [ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')", "slow: marks tests as slow (deselect with '-m \"not slow\"')",
"openai: marks tests that require OpenAI API key", "openai: marks tests that require OpenAI API key",
"integration: marks tests that require live services (Ollama, LM Studio, etc.)",
] ]
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
addopts = [ addopts = [

View File

@@ -1,162 +0,0 @@
#!/usr/bin/env python3
"""
Test script to reproduce ColQwen results from issue #119
https://github.com/yichuan-w/LEANN/issues/119
This script demonstrates the ColQwen workflow:
1. Download sample PDF
2. Convert to images
3. Build multimodal index
4. Run test queries
5. Generate similarity maps
"""
import importlib.util
import os
from pathlib import Path
def main():
print("🧪 ColQwen Reproduction Test - Issue #119")
print("=" * 50)
# Check if we're in the right directory
repo_root = Path.cwd()
if not (repo_root / "apps" / "colqwen_rag.py").exists():
print("❌ Please run this script from the LEANN repository root")
print(" cd /path/to/LEANN && python test_colqwen_reproduction.py")
return
print("✅ Repository structure looks good")
# Step 1: Check dependencies
print("\n📦 Checking dependencies...")
try:
import torch
# Check if pdf2image is available
if importlib.util.find_spec("pdf2image") is None:
raise ImportError("pdf2image not found")
# Check if colpali_engine is available
if importlib.util.find_spec("colpali_engine") is None:
raise ImportError("colpali_engine not found")
print("✅ Core dependencies available")
print(f" - PyTorch: {torch.__version__}")
print(f" - CUDA available: {torch.cuda.is_available()}")
print(
f" - MPS available: {hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()}"
)
except ImportError as e:
print(f"❌ Missing dependency: {e}")
print("\n📥 Install missing dependencies:")
print(
" uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn"
)
return
# Step 2: Download sample PDF
print("\n📄 Setting up sample PDF...")
pdf_dir = repo_root / "test_pdfs"
pdf_dir.mkdir(exist_ok=True)
sample_pdf = pdf_dir / "attention_paper.pdf"
if not sample_pdf.exists():
print("📥 Downloading sample paper (Attention Is All You Need)...")
import urllib.request
try:
urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", sample_pdf)
print(f"✅ Downloaded: {sample_pdf}")
except Exception as e:
print(f"❌ Download failed: {e}")
print(" Please manually download a PDF to test_pdfs/attention_paper.pdf")
return
else:
print(f"✅ Using existing PDF: {sample_pdf}")
# Step 3: Test ColQwen RAG
print("\n🚀 Testing ColQwen RAG...")
# Build index
print("\n1⃣ Building multimodal index...")
build_cmd = f"python -m apps.colqwen_rag build --pdfs {pdf_dir} --index test_attention --model colqwen2 --pages-dir test_pages"
print(f" Command: {build_cmd}")
try:
result = os.system(build_cmd)
if result == 0:
print("✅ Index built successfully!")
else:
print("❌ Index building failed")
return
except Exception as e:
print(f"❌ Error building index: {e}")
return
# Test search
print("\n2⃣ Testing search...")
test_queries = [
"How does attention mechanism work?",
"What is the transformer architecture?",
"How do you compute self-attention?",
]
for query in test_queries:
print(f"\n🔍 Query: '{query}'")
search_cmd = f'python -m apps.colqwen_rag search test_attention "{query}" --top-k 3'
print(f" Command: {search_cmd}")
try:
result = os.system(search_cmd)
if result == 0:
print("✅ Search completed")
else:
print("❌ Search failed")
except Exception as e:
print(f"❌ Search error: {e}")
# Test interactive mode (briefly)
print("\n3⃣ Testing interactive mode...")
print(" You can test interactive mode with:")
print(" python -m apps.colqwen_rag ask test_attention --interactive")
# Step 4: Test similarity maps (using existing script)
print("\n4⃣ Testing similarity maps...")
similarity_script = (
repo_root
/ "apps"
/ "multimodal"
/ "vision-based-pdf-multi-vector"
/ "multi-vector-leann-similarity-map.py"
)
if similarity_script.exists():
print(" You can generate similarity maps with:")
print(f" cd {similarity_script.parent}")
print(" python multi-vector-leann-similarity-map.py")
print(" (Edit the script to use your local PDF)")
print("\n🎉 ColQwen reproduction test completed!")
print("\n📋 Summary:")
print(" ✅ Dependencies checked")
print(" ✅ Sample PDF prepared")
print(" ✅ Index building tested")
print(" ✅ Search functionality tested")
print(" ✅ Interactive mode available")
print(" ✅ Similarity maps available")
print("\n🔗 Related repositories to check:")
print(" - https://github.com/lightonai/fast-plaid")
print(" - https://github.com/lightonai/pylate")
print(" - https://github.com/stanford-futuredata/ColBERT")
print("\n📝 Next steps:")
print(" 1. Test with your own PDFs")
print(" 2. Experiment with different queries")
print(" 3. Generate similarity maps for visual analysis")
print(" 4. Compare ColQwen2 vs ColPali performance")
if __name__ == "__main__":
main()

View File

@@ -36,6 +36,14 @@ Tests DiskANN graph partitioning functionality:
- Includes performance comparison between DiskANN (with partition) and HNSW - Includes performance comparison between DiskANN (with partition) and HNSW
- **Note**: These tests are skipped in CI due to hardware requirements and computation time - **Note**: These tests are skipped in CI due to hardware requirements and computation time
### `test_prompt_template_e2e.py`
Integration tests for prompt template feature with live embedding services:
- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio)
- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default)
- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk)
- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration`
- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models
## Running Tests ## Running Tests
### Install test dependencies: ### Install test dependencies:
@@ -66,6 +74,12 @@ pytest tests/ -m "not openai"
# Skip slow tests # Skip slow tests
pytest tests/ -m "not slow" pytest tests/ -m "not slow"
# Skip integration tests (that require live services)
pytest tests/ -m "not integration"
# Run only integration tests (requires LM Studio or Ollama running)
pytest tests/test_prompt_template_e2e.py -v -s
# Run DiskANN partition tests (requires local machine, not CI) # Run DiskANN partition tests (requires local machine, not CI)
pytest tests/test_diskann_partition.py pytest tests/test_diskann_partition.py
``` ```
@@ -101,6 +115,20 @@ The `pytest.ini` file configures:
- Custom markers for slow and OpenAI tests - Custom markers for slow and OpenAI tests
- Verbose output with short tracebacks - Verbose output with short tracebacks
### Integration Test Prerequisites
Integration tests (`test_prompt_template_e2e.py`) require live services:
**Required:**
- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded
**Optional:**
- Ollama running at `http://localhost:11434` for token limit detection tests
- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests
Tests gracefully skip if services are unavailable.
### Known Issues ### Known Issues
- OpenAI tests are automatically skipped if no API key is provided - OpenAI tests are automatically skipped if no API key is provided
- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed)

View File

@@ -8,7 +8,7 @@ import subprocess
import sys import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest import pytest
@@ -116,8 +116,10 @@ class TestChunkingFunctions:
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10) chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0 assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks) # Traditional chunks now return dict format for consistency
assert all(len(chunk.strip()) > 0 for chunk in chunks) assert all(isinstance(chunk, dict) for chunk in chunks)
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
def test_create_traditional_chunks_empty_docs(self): def test_create_traditional_chunks_empty_docs(self):
"""Test traditional chunking with empty documents.""" """Test traditional chunking with empty documents."""
@@ -158,11 +160,22 @@ class Calculator:
# Should have multiple chunks due to different functions/classes # Should have multiple chunks due to different functions/classes
assert len(chunks) > 0 assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks) # R3: Expect dict format with "text" and "metadata" keys
assert all(len(chunk.strip()) > 0 for chunk in chunks) assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
"Each chunk text should be non-empty"
)
# Check metadata is present
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
"Each chunk should have file_path metadata"
)
# Check that code structure is somewhat preserved # Check that code structure is somewhat preserved
combined_content = " ".join(chunks) combined_content = " ".join([c["text"] for c in chunks])
assert "def hello_world" in combined_content assert "def hello_world" in combined_content
assert "class Calculator" in combined_content assert "class Calculator" in combined_content
@@ -194,7 +207,11 @@ class Calculator:
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10) chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0 assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks) # R3: Traditional chunking should also return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_ast_mode(self): def test_create_text_chunks_ast_mode(self):
"""Test text chunking in AST mode.""" """Test text chunking in AST mode."""
@@ -213,7 +230,11 @@ class Calculator:
) )
assert len(chunks) > 0 assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks) # R3: AST mode should also return dict format
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_custom_extensions(self): def test_create_text_chunks_custom_extensions(self):
"""Test text chunking with custom code file extensions.""" """Test text chunking with custom code file extensions."""
@@ -353,6 +374,552 @@ class MathUtils:
pytest.skip("Test timed out - likely due to model download in CI") pytest.skip("Test timed out - likely due to model download in CI")
class TestASTContentExtraction:
"""Test AST content extraction bug fix.
These tests verify that astchunk's dict format with 'content' key is handled correctly,
and that the extraction logic doesn't fall through to stringifying entire dicts.
"""
def test_extract_content_from_astchunk_dict(self):
"""Test that astchunk dict format with 'content' key is handled correctly.
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
This causes fallthrough to str(chunk), stringifying the entire dict.
This test will FAIL until the bug is fixed because:
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
- Fixed code should extract just the content value
"""
# Mock the ASTChunkBuilder class
mock_builder = Mock()
# Astchunk returns this format
astchunk_format_chunk = {
"content": "def hello():\n print('world')",
"metadata": {
"filepath": "test.py",
"line_count": 2,
"start_line_no": 0,
"end_line_no": 1,
"node_count": 1,
},
}
mock_builder.chunkify.return_value = [astchunk_format_chunk]
# Create mock document
doc = MockDocument(
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
)
# Mock the astchunk module and its ASTChunkBuilder class
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
# Patch sys.modules to inject our mock before the import
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should return dict format with proper metadata
assert len(chunks) > 0, "Should return at least one chunk"
# R3: Each chunk should be a dict
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "metadata" in chunk, "Chunk should have 'metadata' key"
chunk_text = chunk["text"]
# CRITICAL: Should NOT contain stringified dict markers in the text field
# These assertions will FAIL with current buggy code
assert "'content':" not in chunk_text, (
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
)
assert "'metadata':" not in chunk_text, (
"Chunk text contains stringified metadata - extraction failed! "
f"Got: {chunk_text[:100]}..."
)
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
"Chunk text appears to be a stringified dict"
)
# Should contain actual content
assert "def hello()" in chunk_text, "Should extract actual code content"
assert "print('world')" in chunk_text, "Should extract complete code content"
# R3: Should preserve astchunk metadata
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
"Should preserve file path metadata"
)
def test_extract_text_key_fallback(self):
"""Test that 'text' key still works for backward compatibility.
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
This test should PASS even with current code.
"""
mock_builder = Mock()
# Some chunks might use "text" key
text_key_chunk = {"text": "def legacy_function():\n return True"}
mock_builder.chunkify.return_value = [text_key_chunk]
# Create mock document
doc = MockDocument(
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract text correctly as dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
# Should NOT be stringified
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
# Should contain actual content
assert "def legacy_function()" in chunk_text
assert "return True" in chunk_text
def test_handles_string_chunks(self):
"""Test that plain string chunks still work.
Some chunkers might return plain strings - verify these are preserved.
This test should PASS with current code.
"""
mock_builder = Mock()
# Plain string chunk
plain_string_chunk = "def simple_function():\n pass"
mock_builder.chunkify.return_value = [plain_string_chunk]
# Create mock document
doc = MockDocument(
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should wrap string in dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
assert chunk_text == plain_string_chunk.strip(), (
"Should preserve plain string chunk content"
)
assert "def simple_function()" in chunk_text
assert "pass" in chunk_text
def test_multiple_chunks_with_mixed_formats(self):
"""Test handling of multiple chunks with different formats.
Real-world scenario: astchunk might return a mix of formats.
This test will FAIL if any chunk with 'content' key gets stringified.
"""
mock_builder = Mock()
# Mix of formats
mixed_chunks = [
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
"def second():\n return 2", # Plain string
{"text": "def third():\n return 3"}, # Old format
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
]
mock_builder.chunkify.return_value = mixed_chunks
# Create mock document
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract all chunks correctly as dicts
assert len(chunks) == 4, "Should extract all 4 chunks"
# Check each chunk
for i, chunk in enumerate(chunks):
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
assert "text" in chunk, f"Chunk {i} should have 'text' key"
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
chunk_text = chunk["text"]
# None should be stringified dicts
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
assert "'metadata':" not in chunk_text, (
f"Chunk {i} text is stringified (has 'metadata':)"
)
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
# Verify actual content is present
combined = "\n".join([c["text"] for c in chunks])
assert "def first()" in combined
assert "def second()" in combined
assert "def third()" in combined
assert "class MyClass:" in combined
def test_empty_content_value_handling(self):
"""Test handling of chunks with empty content values.
Edge case: chunk has 'content' key but value is empty.
Should skip these chunks, not stringify them.
"""
mock_builder = Mock()
chunks_with_empty = [
{"content": "", "metadata": {"line_count": 0}}, # Empty content
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
]
mock_builder.chunkify.return_value = chunks_with_empty
doc = MockDocument(
"def valid():\n return True", "/test/empty.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# R3: Should only have the valid chunk (empty ones filtered out)
assert len(chunks) == 1, "Should filter out empty content chunks"
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "def valid()" in chunk["text"]
# Should not have stringified the empty dict
assert "'content': ''" not in chunk["text"]
class TestASTMetadataPreservation:
"""Test metadata preservation in AST chunk dictionaries.
R3: These tests define the contract for metadata preservation when returning
chunk dictionaries instead of plain strings. Each chunk dict should have:
- "text": str - the actual chunk content
- "metadata": dict - all metadata from document AND astchunk
These tests will FAIL until G3 implementation changes return type to list[dict].
"""
def test_ast_chunks_preserve_file_metadata(self):
"""Test that document metadata is preserved in chunk metadata.
This test verifies that all document-level metadata (file_path, file_name,
creation_date, last_modified_date) is included in each chunk's metadata dict.
This will FAIL because current code returns list[str], not list[dict].
"""
# Create mock document with rich metadata
python_code = '''
def calculate_sum(numbers):
"""Calculate sum of numbers."""
return sum(numbers)
class DataProcessor:
"""Process data records."""
def process(self, data):
return [x * 2 for x in data]
'''
doc = MockDocument(
python_code,
file_path="/project/src/utils.py",
metadata={
"language": "python",
"file_path": "/project/src/utils.py",
"file_name": "utils.py",
"creation_date": "2024-01-15T10:30:00",
"last_modified_date": "2024-10-31T15:45:00",
},
)
# Mock astchunk to return chunks with metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def calculate_sum(numbers):\n return sum(numbers)",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 2,
"start_line_no": 1,
"end_line_no": 2,
"node_count": 1,
},
},
{
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 3,
"start_line_no": 5,
"end_line_no": 7,
"node_count": 2,
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These assertions will FAIL with current list[str] return type
assert len(chunks) == 2, "Should return 2 chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL: current code returns strings
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
# Document metadata preservation - WILL FAIL
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
assert metadata["file_path"] == "/project/src/utils.py", (
f"Chunk {i} file_path incorrect"
)
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
f"Chunk {i} creation_date incorrect"
)
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
f"Chunk {i} last_modified_date incorrect"
)
# Verify metadata is consistent across chunks from same document
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
"All chunks from same document should have same file_path"
)
# Verify text content is present and not stringified
assert "def calculate_sum" in chunks[0]["text"]
assert "class DataProcessor" in chunks[1]["text"]
def test_ast_chunks_include_astchunk_metadata(self):
"""Test that astchunk-specific metadata is merged into chunk metadata.
This test verifies that astchunk's metadata (line_count, start_line_no,
end_line_no, node_count) is merged with document metadata.
This will FAIL because current code returns list[str], not list[dict].
"""
python_code = '''
def function_one():
"""First function."""
x = 1
y = 2
return x + y
def function_two():
"""Second function."""
return 42
'''
doc = MockDocument(
python_code,
file_path="/test/code.py",
metadata={
"language": "python",
"file_path": "/test/code.py",
"file_name": "code.py",
},
)
# Mock astchunk with detailed metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
"metadata": {
"filepath": "/test/code.py",
"line_count": 4,
"start_line_no": 1,
"end_line_no": 4,
"node_count": 5, # function, assignments, return
},
},
{
"content": "def function_two():\n return 42",
"metadata": {
"filepath": "/test/code.py",
"line_count": 2,
"start_line_no": 7,
"end_line_no": 8,
"node_count": 2, # function, return
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These will FAIL with current list[str] return
assert len(chunks) == 2
# First chunk - function_one
chunk1 = chunks[0]
assert isinstance(chunk1, dict), "Chunk should be dict"
assert "metadata" in chunk1
metadata1 = chunk1["metadata"]
# Check astchunk metadata is present
assert "line_count" in metadata1, "Should include astchunk line_count"
assert metadata1["line_count"] == 4, "line_count should be 4"
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
assert "node_count" in metadata1, "Should include astchunk node_count"
assert metadata1["node_count"] == 5, "node_count should be 5"
# Second chunk - function_two
chunk2 = chunks[1]
metadata2 = chunk2["metadata"]
assert metadata2["line_count"] == 2, "line_count should be 2"
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
assert metadata2["node_count"] == 2, "node_count should be 2"
# Verify document metadata is ALSO present (merged, not replaced)
assert metadata1["file_path"] == "/test/code.py"
assert metadata1["file_name"] == "code.py"
assert metadata2["file_path"] == "/test/code.py"
assert metadata2["file_name"] == "code.py"
# Verify text content is correct
assert "def function_one" in chunk1["text"]
assert "def function_two" in chunk2["text"]
def test_traditional_chunks_as_dicts_helper(self):
"""Test the helper function that wraps traditional chunks as dicts.
This test verifies that when create_traditional_chunks is called,
its plain string chunks are wrapped into dict format with metadata.
This will FAIL because the helper function _traditional_chunks_as_dicts()
doesn't exist yet, and create_traditional_chunks returns list[str].
"""
# Create documents with various metadata
docs = [
MockDocument(
"This is the first paragraph of text. It contains multiple sentences. "
"This should be split into chunks based on size.",
file_path="/docs/readme.txt",
metadata={
"file_path": "/docs/readme.txt",
"file_name": "readme.txt",
"creation_date": "2024-01-01",
},
),
MockDocument(
"Second document with different metadata. It also has content that needs chunking.",
file_path="/docs/guide.md",
metadata={
"file_path": "/docs/guide.md",
"file_name": "guide.md",
"last_modified_date": "2024-10-31",
},
),
]
# Call create_traditional_chunks (which should now return list[dict])
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
# CRITICAL: Will FAIL - current code returns list[str]
assert len(chunks) > 0, "Should return chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
# Text should be non-empty
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
# Metadata should include document info
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
# Verify metadata tracking works correctly
# At least one chunk should be from readme.txt
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
# At least one chunk should be from guide.md
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
# Verify creation_date is preserved for readme chunks
for chunk in readme_chunks:
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
"readme.txt chunks should preserve creation_date"
)
# Verify last_modified_date is preserved for guide chunks
for chunk in guide_chunks:
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
"guide.md chunks should preserve last_modified_date"
)
# Verify text content is present
all_text = " ".join([c["text"] for c in chunks])
assert "first paragraph" in all_text
assert "Second document" in all_text
class TestErrorHandling: class TestErrorHandling:
"""Test error handling and edge cases.""" """Test error handling and edge cases."""

View File

@@ -0,0 +1,533 @@
"""
Tests for CLI argument integration of --embedding-prompt-template.
These tests verify that:
1. The --embedding-prompt-template flag is properly registered on build and search commands
2. The template value flows from CLI args to embedding_options dict
3. The template is passed through to compute_embeddings() function
4. Default behavior (no flag) is handled correctly
"""
from unittest.mock import Mock, patch
from leann.cli import LeannCLI
class TestCLIPromptTemplateArgument:
"""Tests for --embedding-prompt-template on build and search commands."""
def test_commands_accept_prompt_template_argument(self):
"""Verify that build and search parsers accept --embedding-prompt-template flag."""
cli = LeannCLI()
parser = cli.create_parser()
template_value = "search_query: "
# Test build command
build_args = parser.parse_args(
[
"build",
"test-index",
"--docs",
"/tmp/test-docs",
"--embedding-prompt-template",
template_value,
]
)
assert build_args.command == "build"
assert hasattr(build_args, "embedding_prompt_template"), (
"build command should have embedding_prompt_template attribute"
)
assert build_args.embedding_prompt_template == template_value
# Test search command
search_args = parser.parse_args(
["search", "test-index", "my query", "--embedding-prompt-template", template_value]
)
assert search_args.command == "search"
assert hasattr(search_args, "embedding_prompt_template"), (
"search command should have embedding_prompt_template attribute"
)
assert search_args.embedding_prompt_template == template_value
def test_commands_default_to_none(self):
"""Verify default value is None when flag not provided (backward compatibility)."""
cli = LeannCLI()
parser = cli.create_parser()
# Test build command default
build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"])
assert hasattr(build_args, "embedding_prompt_template"), (
"build command should have embedding_prompt_template attribute"
)
assert build_args.embedding_prompt_template is None, (
"Build default value should be None when flag not provided"
)
# Test search command default
search_args = parser.parse_args(["search", "test-index", "my query"])
assert hasattr(search_args, "embedding_prompt_template"), (
"search command should have embedding_prompt_template attribute"
)
assert search_args.embedding_prompt_template is None, (
"Search default value should be None when flag not provided"
)
class TestBuildCommandPromptTemplateArgumentExtras:
"""Additional build-specific tests for prompt template argument."""
def test_build_command_prompt_template_with_multiword_value(self):
"""
Verify that template values with spaces are handled correctly.
Templates like "search_document: " or "Represent this sentence for searching: "
should be accepted as a single string argument.
"""
cli = LeannCLI()
parser = cli.create_parser()
template = "Represent this sentence for searching: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
"/tmp/test-docs",
"--embedding-prompt-template",
template,
]
)
assert args.embedding_prompt_template == template
class TestPromptTemplateStoredInEmbeddingOptions:
"""Tests for template storage in embedding_options dict."""
@patch("leann.cli.LeannBuilder")
def test_prompt_template_stored_in_embedding_options_on_build(
self, mock_builder_class, tmp_path
):
"""
Verify that when --embedding-prompt-template is provided to build command,
the value is stored in embedding_options dict passed to LeannBuilder.
This test will fail because the CLI doesn't currently process this argument
and add it to embedding_options.
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
# Create CLI and run build command
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "search_query: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--force", # Force rebuild to ensure LeannBuilder is called
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with embedding_options containing prompt_template
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when template provided"
)
assert "prompt_template" in embedding_options, (
"embedding_options should contain 'prompt_template' key"
)
assert embedding_options["prompt_template"] == template, (
f"Template should be '{template}', got {embedding_options.get('prompt_template')}"
)
@patch("leann.cli.LeannBuilder")
def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path):
"""
Verify that when --embedding-prompt-template is NOT provided,
embedding_options either doesn't have the key or it's None.
This ensures we don't pass empty/None values unnecessarily.
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--force", # Force rebuild to ensure LeannBuilder is called
]
)
import asyncio
asyncio.run(cli.build_index(args))
# Check that if embedding_options is passed, it doesn't have prompt_template
call_kwargs = mock_builder_class.call_args.kwargs
if call_kwargs.get("embedding_options"):
embedding_options = call_kwargs["embedding_options"]
# Either the key shouldn't exist, or it should be None
assert (
"prompt_template" not in embedding_options
or embedding_options["prompt_template"] is None
), "prompt_template should not be set when flag not provided"
# R1 Tests: Build-time separate template storage
@patch("leann.cli.LeannBuilder")
def test_build_stores_separate_templates(self, mock_builder_class, tmp_path):
"""
R1 Test 1: Verify that when both --embedding-prompt-template and
--query-prompt-template are provided to build command, both values
are stored separately in embedding_options dict as build_prompt_template
and query_prompt_template.
This test will fail because:
1. CLI doesn't accept --query-prompt-template flag yet
2. CLI doesn't store templates as separate build_prompt_template and
query_prompt_template keys
Expected behavior after implementation:
- .meta.json contains: {"embedding_options": {
"build_prompt_template": "doc: ",
"query_prompt_template": "query: "
}}
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
build_template = "doc: "
query_template = "query: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
build_template,
"--query-prompt-template",
query_template,
"--force",
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with separate template keys
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when templates provided"
)
assert "build_prompt_template" in embedding_options, (
"embedding_options should contain 'build_prompt_template' key"
)
assert embedding_options["build_prompt_template"] == build_template, (
f"build_prompt_template should be '{build_template}'"
)
assert "query_prompt_template" in embedding_options, (
"embedding_options should contain 'query_prompt_template' key"
)
assert embedding_options["query_prompt_template"] == query_template, (
f"query_prompt_template should be '{query_template}'"
)
# Old key should NOT be present when using new separate template format
assert "prompt_template" not in embedding_options, (
"Old 'prompt_template' key should not be present with separate templates"
)
@patch("leann.cli.LeannBuilder")
def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path):
"""
R1 Test 2: Verify backward compatibility - when only
--embedding-prompt-template is provided (old behavior), it should
still be stored as 'prompt_template' in embedding_options.
This ensures existing workflows continue to work unchanged.
This test currently passes because it matches existing behavior, but it
documents the requirement that this behavior must be preserved after
implementing the separate template feature.
Expected behavior:
- .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}}
- No build_prompt_template or query_prompt_template keys
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "prompt: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--force",
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with old format
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when template provided"
)
assert "prompt_template" in embedding_options, (
"embedding_options should contain old 'prompt_template' key for backward compat"
)
assert embedding_options["prompt_template"] == template, (
f"prompt_template should be '{template}'"
)
# New keys should NOT be present in backward compat mode
assert "build_prompt_template" not in embedding_options, (
"build_prompt_template should not be present with single template flag"
)
assert "query_prompt_template" not in embedding_options, (
"query_prompt_template should not be present with single template flag"
)
@patch("leann.cli.LeannBuilder")
def test_build_no_templates(self, mock_builder_class, tmp_path):
"""
R1 Test 3: Verify that when no template flags are provided,
embedding_options has no prompt template keys.
This ensures clean defaults and no unnecessary keys in .meta.json.
This test currently passes because it matches existing behavior, but it
documents the requirement that this behavior must be preserved after
implementing the separate template feature.
Expected behavior:
- .meta.json has no prompt_template, build_prompt_template, or
query_prompt_template keys (or embedding_options is empty/None)
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"])
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that no template keys are present
call_kwargs = mock_builder_class.call_args.kwargs
if call_kwargs.get("embedding_options"):
embedding_options = call_kwargs["embedding_options"]
# None of the template keys should be present
assert "prompt_template" not in embedding_options, (
"prompt_template should not be present when no flags provided"
)
assert "build_prompt_template" not in embedding_options, (
"build_prompt_template should not be present when no flags provided"
)
assert "query_prompt_template" not in embedding_options, (
"query_prompt_template should not be present when no flags provided"
)
class TestPromptTemplateFlowsToComputeEmbeddings:
"""Tests for template flowing through to compute_embeddings function."""
@patch("leann.api.compute_embeddings")
def test_prompt_template_flows_to_compute_embeddings_via_provider_options(
self, mock_compute_embeddings, tmp_path
):
"""
Verify that the prompt template flows from CLI args through LeannBuilder
to compute_embeddings() function via provider_options parameter.
This is an integration test that verifies the complete flow:
CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options)
This test will fail because:
1. CLI doesn't capture the argument yet
2. embedding_options doesn't include prompt_template
3. LeannBuilder doesn't pass it through to compute_embeddings
"""
# Mock compute_embeddings to return dummy embeddings as numpy array
import numpy as np
mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
# Use real LeannBuilder (not mocked) to test the actual flow
cli = LeannCLI()
# Mock load_documents to return a simple document
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "search_document: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--backend-name",
"hnsw", # Use hnsw backend
"--force", # Force rebuild to ensure index is created
]
)
# This should fail because the flow isn't implemented yet
import asyncio
asyncio.run(cli.build_index(args))
# Verify compute_embeddings was called with provider_options containing prompt_template
assert mock_compute_embeddings.called, "compute_embeddings should have been called"
# Check the call arguments
call_kwargs = mock_compute_embeddings.call_args.kwargs
assert "provider_options" in call_kwargs, (
"compute_embeddings should receive provider_options parameter"
)
provider_options = call_kwargs["provider_options"]
assert provider_options is not None, "provider_options should not be None"
assert "prompt_template" in provider_options, (
"provider_options should contain prompt_template key"
)
assert provider_options["prompt_template"] == template, (
f"Template should be '{template}', got {provider_options.get('prompt_template')}"
)
class TestPromptTemplateArgumentHelp:
"""Tests for argument help text and documentation."""
def test_build_command_prompt_template_has_help_text(self):
"""
Verify that --embedding-prompt-template has descriptive help text.
Good help text is crucial for CLI usability.
"""
cli = LeannCLI()
parser = cli.create_parser()
# Get the build subparser
# This is a bit tricky - we need to parse to get the help
# We'll check that the help includes relevant keywords
import io
from contextlib import redirect_stdout
f = io.StringIO()
try:
with redirect_stdout(f):
parser.parse_args(["build", "--help"])
except SystemExit:
pass # --help causes sys.exit(0)
help_text = f.getvalue()
assert "--embedding-prompt-template" in help_text, (
"Help text should mention --embedding-prompt-template"
)
# Check for keywords that should be in the help
help_lower = help_text.lower()
assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), (
"Help text should explain what the prompt template does"
)
def test_search_command_prompt_template_has_help_text(self):
"""
Verify that search command also has help text for --embedding-prompt-template.
"""
cli = LeannCLI()
parser = cli.create_parser()
import io
from contextlib import redirect_stdout
f = io.StringIO()
try:
with redirect_stdout(f):
parser.parse_args(["search", "--help"])
except SystemExit:
pass # --help causes sys.exit(0)
help_text = f.getvalue()
assert "--embedding-prompt-template" in help_text, (
"Search help text should mention --embedding-prompt-template"
)

View File

@@ -0,0 +1,281 @@
"""Unit tests for prompt template prepending in OpenAI embeddings.
This test suite defines the contract for prompt template functionality that allows
users to prepend a consistent prompt to all embedding inputs. These tests verify:
1. Template prepending to all input texts before embedding computation
2. Graceful handling of None/missing provider_options
3. Empty string template behavior (no-op)
4. Logging of template application for observability
5. Template application before token truncation
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import pytest
from leann.embedding_compute import compute_embeddings_openai
class TestPromptTemplatePrepending:
"""Tests for prompt template prepending in compute_embeddings_openai."""
@pytest.fixture
def mock_openai_client(self):
"""Create mock OpenAI client that captures input texts."""
mock_client = MagicMock()
# Mock the embeddings.create response
mock_response = Mock()
mock_response.data = [
Mock(embedding=[0.1, 0.2, 0.3]),
Mock(embedding=[0.4, 0.5, 0.6]),
]
mock_client.embeddings.create.return_value = mock_response
return mock_client
@pytest.fixture
def mock_openai_module(self, mock_openai_client, monkeypatch):
"""Mock the openai module to return our mock client."""
# Mock the API key environment variable
monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking")
# openai is imported inside the function, so we need to patch it there
with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai:
yield mock_openai
def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client):
"""Verify template is prepended to all input texts.
When provider_options contains "prompt_template", that template should
be prepended to every text in the input list before sending to OpenAI API.
This is the core functionality: the template acts as a consistent prefix
that provides context or instruction for the embedding model.
"""
texts = ["First document", "Second document"]
template = "search_document: "
provider_options = {"prompt_template": template}
# Call compute_embeddings_openai with provider_options
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify embeddings.create was called with templated texts
mock_openai_client.embeddings.create.assert_called_once()
call_args = mock_openai_client.embeddings.create.call_args
# Extract the input texts sent to API
sent_texts = call_args.kwargs["input"]
# Verify template was prepended to all texts
assert len(sent_texts) == 2, "Should send same number of texts"
assert sent_texts[0] == "search_document: First document", (
"Template should be prepended to first text"
)
assert sent_texts[1] == "search_document: Second document", (
"Template should be prepended to second text"
)
# Verify result is valid embeddings array
assert isinstance(result, np.ndarray)
assert result.shape == (2, 3), "Should return correct shape"
def test_template_not_applied_when_missing_or_empty(
self, mock_openai_module, mock_openai_client
):
"""Verify template not applied when provider_options is None, missing key, or empty string.
This consolidated test covers three scenarios where templates should NOT be applied:
1. provider_options is None (default behavior)
2. provider_options exists but missing 'prompt_template' key
3. prompt_template is explicitly set to empty string ""
In all cases, texts should be sent to the API unchanged.
"""
# Scenario 1: None provider_options
texts = ["Original text one", "Original text two"]
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=None,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Original text one", (
"Text should be unchanged with None provider_options"
)
assert sent_texts[1] == "Original text two"
assert isinstance(result, np.ndarray)
assert result.shape == (2, 3)
# Reset mock for next scenario
mock_openai_client.reset_mock()
mock_response = Mock()
mock_response.data = [
Mock(embedding=[0.1, 0.2, 0.3]),
Mock(embedding=[0.4, 0.5, 0.6]),
]
mock_openai_client.embeddings.create.return_value = mock_response
# Scenario 2: Missing 'prompt_template' key
texts = ["Text without template", "Another text"]
provider_options = {"base_url": "https://api.openai.com/v1"}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key"
assert sent_texts[1] == "Another text"
assert isinstance(result, np.ndarray)
# Reset mock for next scenario
mock_openai_client.reset_mock()
mock_openai_client.embeddings.create.return_value = mock_response
# Scenario 3: Empty string template
texts = ["Text one", "Text two"]
provider_options = {"prompt_template": ""}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Text one", "Empty template should not modify text"
assert sent_texts[1] == "Text two"
assert isinstance(result, np.ndarray)
def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client):
"""Verify template is prepended in all batches when texts exceed batch size.
OpenAI API has batch size limits. When input texts are split into
multiple batches, the template should be prepended to texts in every batch.
This ensures consistency across all API calls.
"""
# Create many texts that will be split into multiple batches
texts = [f"Document {i}" for i in range(1000)]
template = "passage: "
provider_options = {"prompt_template": template}
# Mock multiple batch responses
mock_response = Mock()
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)]
mock_openai_client.embeddings.create.return_value = mock_response
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify embeddings.create was called multiple times (batching)
assert mock_openai_client.embeddings.create.call_count >= 2, (
"Should make multiple API calls for large text list"
)
# Verify template was prepended in ALL batches
for call in mock_openai_client.embeddings.create.call_args_list:
sent_texts = call.kwargs["input"]
for text in sent_texts:
assert text.startswith(template), (
f"All texts in all batches should start with template. Got: {text}"
)
# Verify result shape
assert result.shape[0] == 1000, "Should return embeddings for all texts"
def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client):
"""Verify template with special characters is handled correctly.
Templates may contain special characters, Unicode, newlines, etc.
These should all be prepended correctly without encoding issues.
"""
texts = ["Document content"]
# Template with various special characters
template = "🔍 Search query [EN]: "
provider_options = {"prompt_template": template}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify special characters in template were preserved
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "🔍 Search query [EN]: Document content", (
"Special characters in template should be preserved"
)
assert isinstance(result, np.ndarray)
def test_prompt_template_integration_with_existing_validation(
self, mock_openai_module, mock_openai_client
):
"""Verify template works with existing input validation.
compute_embeddings_openai has validation for empty texts and whitespace.
Template prepending should happen AFTER validation, so validation errors
are thrown based on original texts, not templated texts.
This ensures users get clear error messages about their input.
"""
# Empty text should still raise ValueError even with template
texts = [""]
provider_options = {"prompt_template": "prefix: "}
with pytest.raises(ValueError, match="empty/invalid"):
compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
def test_prompt_template_with_api_key_and_base_url(
self, mock_openai_module, mock_openai_client
):
"""Verify template works alongside other provider_options.
provider_options may contain multiple settings: prompt_template,
base_url, api_key. All should work together correctly.
"""
texts = ["Test document"]
provider_options = {
"prompt_template": "embed: ",
"base_url": "https://custom.api.com/v1",
"api_key": "test-key-123",
}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify template was applied
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "embed: Test document"
# Verify OpenAI client was created with correct base_url
mock_openai_module.assert_called()
client_init_kwargs = mock_openai_module.call_args.kwargs
assert client_init_kwargs["base_url"] == "https://custom.api.com/v1"
assert client_init_kwargs["api_key"] == "test-key-123"
assert isinstance(result, np.ndarray)

View File

@@ -0,0 +1,315 @@
"""Unit tests for LM Studio TypeScript SDK bridge functionality.
This test suite defines the contract for the LM Studio SDK bridge that queries
model context length via Node.js subprocess. These tests verify:
1. Successful SDK query returns context length
2. Graceful fallback when Node.js not installed (FileNotFoundError)
3. Graceful fallback when SDK not installed (npm error)
4. Timeout handling (subprocess.TimeoutExpired)
5. Invalid JSON response handling
All tests are written in Red Phase - they should FAIL initially because the
`_query_lmstudio_context_limit` function does not exist yet.
The function contract:
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
- Outputs: context_length (int) or None on error
- Requirements:
1. Call Node.js with inline JavaScript using @lmstudio/sdk
2. 10-second timeout (accounts for Node.js startup)
3. Graceful fallback on any error (returns None, doesn't raise)
4. Parse JSON response with contextLength field
5. Log errors at debug level (not warning/error)
"""
import subprocess
from unittest.mock import Mock
import pytest
# Try to import the function - if it doesn't exist, tests will fail as expected
try:
from leann.embedding_compute import _query_lmstudio_context_limit
except ImportError:
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
def _query_lmstudio_context_limit(*args, **kwargs):
raise NotImplementedError(
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
)
class TestLMStudioBridge:
"""Tests for LM Studio TypeScript SDK bridge integration."""
def test_query_lmstudio_success(self, monkeypatch):
"""Verify successful SDK query returns context length.
When the Node.js subprocess successfully queries the LM Studio SDK,
it should return a JSON response with contextLength field. The function
should parse this and return the integer context length.
"""
def mock_run(*args, **kwargs):
# Verify timeout is set to 10 seconds
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
# Verify capture_output and text=True are set
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
assert kwargs.get("text") is True, "Should decode output as text"
# Return successful JSON response
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
# Test with typical LM Studio model
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit == 8192, "Should return context length from SDK response"
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
"""Verify graceful fallback when Node.js not installed.
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
The function should catch this and return None (graceful fallback to registry).
"""
def mock_run(*args, **kwargs):
raise FileNotFoundError("node: command not found")
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when Node.js not installed"
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
"""Verify graceful fallback when @lmstudio/sdk not installed.
When the SDK npm package is not installed, Node.js will return non-zero
exit code with error message in stderr. The function should detect this
and return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 1
mock_result.stdout = ""
mock_result.stderr = (
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
)
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when SDK not installed"
def test_query_lmstudio_timeout(self, monkeypatch):
"""Verify graceful fallback when subprocess times out.
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
not responding), subprocess.TimeoutExpired should be raised. The function
should catch this and return None.
"""
def mock_run(*args, **kwargs):
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None on timeout"
def test_query_lmstudio_invalid_json(self, monkeypatch):
"""Verify graceful fallback when response is invalid JSON.
When the subprocess returns malformed JSON (e.g., due to SDK error),
json.loads will raise ValueError/JSONDecodeError. The function should
catch this and return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = "This is not valid JSON"
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when JSON parsing fails"
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
"""Verify graceful fallback when JSON lacks contextLength field.
When the SDK returns valid JSON but without the expected contextLength
field (e.g., error response), the function should return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="nonexistent-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength field missing"
def test_query_lmstudio_null_context_length(self, monkeypatch):
"""Verify graceful fallback when contextLength is null.
When the SDK returns contextLength: null (model couldn't be loaded),
the function should return None for registry fallback.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength is null"
def test_query_lmstudio_zero_context_length(self, monkeypatch):
"""Verify graceful fallback when contextLength is zero.
When the SDK returns contextLength: 0 (invalid value), the function
should return None to trigger registry fallback.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength is zero"
def test_query_lmstudio_with_custom_port(self, monkeypatch):
"""Verify SDK query works with non-default WebSocket port.
LM Studio can run on custom ports. The function should pass the
provided base_url to the Node.js subprocess.
"""
def mock_run(*args, **kwargs):
# Verify the base_url argument is passed correctly
command = args[0] if args else kwargs.get("args", [])
assert "ws://localhost:8080" in " ".join(command), (
"Should pass custom port to subprocess"
)
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:8080"
)
assert limit == 4096, "Should work with custom WebSocket port"
@pytest.mark.parametrize(
"context_length,expected",
[
(512, 512), # Small context
(2048, 2048), # Common context
(8192, 8192), # Large context
(32768, 32768), # Very large context
],
)
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
"""Verify SDK query handles various context length values.
Different models have different context lengths. The function should
correctly parse and return any positive integer value.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit == expected, f"Should return {expected} for context length {context_length}"
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
Following the graceful fallback pattern from Ollama implementation,
errors should be logged at debug level to avoid alarming users when
fallback to registry works fine.
"""
import logging
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
def mock_run(*args, **kwargs):
raise FileNotFoundError("node: command not found")
monkeypatch.setattr("subprocess.run", mock_run)
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
# Check that debug logging occurred (not warning/error)
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
assert len(debug_logs) > 0, "Should log error at DEBUG level"
# Verify no WARNING or ERROR logs
warning_or_error_logs = [
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
]
assert len(warning_or_error_logs) == 0, (
"Should not log at WARNING/ERROR level for expected failures"
)

View File

@@ -0,0 +1,400 @@
"""End-to-end integration tests for prompt template and token limit features.
These tests verify real-world functionality with live services:
- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support
- Ollama with dynamic token limit detection
- Hybrid token limit discovery mechanism
Run with: pytest tests/test_prompt_template_e2e.py -v -s
Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration"
Prerequisites:
1. LM Studio running with embedding model: http://localhost:1234
2. [Optional] Ollama running: ollama serve
3. [Optional] Ollama model: ollama pull nomic-embed-text
4. [Optional] Node.js + @lmstudio/sdk for context length detection
"""
import logging
import socket
import numpy as np
import pytest
import requests
from leann.embedding_compute import (
compute_embeddings_ollama,
compute_embeddings_openai,
get_model_token_limit,
)
# Test markers for conditional execution
pytestmark = pytest.mark.integration
logger = logging.getLogger(__name__)
def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool:
"""Check if a service is available on the given host:port."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
result = sock.connect_ex((host, port))
sock.close()
return result == 0
except Exception:
return False
def check_ollama_available() -> bool:
"""Check if Ollama service is available."""
if not check_service_available("localhost", 11434):
return False
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
return response.status_code == 200
except Exception:
return False
def check_lmstudio_available() -> bool:
"""Check if LM Studio service is available."""
if not check_service_available("localhost", 1234):
return False
try:
response = requests.get("http://localhost:1234/v1/models", timeout=2.0)
return response.status_code == 200
except Exception:
return False
def get_lmstudio_first_model() -> str:
"""Get the first available model from LM Studio."""
try:
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
data = response.json()
models = data.get("data", [])
if models:
return models[0]["id"]
except Exception:
pass
return None
class TestPromptTemplateOpenAI:
"""End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio)."""
@pytest.mark.skipif(
not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234"
)
def test_lmstudio_embedding_with_prompt_template(self):
"""Test prompt templates with LM Studio using OpenAI-compatible API."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
texts = ["artificial intelligence", "machine learning"]
prompt_template = "search_query: "
# Get embeddings with prompt template via provider_options
provider_options = {"prompt_template": prompt_template}
embeddings = compute_embeddings_openai(
texts=texts,
model_name=model_name,
base_url="http://localhost:1234/v1",
api_key="lm-studio", # LM Studio doesn't require real key
provider_options=provider_options,
)
assert embeddings is not None
assert len(embeddings) == 2
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
assert all(len(emb) > 0 for emb in embeddings)
logger.info(
f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
)
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_prompt_template_affects_embeddings(self):
"""Verify that prompt templates actually change embedding values."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
text = "machine learning"
base_url = "http://localhost:1234/v1"
api_key = "lm-studio"
# Get embeddings without template
embeddings_no_template = compute_embeddings_openai(
texts=[text],
model_name=model_name,
base_url=base_url,
api_key=api_key,
provider_options={},
)
# Get embeddings with template
embeddings_with_template = compute_embeddings_openai(
texts=[text],
model_name=model_name,
base_url=base_url,
api_key=api_key,
provider_options={"prompt_template": "search_query: "},
)
# Embeddings should be different when template is applied
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
logger.info("✓ Prompt template changes embedding values as expected")
class TestPromptTemplateOllama:
"""End-to-end tests for prompt template with Ollama."""
@pytest.mark.skipif(
not check_ollama_available(), reason="Ollama service not available on localhost:11434"
)
def test_ollama_embedding_with_prompt_template(self):
"""Test prompt templates with Ollama using any available embedding model."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
model_name = embedding_models[0]
texts = ["artificial intelligence", "machine learning"]
prompt_template = "search_query: "
# Get embeddings with prompt template via provider_options
provider_options = {"prompt_template": prompt_template}
embeddings = compute_embeddings_ollama(
texts=texts,
model_name=model_name,
is_build=False,
host="http://localhost:11434",
provider_options=provider_options,
)
assert embeddings is not None
assert len(embeddings) == 2
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
assert all(len(emb) > 0 for emb in embeddings)
logger.info(
f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
)
except Exception as e:
pytest.skip(f"Could not test Ollama prompt template: {e}")
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_ollama_prompt_template_affects_embeddings(self):
"""Verify that prompt templates actually change embedding values with Ollama."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
model_name = embedding_models[0]
text = "machine learning"
host = "http://localhost:11434"
# Get embeddings without template
embeddings_no_template = compute_embeddings_ollama(
texts=[text], model_name=model_name, is_build=False, host=host, provider_options={}
)
# Get embeddings with template
embeddings_with_template = compute_embeddings_ollama(
texts=[text],
model_name=model_name,
is_build=False,
host=host,
provider_options={"prompt_template": "search_query: "},
)
# Embeddings should be different when template is applied
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
logger.info("✓ Ollama prompt template changes embedding values as expected")
except Exception as e:
pytest.skip(f"Could not test Ollama prompt template: {e}")
class TestLMStudioSDK:
"""End-to-end tests for LM Studio SDK integration."""
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_model_listing(self):
"""Test that we can list models from LM Studio."""
try:
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
assert response.status_code == 200
data = response.json()
assert "data" in data
models = data["data"]
logger.info(f"✓ LM Studio models available: {len(models)}")
if models:
logger.info(f" First model: {models[0].get('id', 'unknown')}")
except Exception as e:
pytest.skip(f"LM Studio API error: {e}")
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_sdk_context_length_detection(self):
"""Test context length detection via LM Studio SDK bridge (requires Node.js + SDK)."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
try:
from leann.embedding_compute import _query_lmstudio_context_limit
# SDK requires WebSocket URL (ws://)
context_length = _query_lmstudio_context_limit(
model_name=model_name, base_url="ws://localhost:1234"
)
if context_length is None:
logger.warning(
"⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)"
)
pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable")
else:
assert context_length > 0
logger.info(
f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}"
)
except ImportError:
pytest.skip("_query_lmstudio_context_limit not implemented yet")
except Exception as e:
logger.error(f"LM Studio SDK test error: {e}")
raise
class TestOllamaTokenLimit:
"""End-to-end tests for Ollama token limit discovery."""
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_ollama_token_limit_detection(self):
"""Test dynamic token limit detection from Ollama /api/show endpoint."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
test_model = embedding_models[0]
# Test token limit detection
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
assert limit > 0
logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}")
except Exception as e:
pytest.skip(f"Could not test Ollama token detection: {e}")
class TestHybridTokenLimit:
"""End-to-end tests for hybrid token limit discovery mechanism."""
def test_hybrid_discovery_registry_fallback(self):
"""Test fallback to static registry for known OpenAI models."""
# Use a known OpenAI model (should be in registry)
limit = get_model_token_limit(
model_name="text-embedding-3-small",
base_url="http://fake-server:9999", # Fake URL to force registry lookup
)
# text-embedding-3-small should have 8192 in registry
assert limit == 8192
logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens")
def test_hybrid_discovery_default_fallback(self):
"""Test fallback to safe default for completely unknown models."""
limit = get_model_token_limit(
model_name="completely-unknown-model-xyz-12345",
base_url="http://fake-server:9999",
default=512,
)
# Should get the specified default
assert limit == 512
logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens")
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_hybrid_discovery_ollama_dynamic_first(self):
"""Test that Ollama models use dynamic discovery first."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
test_model = embedding_models[0]
# Should query Ollama /api/show dynamically
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
assert limit > 0
logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}")
except Exception as e:
pytest.skip(f"Could not test hybrid Ollama discovery: {e}")
if __name__ == "__main__":
print("\n" + "=" * 70)
print("INTEGRATION TEST SUITE - Real Service Testing")
print("=" * 70)
print("\nThese tests require live services:")
print(" • LM Studio: http://localhost:1234 (with embedding model loaded)")
print(" • [Optional] Ollama: http://localhost:11434")
print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests")
print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s")
print("=" * 70 + "\n")

View File

@@ -0,0 +1,808 @@
"""
Integration tests for prompt template metadata persistence and reuse.
These tests verify the complete lifecycle of prompt template persistence:
1. Template is saved to .meta.json during index build
2. Template is automatically loaded during search operations
3. Template can be overridden with explicit flag during search
4. Template is reused during chat/ask operations
These are integration tests that:
- Use real file system with temporary directories
- Run actual build and search operations
- Inspect .meta.json file contents directly
- Mock embedding servers to avoid external dependencies
- Use small test codebases for fast execution
Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch
import numpy as np
import pytest
from leann.api import LeannBuilder, LeannSearcher
class TestPromptTemplateMetadataPersistence:
"""Tests for prompt template storage in .meta.json during build."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
# Return dummy embeddings as numpy array
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings):
"""
Verify that when build is run with embedding_options containing prompt_template,
the template value is saved to .meta.json file.
This is the core persistence requirement - templates must be saved to allow
reuse in subsequent search operations without re-specifying the flag.
Expected failure: .meta.json exists but doesn't contain embedding_options
with prompt_template, or the value is not persisted correctly.
"""
# Setup test data
index_path = temp_index_dir / "test_index.leann"
template = "search_document: "
# Build index with prompt template in embedding_options
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": template},
)
# Add a simple document
builder.add_text("This is a test document for indexing")
# Build the index
builder.build_index(str(index_path))
# Verify .meta.json was created and contains the template
meta_path = temp_index_dir / "test_index.leann.meta.json"
assert meta_path.exists(), ".meta.json file should be created during build"
# Read and parse metadata
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
# Verify embedding_options exists in metadata
assert "embedding_options" in meta_data, (
"embedding_options should be saved to .meta.json when provided"
)
# Verify prompt_template is in embedding_options
embedding_options = meta_data["embedding_options"]
assert "prompt_template" in embedding_options, (
"prompt_template should be saved within embedding_options"
)
# Verify the template value matches what we provided
assert embedding_options["prompt_template"] == template, (
f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'"
)
def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings):
"""
Verify that when no prompt template is provided during build,
.meta.json either doesn't have embedding_options or prompt_template key.
This ensures clean metadata without unnecessary keys when features aren't used.
Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template.
"""
index_path = temp_index_dir / "test_no_template.leann"
# Build index WITHOUT prompt template
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# No embedding_options provided
)
builder.add_text("Document without template")
builder.build_index(str(index_path))
# Verify metadata
meta_path = temp_index_dir / "test_no_template.leann.meta.json"
assert meta_path.exists()
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
# If embedding_options exists, it should not contain prompt_template
if "embedding_options" in meta_data:
embedding_options = meta_data["embedding_options"]
assert "prompt_template" not in embedding_options, (
"prompt_template should not be in metadata when not provided"
)
class TestPromptTemplateAutoLoadOnSearch:
"""Tests for automatic loading of prompt template during search operations.
NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search).
This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad
which uses simpler mocking and doesn't hang.
"""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings):
"""
Verify that searching an index built WITHOUT a prompt template
works correctly (backward compatibility).
The searcher should handle missing prompt_template gracefully.
Expected behavior: Search succeeds, no template is used.
"""
# Build index without template
index_path = temp_index_dir / "no_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
)
builder.add_text("Document without template")
builder.build_index(str(index_path))
# Reset mocks
mock_embeddings.reset_mock()
# Create searcher and search
searcher = LeannSearcher(index_path=str(index_path))
# Verify no template in embedding_options
assert "prompt_template" not in searcher.embedding_options, (
"Searcher should not have prompt_template when not in metadata"
)
class TestQueryPromptTemplateAutoLoad:
"""Tests for automatic loading of separate query_prompt_template during search (R2).
These tests verify the new two-template system where:
- build_prompt_template: Applied during index building
- query_prompt_template: Applied during search operations
Expected to FAIL in Red Phase (R2) because query template extraction
and application is not yet implemented in LeannSearcher.search().
"""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_compute_embeddings(self):
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
with patch("leann.embedding_compute.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify that search() automatically loads and applies query_prompt_template from .meta.json.
Given: Index built with separate build_prompt_template and query_prompt_template
When: LeannSearcher.search("my query") is called
Then: Query embedding is computed with "query: my query" (query template applied)
This is the core R2 requirement - query templates must be auto-loaded and applied
during search without user intervention.
Expected failure: compute_embeddings called with raw "my query" instead of
"query: my query" because query template extraction is not implemented.
"""
# Setup: Build index with separate templates in new format
index_path = temp_index_dir / "query_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={
"build_prompt_template": "doc: ",
"query_prompt_template": "query: ",
},
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock to ignore build calls
mock_compute_embeddings.reset_mock()
# Act: Search with query
searcher = LeannSearcher(index_path=str(index_path))
# Mock the backend search to avoid actual search
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {
"labels": [["test_id_0"]], # IDs (nested list for batch support)
"distances": [[0.9]], # Distances (nested list for batch support)
}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: compute_embeddings was called with query template applied
assert mock_compute_embeddings.called, "compute_embeddings should be called during search"
# Get the actual text passed to compute_embeddings
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0] # First positional arg (list of texts)
assert len(texts_arg) == 1, "Should compute embedding for one query"
assert texts_arg[0] == "query: my query", (
f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'"
)
def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify backward compatibility with old single prompt_template format.
Given: Index with old format (single prompt_template, no query_prompt_template)
When: LeannSearcher.search("my query") is called
Then: Query embedding computed with "doc: my query" (old template applied)
This ensures indexes built with the old single-template system continue
to work correctly with the new search implementation.
Expected failure: Old template not recognized/applied because backward
compatibility logic is not implemented.
"""
# Setup: Build index with old single-template format
index_path = temp_index_dir / "old_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": "doc: "}, # Old format
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: Old template was applied
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "doc: my query", (
f"Old prompt_template should be applied for backward compatibility: "
f"expected 'doc: my query', got '{texts_arg[0]}'"
)
def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify backward compatibility when no template is present in .meta.json.
Given: Index with no template in .meta.json (very old indexes)
When: LeannSearcher.search("my query") is called
Then: Query embedding computed with "my query" (no template, raw query)
This ensures the most basic backward compatibility - indexes without
any template support continue to work as before.
Expected failure: May fail if default template is incorrectly applied,
or if missing template causes error.
"""
# Setup: Build index without any template
index_path = temp_index_dir / "no_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# No embedding_options at all
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: No template applied (raw query)
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "my query", (
f"No template should be applied when missing from metadata: "
f"expected 'my query', got '{texts_arg[0]}'"
)
def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings):
"""
Verify that explicit provider_options can override stored query template.
Given: Index with query_prompt_template: "query: "
When: search() called with provider_options={"prompt_template": "override: "}
Then: Query embedding computed with "override: test" (override takes precedence)
This enables users to experiment with different query templates without
rebuilding the index, or to handle special query types differently.
Expected failure: provider_options parameter is accepted via **kwargs but
not used. Query embedding computed with raw "test" instead of "override: test"
because override logic is not implemented.
"""
# Setup: Build index with query template
index_path = temp_index_dir / "override_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={
"build_prompt_template": "doc: ",
"query_prompt_template": "query: ",
},
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search with override
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
# This should accept provider_options parameter
searcher.search(
"test",
top_k=1,
recompute_embeddings=False,
provider_options={"prompt_template": "override: "},
)
# Assert: Override template was applied
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "override: test", (
f"Override template should take precedence: "
f"expected 'override: test', got '{texts_arg[0]}'"
)
class TestPromptTemplateReuseInChat:
"""Tests for prompt template reuse in chat/ask operations."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
@pytest.fixture
def mock_embedding_server_manager(self):
"""Mock EmbeddingServerManager for chat tests."""
with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class:
mock_manager = Mock()
mock_manager.start_server.return_value = (True, 5557)
mock_manager_class.return_value = mock_manager
yield mock_manager
@pytest.fixture
def index_with_template(self, temp_index_dir, mock_embeddings):
"""Build an index with a prompt template."""
index_path = temp_index_dir / "chat_template_index.leann"
template = "document_query: "
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": template},
)
builder.add_text("Test document for chat")
builder.build_index(str(index_path))
return str(index_path), template
class TestPromptTemplateIntegrationWithEmbeddingModes:
"""Tests for prompt template compatibility with different embedding modes."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.mark.parametrize(
"mode,model,template,filename_prefix",
[
(
"openai",
"text-embedding-3-small",
"Represent this for searching: ",
"openai_template",
),
("ollama", "nomic-embed-text", "search_query: ", "ollama_template"),
("sentence-transformers", "facebook/contriever", "query: ", "st_template"),
],
)
def test_prompt_template_metadata_with_embedding_modes(
self, temp_index_dir, mode, model, template, filename_prefix
):
"""Verify prompt template is saved correctly across different embedding modes.
Tests that prompt templates are persisted to .meta.json for:
- OpenAI mode (primary use case)
- Ollama mode (also supports templates)
- Sentence-transformers mode (saved for forward compatibility)
Expected behavior: Template is saved to .meta.json regardless of mode.
"""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
index_path = temp_index_dir / f"{filename_prefix}.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model,
embedding_mode=mode,
embedding_options={"prompt_template": template},
)
builder.add_text(f"{mode.capitalize()} test document")
builder.build_index(str(index_path))
# Verify metadata
meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
assert meta_data["embedding_mode"] == mode
# Template should be saved for all modes (even if not used by some)
if "embedding_options" in meta_data:
assert meta_data["embedding_options"]["prompt_template"] == template
class TestQueryTemplateApplicationInComputeEmbedding:
"""Tests for query template application in compute_query_embedding() (Bug Fix).
These tests verify that query templates are applied consistently in BOTH
code paths (server and fallback) when computing query embeddings.
This addresses the bug where query templates were only applied in the
fallback path, not when using the embedding server (the default path).
Bug Context:
- Issue: Query templates were stored in metadata but only applied during
fallback (direct) computation, not when using embedding server
- Fix: Move template application to BEFORE any computation path in
compute_query_embedding() (searcher_base.py:107-110)
- Impact: Critical for models like EmbeddingGemma that require task-specific
templates for optimal performance
These tests ensure the fix works correctly and prevent regression.
"""
@pytest.fixture
def temp_index_with_template(self):
"""Create a temporary index with query template in metadata"""
with tempfile.TemporaryDirectory() as tmpdir:
index_dir = Path(tmpdir)
index_file = index_dir / "test.leann"
meta_file = index_dir / "test.leann.meta.json"
# Create minimal metadata with query template
metadata = {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": "text-embedding-embeddinggemma-300m-qat",
"dimensions": 768,
"embedding_mode": "openai",
"backend_kwargs": {
"graph_degree": 32,
"complexity": 64,
"distance_metric": "cosine",
},
"embedding_options": {
"base_url": "http://localhost:1234/v1",
"api_key": "test-key",
"build_prompt_template": "title: none | text: ",
"query_prompt_template": "task: search result | query: ",
},
}
meta_file.write_text(json.dumps(metadata, indent=2))
# Create minimal HNSW index file (empty is okay for this test)
index_file.write_bytes(b"")
yield str(index_file)
def test_query_template_applied_in_fallback_path(self, temp_index_with_template):
"""Test that query template is applied when using fallback (direct) path"""
from leann.searcher_base import BaseSearcher
# Create a concrete implementation for testing
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
# Load metadata
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
# Mock compute_embeddings to capture the query text
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
# Call compute_query_embedding with template (fallback path)
result = searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False, # Force fallback path
query_template="task: search result | query: ",
)
# Verify template was applied
assert len(captured_queries) == 1
assert captured_queries[0] == "task: search result | query: vector database"
assert result.shape == (1, 768)
def test_query_template_applied_in_server_path(self, temp_index_with_template):
"""Test that query template is applied when using server path"""
from leann.searcher_base import BaseSearcher
# Create a concrete implementation for testing
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
# Load metadata
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
# Mock the server methods to capture the query text
captured_queries = []
def mock_ensure_server_running(passages_file, port):
return port
def mock_compute_embedding_via_server(chunks, port):
captured_queries.extend(chunks)
return np.random.rand(len(chunks), 768).astype(np.float32)
searcher._ensure_server_running = mock_ensure_server_running
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
# Call compute_query_embedding with template (server path)
result = searcher.compute_query_embedding(
query="vector database",
use_server_if_available=True, # Use server path
query_template="task: search result | query: ",
)
# Verify template was applied BEFORE calling server
assert len(captured_queries) == 1
assert captured_queries[0] == "task: search result | query: vector database"
assert result.shape == (1, 768)
def test_query_template_without_template_parameter(self, temp_index_with_template):
"""Test that query is unchanged when no template is provided"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False,
query_template=None, # No template
)
# Verify query is unchanged
assert len(captured_queries) == 1
assert captured_queries[0] == "vector database"
def test_query_template_consistency_between_paths(self, temp_index_with_template):
"""Test that both paths apply template identically"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
query_template = "task: search result | query: "
original_query = "vector database"
# Capture queries from fallback path
fallback_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
fallback_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query=original_query,
use_server_if_available=False,
query_template=query_template,
)
# Capture queries from server path
server_queries = []
def mock_ensure_server_running(passages_file, port):
return port
def mock_compute_embedding_via_server(chunks, port):
server_queries.extend(chunks)
return np.random.rand(len(chunks), 768).astype(np.float32)
searcher._ensure_server_running = mock_ensure_server_running
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
searcher.compute_query_embedding(
query=original_query,
use_server_if_available=True,
query_template=query_template,
)
# Verify both paths produced identical templated queries
assert len(fallback_queries) == 1
assert len(server_queries) == 1
assert fallback_queries[0] == server_queries[0]
assert fallback_queries[0] == f"{query_template}{original_query}"
def test_query_template_with_empty_string(self, temp_index_with_template):
"""Test behavior with empty template string"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False,
query_template="", # Empty string
)
# Empty string is falsy, so no template should be applied
assert captured_queries[0] == "vector database"

View File

@@ -0,0 +1,643 @@
"""Unit tests for token-aware truncation functionality.
This test suite defines the contract for token truncation functions that prevent
500 errors from Ollama when text exceeds model token limits. These tests verify:
1. Model token limit retrieval (known and unknown models)
2. Text truncation behavior for single and multiple texts
3. Token counting and truncation accuracy using tiktoken
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
import pytest
import tiktoken
from leann.embedding_compute import (
EMBEDDING_MODEL_LIMITS,
get_model_token_limit,
truncate_to_token_limit,
)
class TestModelTokenLimits:
"""Tests for retrieving model-specific token limits."""
def test_get_model_token_limit_known_model(self):
"""Verify correct token limit is returned for known models.
Known models should return their specific token limits from
EMBEDDING_MODEL_LIMITS dictionary.
"""
# Test nomic-embed-text (2048 tokens)
limit = get_model_token_limit("nomic-embed-text")
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
# Test nomic-embed-text-v1.5 (2048 tokens)
limit = get_model_token_limit("nomic-embed-text-v1.5")
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
# Test nomic-embed-text-v2 (512 tokens)
limit = get_model_token_limit("nomic-embed-text-v2")
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
# Test OpenAI models (8192 tokens)
limit = get_model_token_limit("text-embedding-3-small")
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
def test_get_model_token_limit_unknown_model(self):
"""Verify default token limit is returned for unknown models.
Unknown models should return the default limit (2048) to allow
operation with reasonable safety margin.
"""
# Test with completely unknown model
limit = get_model_token_limit("unknown-model-xyz")
assert limit == 2048, "Unknown models should return default 2048"
# Test with empty string
limit = get_model_token_limit("")
assert limit == 2048, "Empty model name should return default 2048"
def test_get_model_token_limit_custom_default(self):
"""Verify custom default can be specified for unknown models.
Allow callers to specify their own default token limit when
model is not in the known models dictionary.
"""
limit = get_model_token_limit("unknown-model", default=4096)
assert limit == 4096, "Should return custom default for unknown models"
# Known model should ignore custom default
limit = get_model_token_limit("nomic-embed-text", default=4096)
assert limit == 2048, "Known model should ignore custom default"
def test_embedding_model_limits_dictionary_exists(self):
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
The dictionary should be importable and contain at least the
known nomic models with correct token limits.
"""
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
"Should contain nomic-embed-text-v1.5"
)
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
# OpenAI models
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
class TestTokenTruncation:
"""Tests for truncating texts to token limits."""
@pytest.fixture
def tokenizer(self):
"""Provide tiktoken tokenizer for token counting verification."""
return tiktoken.get_encoding("cl100k_base")
def test_truncate_single_text_under_limit(self, tokenizer):
"""Verify text under token limit remains unchanged.
When text is already within the token limit, it should be
returned unchanged with no truncation.
"""
text = "This is a short text that is well under the token limit."
token_count = len(tokenizer.encode(text))
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
# Truncate with generous limit
result = truncate_to_token_limit([text], token_limit=512)
assert len(result) == 1, "Should return same number of texts"
assert result[0] == text, "Text under limit should be unchanged"
def test_truncate_single_text_over_limit(self, tokenizer):
"""Verify text over token limit is truncated correctly.
When text exceeds the token limit, it should be truncated to
fit within the limit while maintaining valid token boundaries.
"""
# Create a text that definitely exceeds limit
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 50, (
f"Test setup: text should be long (has {original_token_count} tokens)"
)
# Truncate to 50 tokens
result = truncate_to_token_limit([text], token_limit=50)
assert len(result) == 1, "Should return same number of texts"
assert result[0] != text, "Text over limit should be truncated"
assert len(result[0]) < len(text), "Truncated text should be shorter"
# Verify truncated text is within token limit
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 50, (
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
)
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
"""Verify multiple texts with mixed lengths are handled correctly.
When processing multiple texts:
- Texts under limit should remain unchanged
- Texts over limit should be truncated independently
- Output list should maintain same order and length
"""
texts = [
"Short text.", # Under limit
"word " * 200, # Over limit
"Another short one.", # Under limit
"token " * 150, # Over limit
]
# Verify test setup
for i, text in enumerate(texts):
token_count = len(tokenizer.encode(text))
if i in [1, 3]:
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
else:
assert token_count < 50, (
f"Text {i} should be under limit (has {token_count} tokens)"
)
# Truncate with 50 token limit
result = truncate_to_token_limit(texts, token_limit=50)
assert len(result) == len(texts), "Should return same number of texts"
# Verify each text individually
for i, (original, truncated) in enumerate(zip(texts, result)):
token_count = len(tokenizer.encode(truncated))
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
# Short texts should be unchanged
if i in [0, 2]:
assert truncated == original, f"Short text {i} should be unchanged"
# Long texts should be truncated
else:
assert len(truncated) < len(original), f"Long text {i} should be truncated"
def test_truncate_empty_list(self):
"""Verify empty input list returns empty output list.
Edge case: empty list should return empty list without errors.
"""
result = truncate_to_token_limit([], token_limit=512)
assert result == [], "Empty input should return empty output"
def test_truncate_preserves_order(self, tokenizer):
"""Verify truncation preserves original text order.
Output list should maintain the same order as input list,
regardless of which texts were truncated.
"""
texts = [
"First text " * 50, # Will be truncated
"Second text.", # Won't be truncated
"Third text " * 50, # Will be truncated
]
result = truncate_to_token_limit(texts, token_limit=20)
assert len(result) == 3, "Should preserve list length"
# Check that order is maintained by looking for distinctive words
assert "First" in result[0], "First text should remain in first position"
assert "Second" in result[1], "Second text should remain in second position"
assert "Third" in result[2], "Third text should remain in third position"
def test_truncate_extremely_long_text(self, tokenizer):
"""Verify extremely long texts are truncated efficiently.
Test with text that far exceeds token limit to ensure
truncation handles extreme cases without performance issues.
"""
# Create very long text (simulate real-world scenario)
text = "token " * 5000 # ~5000+ tokens
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 1000, "Test setup: text should be very long"
# Truncate to small limit
result = truncate_to_token_limit([text], token_limit=100)
assert len(result) == 1
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 100, (
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
)
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
def test_truncate_exact_token_limit(self, tokenizer):
"""Verify text at exactly token limit is handled correctly.
Edge case: text with exactly the token limit should either
remain unchanged or be safely truncated by 1 token.
"""
# Create text with approximately 50 tokens
# We'll adjust to get exactly 50
target_tokens = 50
text = "word " * 50
tokens = tokenizer.encode(text)
# Adjust to get exactly target_tokens
if len(tokens) > target_tokens:
tokens = tokens[:target_tokens]
text = tokenizer.decode(tokens)
elif len(tokens) < target_tokens:
# Add more words
while len(tokenizer.encode(text)) < target_tokens:
text += "word "
tokens = tokenizer.encode(text)[:target_tokens]
text = tokenizer.decode(tokens)
# Verify we have exactly target_tokens
assert len(tokenizer.encode(text)) == target_tokens, (
"Test setup: should have exactly 50 tokens"
)
result = truncate_to_token_limit([text], token_limit=target_tokens)
assert len(result) == 1
result_tokens = len(tokenizer.encode(result[0]))
assert result_tokens <= target_tokens, (
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
)
class TestLMStudioHybridDiscovery:
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
These tests verify that get_model_token_limit() properly integrates with
the LM Studio SDK bridge for dynamic token limit discovery. The integration
should:
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
2. Convert HTTP URLs to WebSocket format for SDK queries
3. Query LM Studio SDK and use discovered limit
4. Fall back to registry when SDK returns None
5. Execute AFTER Ollama detection but BEFORE registry fallback
All tests are written in Red Phase - they should FAIL initially because the
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
"""
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
"""Verify LM Studio SDK query succeeds and returns detected limit.
When a LM Studio base_url is detected and the SDK query succeeds,
get_model_token_limit() should return the dynamically discovered
context length without falling back to the registry.
"""
# Mock _query_lmstudio_context_limit to return successful SDK query
def mock_query_lmstudio(model_name, base_url):
# Verify WebSocket URL was passed (not HTTP)
assert base_url.startswith("ws://"), (
f"Should convert HTTP to WebSocket format, got: {base_url}"
)
return 8192 # Successful SDK query
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with HTTP URL that should be converted to WebSocket
limit = get_model_token_limit(
model_name="custom-model", base_url="http://localhost:1234/v1"
)
assert limit == 8192, "Should return limit from LM Studio SDK query"
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
"""Verify fallback to registry when LM Studio SDK returns None.
When LM Studio SDK query fails (returns None), get_model_token_limit()
should fall back to the EMBEDDING_MODEL_LIMITS registry.
"""
# Mock _query_lmstudio_context_limit to return None (SDK failure)
def mock_query_lmstudio(model_name, base_url):
return None # SDK query failed
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with known model that exists in registry
limit = get_model_token_limit(
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
)
# Should fall back to registry value
assert limit == 2048, "Should fall back to registry when SDK returns None"
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
"""Verify detection of LM Studio via port 1234.
get_model_token_limit() should recognize port 1234 as a LM Studio
server and attempt SDK query, regardless of hostname.
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return 4096
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with port 1234 (default LM Studio port)
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
assert limit == 4096, "Should return SDK query result"
@pytest.mark.parametrize(
"test_url,expected_limit,keyword",
[
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
],
)
def test_get_model_token_limit_lmstudio_url_keyword_detection(
self, monkeypatch, test_url, expected_limit, keyword
):
"""Verify detection of LM Studio via keywords in URL.
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
in the URL as indicating a LM Studio server.
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return expected_limit
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
@pytest.mark.parametrize(
"input_url,expected_protocol,expected_host",
[
("http://localhost:1234/v1", "ws://", "localhost:1234"),
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
],
)
def test_get_model_token_limit_protocol_conversion(
self, monkeypatch, input_url, expected_protocol, expected_host
):
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
1. Convert 'http://' to 'ws://'
2. Convert 'https://' to 'wss://'
3. Remove '/v1' or other path suffixes (SDK expects base URL)
4. Preserve host and port
"""
conversions_tested = []
def mock_query_lmstudio(model_name, base_url):
conversions_tested.append(base_url)
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
get_model_token_limit(model_name="test-model", base_url=input_url)
# Verify conversion happened
assert len(conversions_tested) == 1, "Should have called SDK query once"
assert conversions_tested[0].startswith(expected_protocol), (
f"Should convert to {expected_protocol}"
)
assert expected_host in conversions_tested[0], (
f"Should preserve host and port: {expected_host}"
)
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
"""Verify LM Studio detection happens AFTER Ollama detection.
The hybrid discovery order should be:
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
3. Registry fallback
If both Ollama and LM Studio patterns match, Ollama should take precedence.
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
"""
ollama_called = False
lmstudio_called = False
def mock_query_ollama(model_name, base_url):
nonlocal ollama_called
ollama_called = True
return 2048 # Ollama query succeeds
def mock_query_lmstudio(model_name, base_url):
nonlocal lmstudio_called
lmstudio_called = True
return None # Should not be reached if Ollama succeeds
monkeypatch.setattr(
"leann.embedding_compute._query_ollama_context_limit",
mock_query_ollama,
)
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with Ollama URL
limit = get_model_token_limit(
model_name="test-model", base_url="http://localhost:11434/api"
)
assert ollama_called, "Should attempt Ollama query first"
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
assert limit == 2048, "Should return Ollama result"
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
"""
lmstudio_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal lmstudio_called
lmstudio_called = True
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with non-LM Studio URLs
test_cases = [
"http://localhost:8080/v1", # Different port
"http://openai.example.com/v1", # Different service
"http://localhost:3000/v1", # Another port
]
for base_url in test_cases:
lmstudio_called = False # Reset for each test
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
"""Verify LM Studio detection is case-insensitive for keywords.
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test various case variations
test_cases = [
"http://LMStudio.local:8080/v1",
"http://LMSTUDIO.example.com/v1",
"http://LmStudio.local/v1",
"http://api.LM.STUDIO:5000/v1",
]
for base_url in test_cases:
query_called = False # Reset for each test
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
assert query_called, f"Should detect LM Studio in URL: {base_url}"
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
class TestTokenLimitCaching:
"""Tests for token limit caching to prevent repeated SDK/API calls.
Caching prevents duplicate SDK/API calls within the same Python process,
which is important because:
1. LM Studio SDK load() can load duplicate model instances
2. Ollama /api/show queries add latency
3. Registry lookups are pure overhead
Cache is process-scoped and resets between leann build invocations.
"""
def setup_method(self):
"""Clear cache before each test."""
from leann.embedding_compute import _token_limit_cache
_token_limit_cache.clear()
def test_registry_lookup_is_cached(self):
"""Verify that registry lookups are cached."""
from leann.embedding_compute import _token_limit_cache
# First call
limit1 = get_model_token_limit("text-embedding-3-small")
assert limit1 == 8192
# Verify it's in cache
cache_key = ("text-embedding-3-small", "")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 8192
# Second call should use cache
limit2 = get_model_token_limit("text-embedding-3-small")
assert limit2 == 8192
def test_default_fallback_is_cached(self):
"""Verify that default fallbacks are cached."""
from leann.embedding_compute import _token_limit_cache
# First call with unknown model
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
assert limit1 == 512
# Verify it's in cache
cache_key = ("unknown-model-xyz", "")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 512
# Second call should use cache
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
assert limit2 == 512
def test_different_urls_create_separate_cache_entries(self):
"""Verify that different base_urls create separate cache entries."""
from leann.embedding_compute import _token_limit_cache
# Same model, different URLs
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
# Both should find the model in registry (2048)
assert limit1 == 2048
assert limit2 == 2048
# But they should be separate cache entries
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
assert cache_key1 in _token_limit_cache
assert cache_key2 in _token_limit_cache
assert len(_token_limit_cache) == 2
def test_cache_prevents_repeated_lookups(self):
"""Verify that cache prevents repeated registry/API lookups."""
from leann.embedding_compute import _token_limit_cache
model_name = "text-embedding-ada-002"
# First call - should add to cache
assert len(_token_limit_cache) == 0
limit1 = get_model_token_limit(model_name)
cache_size_after_first = len(_token_limit_cache)
assert cache_size_after_first == 1
# Multiple subsequent calls - cache size should not change
for _ in range(5):
limit = get_model_token_limit(model_name)
assert limit == limit1
assert len(_token_limit_cache) == cache_size_after_first
def test_versioned_model_names_cached_correctly(self):
"""Verify that versioned model names (e.g., model:tag) are cached."""
from leann.embedding_compute import _token_limit_cache
# Model with version tag
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
assert limit == 2048
# Should be cached with full name including version
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 2048

7238
uv.lock generated
View File

File diff suppressed because it is too large Load Diff