Compare commits

..

53 Commits

Author SHA1 Message Date
aakash
877fbe81f4 Fix: Prevent duplicate PDF processing when using --file-types .pdf
Fixes #175

Problem:
When --file-types .pdf is specified, PDFs were being processed twice:
1. Separately with PyMuPDF/pdfplumber extractors
2. Again in the 'other file types' section via SimpleDirectoryReader

This caused duplicate processing and potential conflicts.

Solution:
- Exclude .pdf from other_file_extensions when PDFs are already
  processed separately
- Only load other file types if there are extensions to process
- Prevents duplicate PDF processing

Changes:
- Added logic to filter out .pdf from code_extensions when loading
  other file types if PDFs were processed separately
- Updated SimpleDirectoryReader to use filtered extensions
- Added check to skip loading if no other extensions to process
2025-11-30 11:18:57 -08:00
Andy Lee
eb909ccec5 docs: survey 2025-11-20 09:20:45 -08:00
ww26
969f514564 Fix prompt template bugs: build template ignored and runtime override not wired (#173)
* Fix prompt template bugs in build and search

Bug 1: Build template ignored in new format
- Updated compute_embeddings_openai() to read build_prompt_template or prompt_template
- Updated compute_embeddings_ollama() with same fix
- Maintains backward compatibility with old single-template format

Bug 2: Runtime override not wired up
- Wired CLI search to pass provider_options to searcher.search()
- Enables runtime template override during search via --embedding-prompt-template

All 42 prompt template tests passing.

Fixes #155

* Fix: Prevent embedding server from applying templates during search

- Filter out all prompt templates (build_prompt_template, query_prompt_template, prompt_template) from provider_options when launching embedding server during search
- Templates are already applied in compute_query_embedding() before server call
- Prevents double-templating and ensures runtime override works correctly

This fixes the issue where --embedding-prompt-template during search was ignored because the server was applying build_prompt_template instead.

* Format code with ruff
2025-11-16 20:56:42 -08:00
ww26
1ef9cba7de Feature/prompt templates and lmstudio sdk (#171)
* Add prompt template support and LM Studio SDK integration

Features:

- Prompt template support for embedding models (via --embedding-prompt-template)

- LM Studio SDK integration for automatic context length detection

- Hybrid token limit discovery (Ollama → LM Studio → Registry → Default)

- Client-side token truncation to prevent silent failures

- Automatic persistence of embedding_options to .meta.json

Implementation:

- Added _query_lmstudio_context_limit() with Node.js subprocess bridge

- Modified compute_embeddings_openai() to apply prompt templates before truncation

- Extended CLI with --embedding-prompt-template flag for build and search

- URL detection for LM Studio (port 1234 or lmstudio/lm.studio keywords)

- HTTP→WebSocket URL conversion for SDK compatibility

Tests:

- 60 passing tests across 5 test files

- Comprehensive coverage of prompt templates, LM Studio integration, and token handling

- Parametrized tests for maintainability and clarity

* Add integration tests and fix LM Studio SDK bridge

Features:
- End-to-end integration tests for prompt template with EmbeddingGemma
- Integration tests for hybrid token limit discovery mechanism
- Tests verify real-world functionality with live services (LM Studio, Ollama)

Fixes:
- LM Studio SDK bridge now uses client.embedding.load() for embedding models
- Fixed NODE_PATH resolution to include npm global modules
- Fixed integration test to use WebSocket URL (ws://) for SDK bridge

Tests:
- test_prompt_template_e2e.py: 8 integration tests covering:
  - Prompt template prepending with LM Studio (EmbeddingGemma)
  - LM Studio SDK bridge for context length detection
  - Ollama dynamic token limit detection
  - Hybrid discovery fallback mechanism (registry, default)
- All tests marked with @pytest.mark.integration for selective execution
- Tests gracefully skip when services unavailable

Documentation:
- Updated tests/README.md with integration test section
- Added prerequisites and running instructions
- Documented that prompt templates are ONLY for EmbeddingGemma
- Added integration marker to pyproject.toml

Test Results:
- All 8 integration tests passing with live services
- Confirmed prompt templates work correctly with EmbeddingGemma
- Verified LM Studio SDK bridge auto-detects context length (2048)
- Validated hybrid token limit discovery across all backends

* Add prompt template support to Ollama mode

Extends prompt template functionality from OpenAI mode to Ollama for backend consistency.

Changes:
- Add provider_options parameter to compute_embeddings_ollama()
- Apply prompt template before token truncation (lines 1005-1011)
- Pass provider_options through compute_embeddings() call chain

Tests:
- test_ollama_embedding_with_prompt_template: Verifies templates work with Ollama
- test_ollama_prompt_template_affects_embeddings: Confirms embeddings differ with/without template
- Both tests pass with live Ollama service (2/2 passing)

Usage:
leann build --embedding-mode ollama --embedding-prompt-template "query: " ...

* Fix LM Studio SDK bridge to respect JIT auto-evict settings

Problem: SDK bridge called client.embedding.load() which loaded models into
LM Studio memory and bypassed JIT auto-evict settings, causing duplicate
model instances to accumulate.

Root cause analysis (from Perplexity research):
- Explicit SDK load() commands are treated as "pinned" models
- JIT auto-evict only applies to models loaded reactively via API requests
- SDK-loaded models remain in memory until explicitly unloaded

Solutions implemented:

1. Add model.unload() after metadata query (line 243)
   - Load model temporarily to get context length
   - Unload immediately to hand control back to JIT system
   - Subsequent API requests trigger JIT load with auto-evict

2. Add token limit caching to prevent repeated SDK calls
   - Cache discovered limits in _token_limit_cache dict (line 48)
   - Key: (model_name, base_url), Value: token_limit
   - Prevents duplicate load/unload cycles within same process
   - Cache shared across all discovery methods (Ollama, SDK, registry)

Tests:
- TestTokenLimitCaching: 5 tests for cache behavior (integrated into test_token_truncation.py)
- Manual testing confirmed no duplicate models in LM Studio after fix
- All existing tests pass

Impact:
- Respects user's LM Studio JIT and auto-evict settings
- Reduces model memory footprint
- Faster subsequent builds (cached limits)

* Document prompt template and LM Studio SDK features

Added comprehensive documentation for new optional embedding features:

Configuration Guide (docs/configuration-guide.md):
- New section: "Optional Embedding Features"
- Task-Specific Prompt Templates subsection:
  - Explains EmbeddingGemma use case with document/query prompts
  - CLI and Python API examples
  - Clear warnings about compatible vs incompatible models
  - References to GitHub issue #155 and HuggingFace blog
- LM Studio Auto-Detection subsection:
  - Prerequisites (Node.js + @lmstudio/sdk)
  - How auto-detection works (4-step process)
  - Benefits and optional nature clearly stated

FAQ (docs/faq.md):
- FAQ #2: When should I use prompt templates?
  - DO/DON'T guidance with examples
  - Links to detailed configuration guide
- FAQ #3: Why is LM Studio loading multiple copies?
  - Explains the JIT auto-evict fix
  - Troubleshooting steps if still seeing issues
- FAQ #4: Do I need Node.js and @lmstudio/sdk?
  - Clarifies it's completely optional
  - Lists benefits if installed
  - Installation instructions

Cross-references between documents for easy navigation between quick reference and detailed guides.

* Add separate build/query template support for task-specific models

Task-specific models like EmbeddingGemma require different templates for indexing vs searching. Store both templates at build time and auto-apply query template during search with backward compatibility.

* Consolidate prompt template tests from 44 to 37 tests

Merged redundant no-op tests, removed low-value implementation tests, consolidated parameterized CLI tests, and removed hanging over-mocked test. All tests pass with improved focus on behavioral testing.

* Fix query template application in compute_query_embedding

Query templates were only applied in the fallback code path, not when using the embedding server (default path). This meant stored query templates in index metadata were ignored during MCP and CLI searches.

Changes:

- Move template application to before any computation path (searcher_base.py:109-110)

- Add comprehensive tests for both server and fallback paths

- Consolidate tests into test_prompt_template_persistence.py

Tests verify:

- Template applied when using embedding server

- Template applied in fallback path

- Consistent behavior between both paths

* Apply ruff formatting and fix linting issues

- Remove unused imports

- Fix import ordering

- Remove unused variables

- Apply code formatting

* Fix CI test failures: mock OPENAI_API_KEY in tests

Tests were failing in CI because compute_embeddings_openai() checks for OPENAI_API_KEY before using the mocked client. Added monkeypatch to set fake API key in test fixture.
2025-11-14 15:25:17 -08:00
yichuan520030910320
a63550944b update pkg name 2025-11-13 14:57:28 -08:00
yichuan520030910320
97493a2896 update module to install 2025-11-13 14:53:25 -08:00
Aakash Suresh
f7d2dc6e7c Merge pull request #163 from orbisai0security/fix-semgrep-python.lang.security.audit.eval-detected.eval-detected-apps-slack-data-slack-mc-b134f52c-f326
Fix: Unsafe Code Execution Function Could Allow External Code Injection in apps/slack_data/slack_mcp_reader.py
2025-11-13 14:35:44 -08:00
Gil Vernik
ea86b283cb [bug] fix when no package name found (#167)
* fix when no package name

* fix pre-commit style issue

* fix to ruff (legacy alias) fail test
2025-11-13 14:23:13 -08:00
aakash
e7519bceaa Fix CI: sync uv.lock from main and remove .lycheeignore (workflow exclusion is sufficient) 2025-11-13 13:10:07 -08:00
aakash
abf0b2c676 Fix CI: improve security fix and add link checker configuration
- Fix import order (ast before asyncio)
- Remove NameError from exception handling (ast.literal_eval doesn't raise it)
- Add .lycheeignore to exclude intermittently unavailable star-history API
- Update link-check workflow to exclude star-history API and accept 503 status codes
2025-11-13 13:05:00 -08:00
GitHub Actions
3c4785bb63 chore: release v0.3.5 2025-11-12 06:01:25 +00:00
orbisai0security
930b79cc98 fix: semgrep_python.lang.security.audit.eval-detected.eval-detected_apps/slack_data/slack_mcp_reader.py_157 2025-11-12 03:47:18 +00:00
yichuan-w
3766ad1fd2 robust multi-vector 2025-11-09 02:34:53 +00:00
ww26
c3aceed1e0 metadata reveal for ast-chunking; smart detection of seq length in ollama; auto adjust chunk length for ast to prevent silent truncation (#157)
* feat: enhance token limits with dynamic discovery + AST metadata

Improves upon upstream PR #154 with two major enhancements:

1. **Hybrid Token Limit Discovery**
   - Dynamic: Query Ollama /api/show for context limits
   - Fallback: Registry for LM Studio/OpenAI
   - Zero maintenance for Ollama users
   - Respects custom num_ctx settings

2. **AST Metadata Preservation**
   - create_ast_chunks() returns dict format with metadata
   - Preserves file_path, file_name, timestamps
   - Includes astchunk metadata (line numbers, node counts)
   - Fixes content extraction bug (checks "content" key)
   - Enables --show-metadata flag

3. **Better Token Limits**
   - nomic-embed-text: 2048 tokens (vs 512)
   - nomic-embed-text-v1.5: 2048 tokens
   - Added OpenAI models: 8192 tokens

4. **Comprehensive Tests**
   - 11 tests for token truncation
   - 545 new lines in test_astchunk_integration.py
   - All metadata preservation tests passing

* fix: merge EMBEDDING_MODEL_LIMITS and remove redundant validation

- Merged upstream's model list with our corrected token limits
- Kept our corrected nomic-embed-text: 2048 (not 512)
- Removed post-chunking validation (redundant with embedding-time truncation)
- All tests passing except 2 pre-existing integration test failures

* style: apply ruff formatting and restore PR #154 version handling

- Remove duplicate truncate_to_token_limit and get_model_token_limit functions
- Restore version handling logic (model:latest -> model) from PR #154
- Restore partial matching fallback for model name variations
- Apply ruff formatting to all modified files
- All 11 token truncation tests passing

* style: sort imports alphabetically (pre-commit auto-fix)

* fix: show AST token limit warning only once per session

- Add module-level flag to track if warning shown
- Prevents spam when processing multiple files
- Add clarifying note that auto-truncation happens at embedding time
- Addresses issue where warning appeared for every code file

* enhance: add detailed logging for token truncation

- Track and report truncation statistics (count, tokens removed, max length)
- Show first 3 individual truncations with exact token counts
- Provide comprehensive summary when truncation occurs
- Use WARNING level for data loss visibility
- Silent (DEBUG level only) when no truncation needed

Replaces misleading "truncated where necessary" message that appeared
even when nothing was truncated.
2025-11-08 17:37:31 -08:00
yichuan-w
dc6c9f696e update some search in copali 2025-11-08 08:53:03 +00:00
CalebZ9909
2406c41eef Update faiss submodule to latest commit
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-08 00:47:21 +00:00
Andy Lee
d4f5f2896f Faster Update (#148)
* stash

* stash

* add std err in add and trace progress

* fix.

* docs

* style: format

* docs

* better figs

* better figs

* update results

* fotmat

---------

Co-authored-by: yichuan-w <yichuan-w@users.noreply.github.com>
2025-11-05 13:37:47 -08:00
Aakash Suresh
366984e92e Merge pull request #154 from yichuan-w/fix/chunking-token-limit-behavior
Fix/chunking token limit behavior
2025-11-02 21:37:47 -08:00
aakash
64b92a04a7 fixing chunking token issues within limit for embedding models 2025-10-31 17:15:00 -07:00
ww26
a85d0ad4a7 Feature/optimize ollama batching (#152)
* feat: add metadata output to search results

- Add --show-metadata flag to display file paths in search results
- Preserve document metadata (file_path, file_name, timestamps) during chunking
- Update MCP tool schema to support show_metadata parameter
- Enhance CLI search output to display metadata when requested
- Fix pre-existing bug: args.backend -> args.backend_name

Resolves yichuan-w/LEANN#144

* fix: resolve ZMQ linking issues in Python extension

- Use pkg_check_modules IMPORTED_TARGET to create PkgConfig::ZMQ
- Set PKG_CONFIG_PATH to prioritize ARM64 Homebrew on Apple Silicon
- Override macOS -undefined dynamic_lookup to force proper symbol resolution
- Use PUBLIC linkage for ZMQ in faiss library for transitive linking
- Mark cppzmq includes as SYSTEM to suppress warnings

Fixes editable install ZMQ symbol errors while maintaining compatibility
across Linux, macOS Intel, and macOS ARM64 platforms.

* style: apply ruff formatting

* chore: update faiss submodule to use ww2283 fork

Use ww2283/faiss fork with fix/zmq-linking branch to resolve CI checkout
failures. The ZMQ linking fixes are not yet merged upstream.

* feat: implement true batch processing for Ollama embeddings

Migrate from deprecated /api/embeddings to modern /api/embed endpoint
which supports batch inputs. This reduces HTTP overhead by sending
32 texts per request instead of making individual API calls.

Changes:
- Update endpoint from /api/embeddings to /api/embed
- Change parameter from 'prompt' (single) to 'input' (array)
- Update response parsing for batch embeddings array
- Increase timeout to 60s for batch processing
- Improve error handling for batch requests

Performance:
- Reduces API calls by 32x (batch size)
- Eliminates HTTP connection overhead per text
- Note: Ollama still processes batch items sequentially internally

Related: #151

* fall back to original faiss as i merge the PR

---------

Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
2025-10-30 16:39:14 -07:00
yichuan-w
dbb5f4d352 Fix CI failure by removing paru-bin submodule
Remove paru-bin directory that was incorrectly added as a git submodule.
This directory is an AUR build artifact and should not be tracked.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-25 14:51:06 -07:00
yichuan-w
f180b83589 add deep wiki 2025-10-25 14:46:17 -07:00
CelineNi2
abf312d998 Display context chunks in ask and search results (#149)
* Printing querying time

* Adding source name to chunks

Adding source name as metadata to chunks, then printing the sources when searching

* Printing the context provided to LLM

To check the data transmitted to the LLMs : display the relevance, ID, content, and source of each sent chunk.

* Correcting source as metadata for chunks

* Applying ruff format

* Applying Ruff formatting

* Ruff formatting
2025-10-23 15:03:59 -07:00
Aakash Suresh
ab251ab751 Fix/twitter bookmarks anchor link (#143)
* fix: Fix Twitter bookmarks anchor link

- Convert Twitter Bookmarks from collapsible details to proper header
- Update internal link to match new anchor format
- Ensures external links to #twitter-bookmarks-your-personal-tweet-library work correctly

Fixes broken link: https://github.com/yichuan-w/LEANN?tab=readme-ov-file#twitter-bookmarks-your-personal-tweet-library

* fix: Fix Slack messages anchor link as well

- Convert Slack Messages from collapsible details to proper header
- Update internal link to match new anchor format
- Ensures external links to #slack-messages-search-your-team-conversations work correctly

Both Twitter and Slack MCP sections now have reliable anchor links.

* fix: Point Slack and Twitter links to main MCP section

- Both Slack and Twitter are subsections under MCP Integration
- Links should point to #mcp-integration-rag-on-live-data-from-any-platform
- Users will land on the MCP section and can find both Slack and Twitter subsections there

This matches the actual document structure where Slack and Twitter are under the MCP Integration section.

* Improve Slack MCP integration with retry logic and comprehensive setup guide

- Add retry mechanism with exponential backoff for cache sync issues
- Handle 'users cache is not ready yet' errors gracefully
- Add max-retries and retry-delay CLI arguments for better control
- Create comprehensive Slack setup guide with troubleshooting
- Update README with link to detailed setup guide
- Improve error messages and user experience

* Fix trailing whitespace in slack setup guide

Pre-commit hooks formatting fixes

* Add comprehensive Slack setup guide with success screenshot

- Create detailed setup guide with step-by-step instructions
- Add troubleshooting section for common issues like cache sync errors
- Include real terminal output example from successful integration
- Add screenshot showing VS Code interface with Slack channel data
- Remove excessive emojis for more professional documentation
- Document retry logic improvements and CLI arguments

* Fix formatting issues in Slack setup guide

- Remove trailing whitespace
- Fix end of file formatting
- Pre-commit hooks formatting fixes

* Add real RAG example showing intelligent Slack query functionality

- Add detailed example of asking 'What is LEANN about?'
- Show retrieved messages from Slack channels
- Demonstrate intelligent answer generation based on context
- Add command example for running real RAG queries
- Explain the 4-step process: retrieve, index, generate, cite

* Update Slack setup guide with bot invitation requirements

- Add important section about inviting bot to channels before RAG queries
- Explain the 'not_in_channel' errors and their meaning
- Provide clear steps for bot invitation process
- Document realistic scenario where bot needs explicit channel access
- Update documentation to be more professional and less cursor-style

* Docs: add real RAG example for Sky Lab #random

- Embed screenshot videos/rag-sky-random.png
- Add step-by-step commands and notes
- Include helper test script tests/test_channel_by_id_or_name.py
- Redact example tokens from docs

* Docs/CI: fix broken image paths and ruff lint\n\n- Move screenshot to docs/videos and update references\n- Remove obsolete rag-query-results image\n- Rename variable to satisfy ruff

* Docs: fix image path for lychee (use videos/ relative under docs/)

* Docs: finalize Slack setup guide with Sky random RAG example and image path fixes\n\n- Redact example tokens from docs

* Fix Slack MCP integration and update documentation

- Fix SlackMCPReader to use conversations_history instead of channels_list
- Add fallback imports for leann.interactive_utils and leann.settings
- Update slack-setup-guide.md with real screenshots and improved text
- Remove old screenshot files

* Add Slack integration screenshots to docs/videos

- Add slack_integration.png showing RAG query results
- Add slack_integration_2.png showing additional demo functionality
- Fixes lychee link checker errors for missing image files

* Update Slack integration screenshot with latest changes

* Remove test_channel_by_id_or_name.py

- Clean up temporary test file that was used for debugging
- Keep only the main slack_rag.py application for production use

* Update Slack RAG example to show LEANN announcement retrieval

- Change query from 'PUBPOL 290' to 'What is LEANN about?' for more challenging retrieval
- Update command to use python -m apps.slack_rag instead of test script
- Add expected response showing Yichuan Wang's LEANN announcement message
- Emphasize this demonstrates ability to find specific announcements in conversation history
- Update description to highlight challenging query capabilities

* Update Slack RAG integration with improved CSV parsing and new screenshots

- Fixed CSV message parsing in slack_mcp_reader.py to properly handle individual messages
- Updated slack_rag.py to filter empty channel strings
- Enhanced slack-setup-guide.md with two new query examples:
  - Advisor Models query: 'train black-box models to adopt to your personal data'
  - Barbarians at the Gate query: 'AI-driven research systems ADRS'
- Replaced old screenshots with four new ones showing both query examples
- Updated documentation to use User OAuth Token (xoxp-) instead of Bot Token (xoxb-)
- Added proper command examples with --no-concatenate-conversations and --force-rebuild flags

* Update Slack RAG documentation with Ollama integration and new screenshots

- Updated slack-setup-guide.md with comprehensive Ollama setup instructions
- Added 6 new screenshots showing complete RAG workflow:
  - Command setup, search results, and LLM responses for both queries
- Removed simulated LLM references, now uses real Ollama with llama3.2:1b
- Enhanced documentation with step-by-step Ollama installation
- Updated troubleshooting checklist to include Ollama-specific checks
- Fixed command syntax and added proper Ollama configuration
- Demonstrates working Slack RAG with real AI-generated responses

* Remove Key Features section from Slack RAG examples

- Simplified documentation by removing the bullet point list
- Keeps the focus on the actual examples and screenshots
2025-10-19 11:47:29 -07:00
CelineNi2
28085f6f04 Add messages regarding the use of token during query (#147)
* Add messages regarding the use of token during query

* fix: apply ruff format
2025-10-15 16:48:48 -07:00
CelineNi2
6495833887 Changing the option name "--backend" for "--backend-name" as written in the documentation (#146) 2025-10-14 13:35:10 -07:00
yichuan520030910320
5543b3c5f7 [minor] format fix 2025-10-09 15:10:54 -07:00
yichuan-w
a99983b3d9 fix readme 2025-10-08 21:51:25 +00:00
Aakash Suresh
36482e016c fix: Fix Twitter bookmarks anchor link (#140)
* fix: Fix Twitter bookmarks anchor link

- Convert Twitter Bookmarks from collapsible details to proper header
- Update internal link to match new anchor format
- Ensures external links to #twitter-bookmarks-your-personal-tweet-library work correctly

Fixes broken link: https://github.com/yichuan-w/LEANN?tab=readme-ov-file#twitter-bookmarks-your-personal-tweet-library

* fix: Fix Slack messages anchor link as well

- Convert Slack Messages from collapsible details to proper header
- Update internal link to match new anchor format
- Ensures external links to #slack-messages-search-your-team-conversations work correctly

Both Twitter and Slack MCP sections now have reliable anchor links.

* fix: Point Slack and Twitter links to main MCP section

- Both Slack and Twitter are subsections under MCP Integration
- Links should point to #mcp-integration-rag-on-live-data-from-any-platform
- Users will land on the MCP section and can find both Slack and Twitter subsections there

This matches the actual document structure where Slack and Twitter are under the MCP Integration section.
2025-10-08 02:32:02 -07:00
Aakash Suresh
32967daf81 security: Enhance Hugging Face model loading security - resolves #136 (#138)
BREAKING CHANGE: trust_remote_code now defaults to False for security

- Set trust_remote_code=False by default in HFChat class
- Add explicit trust_remote_code parameter to HFChat.__init__()
- Add security warning when trust_remote_code=True is used
- Update get_llm() function to support trust_remote_code parameter
- Update benchmark utilities (load_hf_model, load_vllm_model, load_qwen_vl_model)
- Add comprehensive documentation for security implications

Security Benefits:
- Prevents arbitrary code execution from compromised model repositories
- Requires explicit opt-in for models that need remote code execution
- Shows clear warnings when security is reduced
- Follows security-by-default principle

Migration Guide:
- Most users: No changes needed (more secure by default)
- Users with models requiring remote code: Add trust_remote_code=True explicitly
- Config users: Add 'trust_remote_code': true to LLM config if needed

Fixes #136
2025-10-07 13:13:44 -07:00
Aakash Suresh
b4bb8dec75 feat: Add MCP integration support for Slack and Twitter (#134)
* feat: Add MCP integration support for Slack and Twitter

- Implement SlackMCPReader for connecting to Slack MCP servers
- Implement TwitterMCPReader for connecting to Twitter MCP servers
- Add SlackRAG and TwitterRAG applications with full CLI support
- Support live data fetching via Model Context Protocol (MCP)
- Add comprehensive documentation and usage examples
- Include connection testing capabilities with --test-connection flag
- Add standalone tests for core functionality
- Update README with detailed MCP integration guide
- Add Aakash Suresh to Active Contributors

Resolves #36

* fix: Resolve linting issues in MCP integration

- Replace deprecated typing.Dict/List with built-in dict/list
- Fix boolean comparisons (== True/False) to direct checks
- Remove unused variables in demo script
- Update type annotations to use modern Python syntax

All pre-commit hooks should now pass.

* fix: Apply final formatting fixes for pre-commit hooks

- Remove unused imports (asyncio, pathlib.Path)
- Remove unused class imports in demo script
- Ensure all files pass ruff format and pre-commit checks

This should resolve all remaining CI linting issues.

* fix: Apply pre-commit formatting changes

- Fix trailing whitespace in all files
- Apply ruff formatting to match project standards
- Ensure consistent code style across all MCP integration files

This commit applies the exact changes that pre-commit hooks expect.

* fix: Apply pre-commit hooks formatting fixes

- Remove trailing whitespace from all files
- Fix ruff formatting issues (2 errors resolved)
- Apply consistent code formatting across 3 files
- Ensure all files pass pre-commit validation

This resolves all CI formatting failures.

* fix: Update MCP RAG classes to match BaseRAGExample signature

- Fix SlackMCPRAG and TwitterMCPRAG __init__ methods to provide required parameters
- Add name, description, and default_index_name to super().__init__ calls
- Resolves test failures: test_slack_rag_initialization and test_twitter_rag_initialization

This fixes the TypeError caused by BaseRAGExample requiring additional parameters.

* style: Apply ruff formatting - add trailing commas

- Add trailing commas to super().__init__ calls in SlackMCPRAG and TwitterMCPRAG
- Fixes ruff format pre-commit hook requirements

* fix: Resolve SentenceTransformer model_kwargs parameter conflict

- Fix local_files_only parameter conflict in embedding_compute.py
- Create separate copies of model_kwargs and tokenizer_kwargs for local vs network loading
- Prevents parameter conflicts when falling back from local to network loading
- Resolves TypeError in test_readme_examples.py tests

This addresses the SentenceTransformer initialization issues in CI tests.

* fix: Add comprehensive SentenceTransformer version compatibility

- Handle both old and new sentence-transformers versions
- Gracefully fallback from advanced parameters to basic initialization
- Catch TypeError for model_kwargs/tokenizer_kwargs and use basic SentenceTransformer init
- Ensures compatibility across different CI environments and local setups
- Maintains optimization benefits where supported while ensuring broad compatibility

This resolves test failures in CI environments with older sentence-transformers versions.

* style: Apply ruff formatting to embedding_compute.py

- Break long logger.warning lines for better readability
- Fixes pre-commit hook formatting requirements

* docs: Comprehensive documentation improvements for better user experience

- Add clear step-by-step Getting Started Guide for new users
- Add comprehensive CLI Reference with all commands and options
- Improve installation instructions with clear steps and verification
- Add detailed troubleshooting section for common issues (Ollama, OpenAI, etc.)
- Clarify difference between CLI commands and specialized apps
- Add environment variables documentation
- Improve MCP integration documentation with CLI integration examples
- Address user feedback about confusing installation and setup process

This resolves documentation gaps that made LEANN difficult for non-specialists to use.

* style: Remove trailing whitespace from README.md

- Fix trailing whitespace issues found by pre-commit hooks
- Ensures consistent formatting across documentation

* docs: Simplify README by removing excessive documentation

- Remove overly complex CLI reference and getting started sections (lines 61-334)
- Remove emojis from section headers for cleaner appearance
- Keep README simple and focused as requested
- Maintain essential MCP integration documentation

This addresses feedback to keep documentation minimal and avoid auto-generated content.

* docs: Address maintainer feedback on README improvements

- Restore emojis in section headers (Prerequisites and Quick Install)
- Add MCP live data feature mention in line 23 with links to Slack and Twitter
- Add detailed API credential setup instructions for Slack:
  - Step-by-step Slack App creation process
  - Required OAuth scopes and permissions
  - Clear token identification (xoxb- vs xapp-)
- Add detailed API credential setup instructions for Twitter:
  - Twitter Developer Account application process
  - API v2 requirements for bookmarks access
  - Required permissions and scopes

This addresses maintainer feedback to make API setup more user-friendly.
2025-10-07 02:18:32 -07:00
Andy Lee
5ba9cf6442 chore: require sentence-transformers >=3 and pin transformers <4.46 2025-10-06 15:52:56 -07:00
Andy Lee
1484406a8d chore: align core deps with transformers pin 2025-10-05 19:01:58 -07:00
Andy Lee
761ec1f0ac chore: pin transformers for py39 2025-10-05 18:29:45 -07:00
Andy Lee
4808afc686 docs: point DiskANN link to public PDF 2025-10-05 17:58:57 -07:00
Jon Haddad
0bba4b2157 Add readline support to interactive command line interfaces (#121)
* Add readline support to interactive command line interfaces

- Implement readline history, navigation, and editing for CLI, API, and RAG chat modes
- Create shared InteractiveSession class to consolidate readline functionality
- Add command history persistence across sessions with separate files per context
- Support built-in commands: help, clear, history, quit/exit
- Enable arrow key navigation and command editing in all interactive modes

* Improvements based on feedback
2025-10-05 17:38:15 -07:00
Kishlay Kisu
e67b5f44fa Implement FileSystem wide semantic file search engine with temporal awareness using LEANN. (#103)
* system wide semantic file search with temporal awareness

* ruff checking passed

* graceful exit for empty dump

* error thrown for time only search

* fixes
2025-10-05 17:26:48 -07:00
Aakash Suresh
658bce47ef Feature/imessage rag support (#131) 2025-10-02 10:40:57 -07:00
Andy Lee
6b399ad8d2 fix: launch another port when updating 2025-09-30 13:00:00 -07:00
Andy Lee
16f35aa067 Update faiss for batch distances calc & caching when updating 2025-09-30 12:42:40 -07:00
Andy Lee
ab9c6bd69e Fix update. Should launch embedding server first (#130)
* fix: set ntotal for storage as well

* fix: launch embedding server before adding
2025-09-30 00:58:17 -07:00
yichuan520030910320
e2b37914ce add dynamic add test 2025-09-30 00:48:46 -07:00
Andy Lee
e588100674 fix: set ntotal for storage as well (#129) 2025-09-29 20:43:16 -07:00
Andy Lee
fecee94af1 Experiments (#68)
* feat: finance bench

* docs: results

* chore: ignroe data README

* feat: fix financebench

* feat: laion, also required idmaps support

* style: format

* style: format

* fix: resolve ruff linting errors

- Remove unused variables in benchmark scripts
- Rename unused loop variables to follow convention

* feat: enron email bench

* experiments for running DiskANN & BM25 on Arch 4090

* style: format

* chore(ci): remove paru-bin submodule and config to fix checkout --recurse-submodules

* docs: data

* docs: data updated

* fix: as package

* fix(ci): only run pre-commit

* chore: use http url of astchunk; use group for some dev deps

* fix(ci): should checkout modules as well since `uv sync` checks

* fix(ci): run with lint only

* fix: find links to install wheels available

* CI: force local wheels in uv install step

* CI: install local wheels via file paths

* CI: pick wheels matching current Python tag

* CI: handle python tag mismatches for local wheels

* CI: use matrix python venv and set macOS deployment target

* CI: revert install step to match main

* CI: use uv group install with local wheel selection

* CI: rely on setup-uv for Python and tighten group install

* CI: install build deps with uv python interpreter

* CI: use temporary uv venv for build deps

* CI: add build venv scripts path for wheel repair
2025-09-24 11:19:04 -07:00
yichuan520030910320
01475c10a0 add img 2025-09-23 23:25:05 -07:00
yichuan520030910320
c8aa063f48 merge main 2025-09-23 23:21:53 -07:00
yichuan520030910320
576beb13db add doc about multimodal 2025-09-23 23:21:03 -07:00
Andy Lee
63c7b0c8a3 Fix restart embedding server when passages change (#117)
* fix: restart embedding server when passages change

* fix: restore python 3.9 typing compatibility
2025-09-23 22:28:36 -07:00
Andy Lee
ec889f7ef4 Allow 'leann ask' to accept a positional question (#116) 2025-09-23 21:18:57 -07:00
Yi-Ting Chiu
322e5c162d docs: open ai api compatibility (#118) 2025-09-23 21:17:50 -07:00
Yichuan Wang
edde0cdeb2 [Feat] ColQwen intergration (#111)
* add colqwen stuff

* add colqwen stuff and pass ruff

* remove ipynb
2025-09-23 17:51:29 -07:00
Andy Lee
db7ba27ff6 feat: Add support for configurable local LLM endpoints (#115)
* feat: support configurable local llm endpoints

* docs
2025-09-23 15:12:13 -07:00
Andy Lee
5f7806e16f Introducing dynamic index update (#108)
* feat: Add GitHub PR and issue templates for better contributor experience

* simplify: Make templates more concise and user-friendly

* fix: enable is_compact=False, is_recompute=True

* feat: update when recompute

* test

* fix: real recompute

* refactor

* fix: compare with no-recompute

* fix: test
2025-09-21 22:56:27 -07:00
51 changed files with 12059 additions and 4310 deletions

View File

@@ -14,6 +14,6 @@ jobs:
- uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2
with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

3
.gitignore vendored
View File

@@ -105,3 +105,6 @@ apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weavia
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
# AUR build directory (Arch Linux)
paru-bin/

View File

@@ -8,16 +8,32 @@
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q">
<img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
</a>
<a href="assets/wechat_user_group.JPG" title="Join WeChat group">
<img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group">
</a>
</p>
<div align="center">
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
</a>
<p>
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
</p>
</div>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
The smallest vector index in the world. RAG Everything with LEANN!
</h2>
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
@@ -777,7 +793,7 @@ Once your iMessage conversations are indexed, you can search with queries like:
### MCP Integration: RAG on Live Data from Any Platform
**NEW!** Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
**Key Benefits:**
- **Live Data Access**: Fetch real-time data without manual exports
@@ -801,18 +817,17 @@ python -m apps.slack_rag \
--query "What did we decide about the product launch?"
```
**Setup Requirements:**
**📖 Comprehensive Setup Guide**: For detailed setup instructions, troubleshooting common issues (like "users cache is not ready yet"), and advanced configuration options, see our [**Slack Setup Guide**](docs/slack-setup-guide.md).
**Quick Setup:**
1. Install a Slack MCP server (e.g., `npm install -g slack-mcp-server`)
2. Create a Slack App and get API credentials:
- Go to [api.slack.com/apps](https://api.slack.com/apps) and create a new app
- Under "OAuth & Permissions", add these Bot Token Scopes: `channels:read`, `channels:history`, `groups:read`, `groups:history`, `im:read`, `im:history`, `mpim:read`, `mpim:history`
- Install the app to your workspace and copy the "Bot User OAuth Token" (starts with `xoxb-`)
- Under "App-Level Tokens", create a token with `connections:write` scope (starts with `xapp-`)
2. Create a Slack App and get API credentials (see detailed guide above)
3. Set environment variables:
```bash
export SLACK_BOT_TOKEN="xoxb-your-bot-token"
export SLACK_APP_TOKEN="xapp-your-app-token"
export SLACK_APP_TOKEN="xapp-your-app-token" # Optional
```
3. Test connection with `--test-connection` flag
4. Test connection with `--test-connection` flag
**Arguments:**
- `--mcp-server`: Command to start the Slack MCP server
@@ -820,6 +835,8 @@ python -m apps.slack_rag \
- `--channels`: Specific channels to index (optional)
- `--concatenate-conversations`: Group messages by channel (default: true)
- `--max-messages-per-channel`: Limit messages per channel (default: 100)
- `--max-retries`: Maximum retries for cache sync issues (default: 5)
- `--retry-delay`: Initial delay between retries in seconds (default: 2.0)
#### 🐦 Twitter Bookmarks: Your Personal Tweet Library
@@ -858,7 +875,7 @@ python -m apps.twitter_rag \
- `--no-tweet-content`: Exclude tweet content, only metadata
- `--no-metadata`: Exclude engagement metadata
<!-- </details> -->
</details>
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
@@ -875,7 +892,7 @@ python -m apps.twitter_rag \
- "Show me bookmarked threads about startup advice"
- "What Python tutorials did I save?"
<details>
</details>
<summary><strong>🔧 Using MCP with CLI Commands</strong></summary>
**Want to use MCP data with regular LEANN CLI?** You can combine MCP apps with CLI commands:
@@ -921,7 +938,7 @@ Want to add support for other platforms? LEANN's MCP integration is designed for
### 🚀 Claude Code Integration: Transform Your Development Workflow!
<details>
<summary><strong>NEW!! ASTAware Code Chunking</strong></summary>
<summary><strong>ASTAware Code Chunking</strong></summary>
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
@@ -1208,3 +1225,7 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
<p align="center">
Made with ❤️ by the Leann team
</p>
## 🤖 Explore LEANN with AI
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.

View File

@@ -10,9 +10,39 @@ from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
from leann.interactive_utils import create_rag_session
# Optional import: older PyPI builds may not include interactive_utils
try:
from leann.interactive_utils import create_rag_session
except ImportError:
def create_rag_session(app_name: str, data_description: str):
class _SimpleSession:
def run_interactive_loop(self, handler):
print(f"Interactive session for {app_name}: {data_description}")
print("Interactive mode not available in this build")
return _SimpleSession()
from leann.registry import register_project_directory
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Optional import: older PyPI builds may not include settings
try:
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
except ImportError:
# Minimal fallbacks if settings helpers are unavailable
import os
def resolve_ollama_host(value: str | None) -> str | None:
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
def resolve_openai_api_key(value: str | None) -> str | None:
return value or os.getenv("OPENAI_API_KEY")
def resolve_openai_base_url(value: str | None) -> str | None:
return value or os.getenv("OPENAI_BASE_URL")
dotenv.load_dotenv()
@@ -150,14 +180,14 @@ class BaseRAGExample(ABC):
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=512,
help="Maximum characters per AST chunk (default: 512)",
default=300,
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks (default: 64)",
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
)
ast_group.add_argument(
"--code-file-extensions",

View File

@@ -12,6 +12,7 @@ from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
@@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
@@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
__all__ = [
"CODE_EXTENSIONS",
"_traditional_chunks_as_dicts",
"create_ast_chunks",
"create_text_chunks",
"create_traditional_chunks",

View File

@@ -1,12 +1,18 @@
from __future__ import annotations
import concurrent.futures
import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
import numpy as np
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
@@ -16,6 +22,380 @@ def _ensure_repo_paths_importable(current_file: str) -> None:
sys.path.append(str(_leann_hnsw_pkg))
def _find_backend_module_file() -> Optional[Path]:
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
this_file = Path(__file__).resolve()
candidates: list[Path] = []
# Common in-repo location
repo_root = this_file.parents[3]
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
candidates.append(
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
)
for cand in candidates:
try:
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
pass
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
for p in list(sys.path):
try:
cand = Path(p) / "leann_multi_vector.py"
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
continue
return None
_BACKEND_LEANN_CLASS: Optional[type] = None
def _get_backend_leann_multi_vector() -> type:
"""Load backend LeannMultiVector class even if this file shadows its module name."""
global _BACKEND_LEANN_CLASS
if _BACKEND_LEANN_CLASS is not None:
return _BACKEND_LEANN_CLASS
backend_path = _find_backend_module_file()
if backend_path is None:
# Fallback to local implementation in this module
try:
cls = LeannMultiVector # type: ignore[name-defined]
_BACKEND_LEANN_CLASS = cls
return cls
except Exception as e:
raise ImportError(
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
) from e
import importlib.util
module_name = "leann_hnsw_backend_module"
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
if spec is None or spec.loader is None:
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
backend_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = backend_module
spec.loader.exec_module(backend_module) # type: ignore[assignment]
if not hasattr(backend_module, "LeannMultiVector"):
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
return _BACKEND_LEANN_CLASS
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
) -> Any:
LeannMultiVector = _get_backend_leann_multi_vector()
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
"image": images[i], # Include the original image
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
LeannMultiVector = _get_backend_leann_multi_vector()
index_base = Path(index_path)
# Check for the actual HNSW index file written by the backend + our sidecar files
index_file = index_base.parent / f"{index_base.stem}.index"
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_file.exists() and meta.exists() and labels.exists():
try:
with open(meta, encoding="utf-8") as f:
meta_json = json.load(f)
dim = int(meta_json.get("dimensions", 128))
except Exception:
dim = 128
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Ensure repo paths are importable for dynamic backend loading
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
@@ -45,6 +425,7 @@ class LeannMultiVector:
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
self._docid_to_indices: dict[int, list[int]] | None = None
def _meta_dict(self) -> dict:
return {
@@ -69,6 +450,7 @@ class LeannMultiVector:
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
"image": data.get("image"), # PIL Image object (optional)
}
)
@@ -80,6 +462,15 @@ class LeannMultiVector:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def _embeddings_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
def _images_dir_path(self) -> Path:
"""Directory where original images are stored."""
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.images"
def create_index(self) -> None:
if not self._pending_items:
return
@@ -87,10 +478,23 @@ class LeannMultiVector:
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
# Create images directory if needed
images_dir = self._images_dir_path()
images_dir.mkdir(parents=True, exist_ok=True)
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
image = item.get("image")
# Save image if provided
image_path = ""
if image is not None and isinstance(image, Image.Image):
image_filename = f"doc_{doc_id}.png"
image_path = str(images_dir / image_filename)
image.save(image_path, "PNG")
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
@@ -100,6 +504,7 @@ class LeannMultiVector:
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
"image_path": image_path, # Store the path to the saved image
}
)
@@ -107,7 +512,6 @@ class LeannMultiVector:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
@@ -121,6 +525,9 @@ class LeannMultiVector:
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
# Persist embeddings for exact reranking
np.save(self._embeddings_path(), embeddings_np)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
@@ -133,6 +540,19 @@ class LeannMultiVector:
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def _build_docid_to_indices_if_needed(self) -> None:
if self._docid_to_indices is not None:
return
self._load_labels_meta_if_needed()
mapping: dict[int, list[int]] = {}
for idx, meta in enumerate(self._labels_meta):
try:
doc_id = int(meta["doc_id"]) # type: ignore[index]
except Exception:
continue
mapping.setdefault(doc_id, []).append(idx)
self._docid_to_indices = mapping
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
@@ -180,3 +600,181 @@ class LeannMultiVector:
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact(
self,
data: np.ndarray,
topk: int,
*,
first_stage_k: int = 200,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
High-precision MaxSim reranking over candidate documents.
Steps:
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
2) For each candidate doc, load all its token embeddings and compute
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
Returns top-k list of (score, doc_id).
"""
# Normalize inputs
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
# Fallback to approximate if we don't have persisted embeddings
return self.search(data, topk, first_stage_k=first_stage_k)
# Memory-map embeddings to avoid loading all into RAM
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
# First-stage ANN to collect candidate doc_ids
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
if labels is None:
return []
candidate_doc_ids: set[int] = set()
for batch in labels:
for sid in batch:
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
# Exact scoring per doc (parallelized)
assert self._docid_to_indices is not None
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
sim = np.dot(data, doc_vecs.T)
# nan-safe
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact_all(
self,
data: np.ndarray,
topk: int,
*,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
Exact MaxSim over ALL documents (no ANN pre-filtering).
This computes, for each document, sum_i max_j dot(q_i, d_j).
It memory-maps the persisted token-embedding matrix for scalability.
"""
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve the original image for a given doc_id from the index.
Args:
doc_id: The document ID
Returns:
PIL Image object if found, None otherwise
"""
self._load_labels_meta_if_needed()
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
image_path = meta.get("image_path", "")
if image_path and Path(image_path).exists():
return Image.open(image_path)
break
return None
def get_metadata(self, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a given doc_id.
Args:
doc_id: The document ID
Returns:
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
"""
self._load_labels_meta_if_needed()
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
return {
"doc_id": doc_id,
"filepath": meta.get("filepath", ""),
"image_path": meta.get("image_path", ""),
}
return None

View File

@@ -2,34 +2,31 @@
# %%
# uv pip install matplotlib qwen_vl_utils
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
from typing import Any, Optional
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import ( # utility functions/classes
_ensure_repo_paths_importable,
_load_images_from_dir,
_maybe_convert_pdf_to_images,
_load_colvision,
_embed_images,
_embed_queries,
_build_index,
_load_retriever_if_index_exists,
_generate_similarity_map,
QwenVL,
)
_ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
@@ -44,7 +41,7 @@ PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
TOPK: int = 3
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
@@ -54,332 +51,57 @@ SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 128
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
MAX_NEW_TOKENS: int = 1024
# %%
# Step 1: Prepare data
if USE_HF_DATASET:
from datasets import load_dataset
# Step 1: Check if we can skip data loading (index already exists)
retriever: Optional[Any] = None
need_to_build_index = REBUILD_INDEX
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N ):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
if not REBUILD_INDEX:
retriever = _load_retriever_if_index_exists(INDEX_PATH)
if retriever is not None:
print(f"✓ Index loaded from {INDEX_PATH}")
print(f"✓ Images available at: {retriever._images_dir_path()}")
need_to_build_index = False
else:
print(f"Index not found, will build new index")
need_to_build_index = True
# Step 2: Load data only if we need to build the index
if need_to_build_index:
print("Loading dataset...")
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print(f"Loaded {len(images)} images")
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print("Skipping dataset loading (using existing index)")
filepaths = [] # Not needed when using existing index
images = [] # Not needed when using existing index
# %%
# Step 2: Load model and processor
# Step 3: Load model and processor (only if we need to build index or perform search)
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
@@ -387,34 +109,39 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
try:
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None:
# Step 4: Build index if needed
if need_to_build_index and retriever is None:
print("Building index...")
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
# Clear memory
del images, filepaths, doc_vecs
# Note: Images are now stored in the index, retriever will load them on-demand from disk
# %%
# Step 4: Embed query and search
# Step 5: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
path = filepaths[doc_id]
# Retrieve image from index instead of memory
image = retriever.get_image(doc_id)
if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
continue
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(images[doc_id])
top_images.append(image)
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
@@ -427,12 +154,17 @@ else:
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
# Print the retrieval score (document-level MaxSim) alongside the saved path
try:
score, _doc_id = results[rank - 1]
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 5: Similarity maps for top-K results
# Step 6: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
@@ -469,7 +201,7 @@ if results and SIMILARITY_MAP:
# %%
# Step 6: Optional answer generation
# Step 7: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)

View File

@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
flexible message processing options.
"""
import ast
import asyncio
import json
import logging
@@ -29,6 +30,8 @@ class SlackMCPReader:
workspace_name: Optional[str] = None,
concatenate_conversations: bool = True,
max_messages_per_conversation: int = 100,
max_retries: int = 5,
retry_delay: float = 2.0,
):
"""
Initialize the Slack MCP Reader.
@@ -38,11 +41,15 @@ class SlackMCPReader:
workspace_name: Optional workspace name to filter messages
concatenate_conversations: Whether to group messages by channel/thread
max_messages_per_conversation: Maximum messages to include per conversation
max_retries: Maximum number of retries for failed operations
retry_delay: Initial delay between retries in seconds
"""
self.mcp_server_command = mcp_server_command
self.workspace_name = workspace_name
self.concatenate_conversations = concatenate_conversations
self.max_messages_per_conversation = max_messages_per_conversation
self.max_retries = max_retries
self.retry_delay = retry_delay
self.mcp_process = None
async def start_mcp_server(self):
@@ -110,11 +117,73 @@ class SlackMCPReader:
return response.get("result", {}).get("tools", [])
def _is_cache_sync_error(self, error: dict) -> bool:
"""Check if the error is related to users cache not being ready."""
if isinstance(error, dict):
message = error.get("message", "").lower()
return (
"users cache is not ready" in message or "sync process is still running" in message
)
return False
async def _retry_with_backoff(self, func, *args, **kwargs):
"""Retry a function with exponential backoff, especially for cache sync issues."""
last_exception = None
for attempt in range(self.max_retries + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# Check if this is a cache sync error
error_dict = {}
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
error_dict = e.args[0]
elif "Failed to fetch messages" in str(e):
# Try to extract error from the exception message
import re
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError):
pass
else:
# Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError):
pass
if self._is_cache_sync_error(error_dict):
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt) # Exponential backoff
logger.info(
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
)
await asyncio.sleep(delay)
continue
else:
logger.warning(
f"Cache sync still not ready after {self.max_retries} retries, giving up"
)
break
else:
# Not a cache sync error, don't retry
break
# If we get here, all retries failed or it's not a retryable error
raise last_exception
async def fetch_slack_messages(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Fetch Slack messages using MCP tools.
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
Args:
channel: Optional channel name to filter messages
@@ -123,32 +192,59 @@ class SlackMCPReader:
Returns:
List of message dictionaries
"""
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
async def _fetch_slack_messages_impl(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Internal implementation of fetch_slack_messages without retry logic.
"""
# This is a generic implementation - specific MCP servers may have different tool names
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
tools = await self.list_available_tools()
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
message_tool = None
# Look for a tool that can fetch messages
# Look for a tool that can fetch messages - prioritize conversations_history
message_tool = None
# First, try to find conversations_history specifically
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(
keyword in tool_name
for keyword in ["message", "history", "channel", "conversation"]
):
if "conversations_history" in tool_name:
message_tool = tool
logger.info(f"Found conversations_history tool: {tool}")
break
# If not found, look for other message-fetching tools
if not message_tool:
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(
keyword in tool_name
for keyword in ["conversations_search", "message", "history"]
):
message_tool = tool
break
if not message_tool:
raise RuntimeError("No message fetching tool found in MCP server")
# Prepare tool call parameters
tool_params = {"limit": limit}
tool_params = {"limit": "180d"} # Use 180 days to get older messages
if channel:
# Try common parameter names for channel specification
for param_name in ["channel", "channel_id", "channel_name"]:
tool_params[param_name] = channel
break
# For conversations_history, use channel_id parameter
if message_tool["name"] == "conversations_history":
tool_params["channel_id"] = channel
else:
# Try common parameter names for channel specification
for param_name in ["channel", "channel_id", "channel_name"]:
tool_params[param_name] = channel
break
logger.info(f"Tool parameters: {tool_params}")
fetch_request = {
"jsonrpc": "2.0",
@@ -170,8 +266,8 @@ class SlackMCPReader:
try:
messages = json.loads(content["text"])
except json.JSONDecodeError:
# If not JSON, treat as plain text
messages = [{"text": content["text"], "channel": channel or "unknown"}]
# If not JSON, try to parse as CSV format (Slack MCP server format)
messages = self._parse_csv_messages(content["text"], channel)
else:
messages = result["content"]
else:
@@ -180,6 +276,56 @@ class SlackMCPReader:
return messages if isinstance(messages, list) else [messages]
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
"""Parse CSV format messages from Slack MCP server."""
import csv
import io
messages = []
try:
# Split by lines and process each line as a CSV row
lines = csv_text.strip().split("\n")
if not lines:
return messages
# Skip header line if it exists
start_idx = 0
if lines[0].startswith("MsgID,UserID,UserName"):
start_idx = 1
for line in lines[start_idx:]:
if not line.strip():
continue
# Parse CSV line
reader = csv.reader(io.StringIO(line))
try:
row = next(reader)
if len(row) >= 7: # Ensure we have enough columns
message = {
"ts": row[0],
"user": row[1],
"username": row[2],
"real_name": row[3],
"channel": row[4],
"thread_ts": row[5],
"text": row[6],
"time": row[7] if len(row) > 7 else "",
"reactions": row[8] if len(row) > 8 else "",
"cursor": row[9] if len(row) > 9 else "",
}
messages.append(message)
except Exception as e:
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
continue
except Exception as e:
logger.warning(f"Failed to parse CSV messages: {e}")
# Fallback: treat entire text as one message
messages = [{"text": csv_text, "channel": channel or "unknown"}]
return messages
def _format_message(self, message: dict[str, Any]) -> str:
"""Format a single message for indexing."""
text = message.get("text", "")
@@ -251,6 +397,40 @@ class SlackMCPReader:
return "\n".join(content_parts)
async def get_all_channels(self) -> list[str]:
"""Get list of all available channels."""
try:
channels_list_request = {
"jsonrpc": "2.0",
"id": 4,
"method": "tools/call",
"params": {"name": "channels_list", "arguments": {}},
}
channels_response = await self.send_mcp_request(channels_list_request)
if "result" in channels_response:
result = channels_response["result"]
if "content" in result and isinstance(result["content"], list):
content = result["content"][0] if result["content"] else {}
if "text" in content:
# Parse the channels from the response
channels = []
lines = content["text"].split("\n")
for line in lines:
if line.strip() and ("#" in line or "C" in line[:10]):
# Extract channel ID or name
parts = line.split()
for part in parts:
if part.startswith("C") and len(part) > 5:
channels.append(part)
elif part.startswith("#"):
channels.append(part[1:]) # Remove #
logger.info(f"Found {len(channels)} channels: {channels}")
return channels
return []
except Exception as e:
logger.warning(f"Failed to get channels list: {e}")
return []
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
"""
Read Slack data and return formatted text chunks.
@@ -287,36 +467,33 @@ class SlackMCPReader:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
else:
# Fetch from all available channels/conversations
# This is a simplified approach - real implementation would need to
# discover available channels first
try:
messages = await self.fetch_slack_messages(limit=1000)
if messages:
# Group messages by channel if concatenating
if self.concatenate_conversations:
channel_messages = {}
for message in messages:
channel = message.get(
"channel", message.get("channel_name", "general")
)
if channel not in channel_messages:
channel_messages[channel] = []
channel_messages[channel].append(message)
# Fetch from all available channels
logger.info("Fetching from all available channels...")
all_channels = await self.get_all_channels()
# Create concatenated content for each channel
for channel, msgs in channel_messages.items():
text_content = self._create_concatenated_content(msgs, channel)
if not all_channels:
# Fallback to common channel names if we can't get the list
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
logger.info(f"Using fallback channels: {all_channels}")
for channel in all_channels:
try:
logger.info(f"Searching channel: {channel}")
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
if messages:
if self.concatenate_conversations:
text_content = self._create_concatenated_content(messages, channel)
if text_content.strip():
all_texts.append(text_content)
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.error(f"Failed to fetch messages: {e}")
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
return all_texts

View File

@@ -78,6 +78,20 @@ class SlackMCPRAG(BaseRAGExample):
help="Test MCP server connection and list available tools without indexing",
)
parser.add_argument(
"--max-retries",
type=int,
default=5,
help="Maximum number of retries for failed operations (default: 5)",
)
parser.add_argument(
"--retry-delay",
type=float,
default=2.0,
help="Initial delay between retries in seconds (default: 2.0)",
)
async def test_mcp_connection(self, args) -> bool:
"""Test the MCP server connection and display available tools."""
print(f"Testing connection to MCP server: {args.mcp_server}")
@@ -88,12 +102,14 @@ class SlackMCPRAG(BaseRAGExample):
workspace_name=args.workspace_name,
concatenate_conversations=not args.no_concatenate_conversations,
max_messages_per_conversation=args.max_messages_per_channel,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
async with reader:
tools = await reader.list_available_tools()
print("\nSuccessfully connected to MCP server!")
print("Successfully connected to MCP server!")
print(f"Available tools ({len(tools)}):")
for i, tool in enumerate(tools, 1):
@@ -115,7 +131,7 @@ class SlackMCPRAG(BaseRAGExample):
return True
except Exception as e:
print(f"\nFailed to connect to MCP server: {e}")
print(f"Failed to connect to MCP server: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure the MCP server is installed and accessible")
print("2. Check if the server command is correct")
@@ -130,8 +146,11 @@ class SlackMCPRAG(BaseRAGExample):
if args.workspace_name:
print(f"Workspace: {args.workspace_name}")
if args.channels:
print(f"Channels: {', '.join(args.channels)}")
# Filter out empty strings from channels
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
if channels:
print(f"Channels: {', '.join(channels)}")
else:
print("Fetching from all available channels")
@@ -146,18 +165,20 @@ class SlackMCPRAG(BaseRAGExample):
workspace_name=args.workspace_name,
concatenate_conversations=concatenate,
max_messages_per_conversation=args.max_messages_per_channel,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
texts = await reader.read_slack_data(channels=args.channels)
texts = await reader.read_slack_data(channels=channels)
if not texts:
print("No messages found! This could mean:")
print("No messages found! This could mean:")
print("- The MCP server couldn't fetch messages")
print("- The specified channels don't exist or are empty")
print("- Authentication issues with the Slack workspace")
return []
print(f"Successfully loaded {len(texts)} text chunks from Slack")
print(f"Successfully loaded {len(texts)} text chunks from Slack")
# Show sample of what was loaded
if texts:
@@ -170,7 +191,7 @@ class SlackMCPRAG(BaseRAGExample):
return texts
except Exception as e:
print(f"Error loading Slack data: {e}")
print(f"Error loading Slack data: {e}")
print("\nThis might be due to:")
print("- MCP server connection issues")
print("- Authentication problems")
@@ -188,7 +209,7 @@ class SlackMCPRAG(BaseRAGExample):
if not success:
return
print(
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
"MCP server is working! You can now run without --test-connection to start indexing."
)
return

143
benchmarks/update/README.md Normal file
View File

@@ -0,0 +1,143 @@
# Update Benchmarks
This directory hosts two benchmark suites that exercise LEANNs HNSW “update +
search” pipeline under different assumptions:
1. **RNG recompute latency** measure how random-neighbour pruning and cache
settings influence incremental `add()` latency when embeddings are fetched
over the ZMQ embedding server.
2. **Update strategy comparison** compare a fully sequential update pipeline
against an offline approach that keeps the graph static and fuses results.
Both suites build a non-compact, `is_recompute=True` index so that new
embeddings are pulled from the embedding server. Benchmark outputs are written
under `.leann/bench/` by default and appended to CSV files for later plotting.
## Benchmarks
### 1. HNSW RNG Recompute Benchmark
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
changes the forward / reverse RNG pruning flags and whether the embedding cache
is enabled:
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
| ---------------------------------- | ----------- | ----------- | ------------------- |
| `baseline` | Enabled | Enabled | Enabled |
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
For each scenario the script:
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
3. Appends the requested updates using the scenarios RNG flags.
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
timings before appending a row to the CSV output.
**Run:**
```bash
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
LEANN_LOG_LEVEL=INFO \
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--runs 1 \
--index-path .leann/bench/test.leann \
--initial-files data/PrideandPrejudice.txt \
--update-files data/huawei_pangu.md \
--max-initial 300 \
--max-updates 1 \
--add-timeout 120
```
**Output:**
- `benchmarks/update/bench_results.csv` per-scenario timing statistics
(including ms/passage) for each run.
- `.leann/bench/hnsw_server.log` detailed ZMQ/server logs (path controlled by
`LEANN_HNSW_LOG_PATH`).
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
### 2. Sequential vs. Offline Update Benchmark
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
same dataset:
- **Scenario A Sequential Update**
- Start an embedding server.
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
mutates the HNSW graph.
- After all inserts, run a search on the updated graph.
- Metrics recorded: update time (`add_total_s`), post-update search time
(`search_time_s`), combined total (`total_time_s`), and per-passage
latency.
- **Scenario B Offline Embedding + Concurrent Search**
- Stop Scenario As server and start a fresh embedding server.
- Spawn two threads: one generates embeddings for the new passages offline
(graph unchanged); the other computes the query embedding and searches the
existing graph.
- Merge offline similarities with the graph search results to emulate late
fusion, then report the merged topk preview.
- Metrics recorded: embedding time (`emb_time_s`), search time
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
**Run (both scenarios):**
```bash
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 \
--num-updates 1
```
You can pass `--only A` or `--only B` to run a single scenario. The script will
print timing summaries to stdout and append the results to CSV.
**Output:**
- `benchmarks/update/offline_vs_update.csv` per-scenario timing statistics for
Scenario A and B.
- Console output includes Scenario Bs merged topk preview for quick sanity
checks.
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
### 3. Visualisation
`plot_bench_results.py` combines the RNG benchmark and the update strategy
benchmark into a single two-panel plot.
**Run:**
```bash
uv run -m benchmarks.update.plot_bench_results \
--csv benchmarks/update/bench_results.csv \
--csv-right benchmarks/update/offline_vs_update.csv \
--out benchmarks/update/bench_latency_from_csv.png
```
**Options:**
- `--broken-y` Enable a broken Y-axis (default: true when appropriate).
- `--csv` RNG benchmark results CSV (left panel).
- `--csv-right` Update strategy results CSV (right panel).
- `--out` Output image path (PNG/PDF supported).
**Output:**
- `benchmarks/update/bench_latency_from_csv.png` visual comparison of the two
suites.
- `benchmarks/update/bench_latency_from_csv.pdf` PDF version, suitable for
slides/papers.
## Parameters & Environment
### Common CLI Flags
- `--max-initial` Number of initial passages used to seed the index.
- `--max-updates` / `--num-updates` Number of passages to treat as updates.
- `--index-path` Base path (without extension) where the LEANN index is stored.
- `--runs` Number of repetitions (RNG benchmark only).
### Environment Variables
- `LEANN_HNSW_LOG_PATH` File to receive embedding-server logs (optional).
- `LEANN_LOG_LEVEL` Logging verbosity (DEBUG/INFO/WARNING/ERROR).
- `CUDA_VISIBLE_DEVICES` Set to empty string if you want to force CPU
execution of the embedding model.
With these scripts you can easily replicate LEANNs update benchmarks, compare
multiple RNG strategies, and evaluate whether sequential updates or offline
fusion better match your latency/accuracy trade-offs.

View File

@@ -0,0 +1,16 @@
"""Benchmarks for LEANN update workflows."""
# Expose helper to locate repository root for other modules that need it.
from pathlib import Path
def find_repo_root() -> Path:
"""Return the project root containing pyproject.toml."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
return current.parents[1]
__all__ = ["find_repo_root"]

View File

@@ -0,0 +1,804 @@
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
embedding recomputation.
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
so that we build a non-compact ``is_recompute=True`` index, spin up the
standard HNSW embedding server, and measure how long incremental ``add`` takes
when RNG pruning is fully enabled vs. partially/fully disabled.
Example usage (run from the repo root; downloads the model on first run)::
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--index-path .leann/bench/leann-demo.leann \
--runs 1
You can tweak the input documents with ``--initial-files`` / ``--update-files``
if you want a larger or different workload, and change the embedding model via
``--model-name``.
"""
import argparse
import json
import logging
import os
import pickle
import re
import sys
import time
from pathlib import Path
from typing import Any
import msgpack
import numpy as np
import zmq
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
return [{"text": text, "metadata": {}} for text in paragraphs]
def benchmark_update_with_mode(
index_path: Path,
new_chunks: list[dict[str, Any]],
model_name: str,
embedding_mode: str,
distance_metric: str,
disable_forward_rng: bool,
disable_reverse_rng: bool,
server_port: int,
add_timeout: int,
ef_construction: int,
) -> tuple[float, float]:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
index_file = index_path.parent / f"{index_path.stem}.index"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in new_chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
model_name,
mode=embedding_mode,
is_build=False,
batch_size=16,
)
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
index.is_recompute = True
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
try:
storage_index.ntotal = index.ntotal
except AttributeError:
pass
try:
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
if ef_construction is not None:
index.hnsw.efConstruction = ef_construction
except AttributeError:
pass
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
logger.info(
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
disable_forward_rng,
disable_reverse_rng,
applied_forward,
applied_reverse,
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=server_port,
model_name=model_name,
embedding_mode=embedding_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
)
if not server_started:
raise RuntimeError("Failed to start embedding server.")
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
_warmup_embedding_server(actual_port)
total_start = time.time()
add_elapsed = 0.0
try:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("incremental add timed out")
if add_timeout > 0:
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(add_timeout)
add_start = time.time()
for i in range(embeddings.shape[0]):
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
add_elapsed = time.time() - add_start
if add_timeout > 0:
signal.alarm(0)
faiss.write_index(index, str(index_file))
finally:
server_manager.stop_server()
except TimeoutError:
raise
except Exception:
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_size)
with open(offset_file, "wb") as f:
pickle.dump(offset_map_backup, f)
raise
prune_hnsw_embeddings_inplace(str(index_file))
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
# Reset toggles so the index on disk returns to baseline behaviour.
try:
index.hnsw.set_disable_rng_during_add(False)
index.hnsw.set_disable_reverse_prune(False)
except AttributeError:
pass
faiss.write_index(index, str(index_file))
total_elapsed = time.time() - total_start
return total_elapsed, add_elapsed
def _total_zmq_nodes(log_path: Path) -> int:
if not log_path.exists():
return 0
with log_path.open("r", encoding="utf-8") as log_file:
text = log_file.read()
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
def _warmup_embedding_server(port: int) -> None:
"""Send a dummy REQ so the embedding server loads its model."""
ctx = zmq.Context()
try:
sock = ctx.socket(zmq.REQ)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.RCVTIMEO, 5000)
sock.setsockopt(zmq.SNDTIMEO, 5000)
sock.connect(f"tcp://127.0.0.1:{port}")
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
sock.send(payload)
try:
sock.recv()
except zmq.error.Again:
pass
finally:
sock.close()
ctx.term()
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/leann-demo.leann"),
help="Output index base path (without extension).",
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
help="Files used to build the initial index.",
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
help="Files appended during the benchmark.",
)
parser.add_argument(
"--runs", type=int, default=1, help="How many times to repeat each scenario."
)
parser.add_argument(
"--model-name",
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model used for build/update.",
)
parser.add_argument(
"--embedding-mode",
default="sentence-transformers",
help="Embedding mode passed to LeannBuilder/embedding server.",
)
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
help="Distance metric for HNSW backend.",
)
parser.add_argument(
"--ef-construction",
type=int,
default=200,
help="efConstruction setting for initial build.",
)
parser.add_argument(
"--server-port",
type=int,
default=5557,
help="Port for the real embedding server.",
)
parser.add_argument(
"--max-initial",
type=int,
default=300,
help="Optional cap on initial passages (after chunking).",
)
parser.add_argument(
"--max-updates",
type=int,
default=1,
help="Optional cap on update passages (after chunking).",
)
parser.add_argument(
"--add-timeout",
type=int,
default=900,
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
)
parser.add_argument(
"--plot-path",
type=Path,
default=Path("bench_latency.png"),
help="Where to save the latency bar plot.",
)
parser.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
)
parser.add_argument(
"--broken-y",
action="store_true",
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
)
parser.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
)
parser.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Where to append per-scenario results as CSV.",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
if not update_paragraphs:
raise ValueError("No update passages found; please provide --update-files with content.")
update_chunks = prepare_new_chunks(update_paragraphs)
ensure_index_dir(args.index_path)
scenarios = [
("baseline", False, False, True),
("no_cache_baseline", False, False, False),
("disable_forward_rng", True, False, True),
("disable_forward_and_reverse_rng", True, True, True),
]
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
log_path.parent.mkdir(parents=True, exist_ok=True)
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
# CSV setup
import csv
run_id = time.strftime("%Y%m%d-%H%M%S")
csv_fields = [
"run_id",
"scenario",
"cache_enabled",
"ef_construction",
"max_initial",
"max_updates",
"total_time_s",
"add_only_s",
"latency_ms_per_passage",
"zmq_nodes",
"stageA_time_s",
"stageBC_time_s",
"model_name",
"embedding_mode",
"distance_metric",
]
# Create CSV with header if missing
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
for run in range(args.runs):
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
print(f"\nScenario: {name}")
cleanup_index_files(args.index_path)
if log_path.exists():
try:
log_path.unlink()
except OSError:
pass
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
prev_size = log_path.stat().st_size if log_path.exists() else 0
try:
total_elapsed, add_elapsed = benchmark_update_with_mode(
args.index_path,
update_chunks,
args.model_name,
args.embedding_mode,
args.distance_metric,
disable_forward,
disable_reverse,
args.server_port,
args.add_timeout,
args.ef_construction,
)
except TimeoutError as exc:
print(f"Scenario {name} timed out: {exc}")
continue
curr_size = log_path.stat().st_size if log_path.exists() else 0
if curr_size < prev_size:
prev_size = 0
zmq_count = 0
if log_path.exists():
with log_path.open("r", encoding="utf-8") as log_file:
log_file.seek(prev_size)
new_entries = log_file.read()
zmq_count = sum(
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
)
stageA = sum(
float(x)
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
)
stageBC = sum(
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
)
else:
stageA = 0.0
stageBC = 0.0
per_chunk = add_elapsed / len(update_chunks)
print(
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
)
print(f"ZMQ node fetch total: {zmq_count}")
results_total[name].append(total_elapsed)
results_add[name].append(add_elapsed)
results_zmq[name].append(zmq_count)
results_ms_per_passage[name].append(per_chunk * 1e3)
results_stageA[name].append(stageA)
results_stageBC[name].append(stageBC)
# Append row to CSV
if args.csv_path:
row = {
"run_id": run_id,
"scenario": name,
"cache_enabled": 1 if cache_enabled else 0,
"ef_construction": args.ef_construction,
"max_initial": args.max_initial,
"max_updates": args.max_updates,
"total_time_s": round(total_elapsed, 6),
"add_only_s": round(add_elapsed, 6),
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
"zmq_nodes": int(zmq_count),
"stageA_time_s": round(stageA, 6),
"stageBC_time_s": round(stageBC, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row)
print("\n=== Summary ===")
for name in results_add:
add_values = results_add[name]
total_values = results_total[name]
zmq_values = results_zmq[name]
latency_values = results_ms_per_passage[name]
if not add_values:
print(f"{name}: no successful runs")
continue
avg_add = sum(add_values) / len(add_values)
avg_total = sum(total_values) / len(total_values)
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
runs = len(add_values)
print(
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
)
if args.plot_path:
try:
import matplotlib.pyplot as plt
labels = [name for name, *_ in scenarios]
values = [
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
if results_ms_per_passage[name]
else 0.0
for name in labels
]
def _auto_cap(vals: list[float]) -> float | None:
s = sorted(vals, reverse=True)
if len(s) < 2:
return None
if s[1] > 0 and s[0] >= 2.5 * s[1]:
return s[1] * 1.1
return None
def _fmt_ms(v: float) -> str:
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
if args.broken_y:
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.4, 5.0),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
)
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i,
v + lower_cap * 0.02,
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False)
ax_bottom.xaxis.tick_bottom()
d = 0.015
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
ax_top.plot((-d, +d), (-d, +d), **kwargs)
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_bottom.set_xticks(range(len(labels)))
ax_bottom.set_xticklabels(labels)
ax = ax_bottom
else:
cap = args.cap_y or _auto_cap(values)
plt.figure(figsize=(7.2, 4.2))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (v, show) in enumerate(zip(values, show_vals)):
b = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(b[0])
if v > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
ax.plot(
[0.02 - 0.02, 0.02 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
ax.plot(
[0.98 - 0.02, 0.98 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
if any(v > cap for v in values):
ax.legend(
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels)
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
plt.ylabel("Average add latency (ms per passage)")
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
plt.tight_layout()
plt.savefig(args.plot_path)
print(f"Saved latency bar plot to {args.plot_path}")
# ZMQ time split (Stage A vs B/C)
try:
plt.figure(figsize=(6, 4))
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
bc_vals = [
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
]
ind = range(len(labels))
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
plt.bar(
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
)
plt.xticks(list(ind), labels, rotation=10)
plt.ylabel("Server ZMQ time (s)")
plt.title(
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
)
plt.legend()
out2 = args.plot_path.with_name(
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
)
plt.tight_layout()
plt.savefig(out2)
print(f"Saved ZMQ time split plot to {out2}")
except Exception as e:
print("Failed to plot ZMQ split:", e)
except ImportError:
print("matplotlib not available; skipping plot generation")
# leave the last build on disk for inspection
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario cache_enabled ef_construction max_initial max_updates total_time_s add_only_s latency_ms_per_passage zmq_nodes stageA_time_s stageBC_time_s model_name embedding_mode distance_metric
2 20251024-133101 baseline 1 200 300 1 3.391856 1.120359 1120.359421 126 0.507821 0.601608 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-133101 no_cache_baseline 0 200 300 1 34.941514 32.91376 32913.760185 4033 0.506933 32.159928 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251024-133101 disable_forward_rng 1 200 300 1 2.746756 0.8202 820.200443 66 0.474354 0.338454 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251024-133101 disable_forward_and_reverse_rng 1 200 300 1 2.396566 0.521478 521.478415 1 0.508973 0.006938 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -0,0 +1,704 @@
"""
Compare two latency models for small incremental updates vs. search:
Scenario A (sequential update then search):
- Build initial HNSW (is_recompute=True)
- Start embedding server (ZMQ) for recompute
- Add N passages one-by-one (each triggers recompute over ZMQ)
- Then run a search query on the updated index
- Report total time = sum(add_i) + search_time, with breakdowns
Scenario B (offline embeds + concurrent search; no graph updates):
- Do NOT insert the N passages into the graph
- In parallel: (1) compute embeddings for the N passages; (2) compute query
embedding and run a search on the existing index
- After both finish, compute similarity between the query embedding and the N
new passage embeddings, merge with the index search results by score, and
report time = max(embed_time, search_time) (i.e., no blocking on updates)
This script reuses the model/data loading conventions of
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
comparison for the two execution strategies above.
Example (from the repository root):
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 --num-updates 5 --k 10
"""
import argparse
import csv
import json
import logging
import os
import pickle
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import psutil # type: ignore
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
if metric == "cosine":
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
norms[norms == 0] = 1
vecs = vecs / norms
return vecs
def _read_index_for_search(index_path: Path) -> Any:
index_file = index_path.parent / f"{index_path.stem}.index"
# Force-disable experimental disk cache when loading the index so that
# incremental benchmarks don't pick up stale top-degree bitmaps.
cfg = faiss.HNSWIndexConfig()
cfg.is_recompute = True
if hasattr(cfg, "disk_cache_ratio"):
cfg.disk_cache_ratio = 0.0
if hasattr(cfg, "external_storage_path"):
cfg.external_storage_path = None
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
index = faiss.read_index(str(index_file), io_flags, cfg)
# ensure recompute mode persists after reload
try:
index.is_recompute = True
except AttributeError:
pass
try:
actual_ntotal = index.hnsw.levels.size()
except AttributeError:
actual_ntotal = index.ntotal
if actual_ntotal != index.ntotal:
print(
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
flush=True,
)
index.ntotal = actual_ntotal
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
return index
def _append_passages_for_updates(
meta_path: Path,
start_id: int,
texts: list[str],
) -> list[str]:
"""Append update passages so the embedding server can serve recompute fetches."""
if not texts:
return []
index_dir = meta_path.parent
meta_name = meta_path.name
if not meta_name.endswith(".meta.json"):
raise ValueError(f"Unexpected meta filename: {meta_path}")
index_base = meta_name[: -len(".meta.json")]
passages_file = index_dir / f"{index_base}.passages.jsonl"
offsets_file = index_dir / f"{index_base}.passages.idx"
if not passages_file.exists() or not offsets_file.exists():
raise FileNotFoundError(
"Passage store missing; cannot register update passages for recompute mode."
)
with open(offsets_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
assigned_ids: list[str] = []
with open(passages_file, "a", encoding="utf-8") as f:
for i, text in enumerate(texts):
passage_id = str(start_id + i)
offset = f.tell()
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
f.write("\n")
offset_map[passage_id] = offset
assigned_ids.append(passage_id)
with open(offsets_file, "wb") as f:
pickle.dump(offset_map, f)
try:
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
except json.JSONDecodeError:
meta = {}
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
return assigned_ids
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
q = np.ascontiguousarray(q, dtype=np.float32)
distances = np.zeros((1, k), dtype=np.float32)
indices = np.zeros((1, k), dtype=np.int64)
index.search(
1,
faiss.swig_ptr(q),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(indices),
)
return distances[0], indices[0]
def _score_for_metric(dist: float, metric: str) -> float:
# Convert FAISS distance to a "higher is better" score
if metric in ("mips", "cosine"):
return float(dist)
# l2 distance (smaller better) -> negative distance as score
return -float(dist)
def _merge_results(
index_results: tuple[np.ndarray, np.ndarray],
offline_scores: list[tuple[int, float]],
k: int,
metric: str,
) -> list[tuple[str, float]]:
distances, indices = index_results
merged: list[tuple[str, float]] = []
for distance, idx in zip(distances.tolist(), indices.tolist()):
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
for j, s in offline_scores:
merged.append((f"offline:{j}", s))
merged.sort(key=lambda x: x[1], reverse=True)
return merged[:k]
@dataclass
class ScenarioResult:
name: str
update_total_s: float
search_s: float
overall_s: float
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/offline-vs-update.leann"),
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
)
parser.add_argument("--max-initial", type=int, default=300)
parser.add_argument("--num-updates", type=int, default=5)
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
parser.add_argument(
"--query",
type=str,
default="neural network",
help="Query text used for the search benchmark.",
)
parser.add_argument("--server-port", type=int, default=5557)
parser.add_argument("--add-timeout", type=int, default=600)
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
parser.add_argument("--embedding-mode", default="sentence-transformers")
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
)
parser.add_argument("--ef-construction", type=int, default=200)
parser.add_argument(
"--only",
choices=["A", "B", "both"],
default="both",
help="Run only Scenario A, Scenario B, or both",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Where to append results (CSV).",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
# Load data
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, None)
if not update_paragraphs:
raise ValueError("No update passages loaded from --update-files")
update_paragraphs = update_paragraphs[: args.num_updates]
if len(update_paragraphs) < args.num_updates:
raise ValueError(
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
)
ensure_index_dir(args.index_path)
cleanup_index_files(args.index_path)
# Build initial index
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
# Prepare index object and meta
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
index = _read_index_for_search(args.index_path)
# CSV setup
run_id = time.strftime("%Y%m%d-%H%M%S")
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
csv_fields = [
"run_id",
"scenario",
"max_initial",
"num_updates",
"k",
"total_time_s",
"add_total_s",
"search_time_s",
"emb_time_s",
"makespan_s",
"model_name",
"embedding_mode",
"distance_metric",
]
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
# Debug: list existing HNSW server PIDs before starting
try:
existing = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if existing:
print("[debug] Found existing hnsw_embedding_server processes before run:")
for p in existing:
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
except Exception as _e:
pass
add_total = 0.0
search_after_add = 0.0
total_seq = 0.0
port_a = None
if args.only in ("A", "both"):
# Scenario A: sequential update then search
start_id = index.ntotal
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
if assigned_ids:
logger.debug(
"Registered %d update passages starting at id %s",
len(assigned_ids),
assigned_ids[0],
)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
ok, port = server_manager.start_server(
port=args.server_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok:
raise RuntimeError("Failed to start embedding server")
try:
# Set ZMQ port for recompute mode
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(port)
# Start A overall timer BEFORE computing update embeddings
t0 = time.time()
# Compute embeddings for updates (counted into A's overall)
t_emb0 = time.time()
upd_embs = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time_updates = time.time() - t_emb0
upd_embs = np.asarray(upd_embs, dtype=np.float32)
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
# Perform sequential adds
for i in range(upd_embs.shape[0]):
t_add0 = time.time()
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
add_total += time.time() - t_add0
# Don't persist index after adds to avoid contaminating Scenario B
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
# faiss.write_index(index, str(index_file))
# Search after updates
q_emb = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_emb = np.asarray(q_emb, dtype=np.float32)
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
# Warm up search with a dummy query first
print("[DEBUG] Warming up search...")
_ = _search(index, q_emb, 1)
t_s0 = time.time()
D_upd, I_upd = _search(index, q_emb, args.k)
search_after_add = time.time() - t_s0
total_seq = time.time() - t0
finally:
server_manager.stop_server()
port_a = port
print("\n=== Scenario A: update->search (sequential) ===")
# emb_time_updates is defined only when A runs
try:
_emb_a = emb_time_updates
except NameError:
_emb_a = 0.0
print(
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
)
# CSV row for A
if args.csv_path:
row_a = {
"run_id": run_id,
"scenario": "A",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": round(total_seq, 6),
"add_total_s": round(add_total, 6),
"search_time_s": round(search_after_add, 6),
"emb_time_s": round(_emb_a, 6),
"makespan_s": 0.0,
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_a)
# Verify server cleanup
try:
# short sleep to allow signal handling to finish
time.sleep(0.5)
leftovers = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if leftovers:
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
for p in leftovers:
print(
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
)
else:
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
except Exception:
pass
# Scenario B: offline embeds + concurrent search (no graph updates)
if args.only in ("B", "both"):
# ensure a server is available for recompute search
server_manager_b = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
requested_port = args.server_port if port_a is None else port_a
ok_b, port_b = server_manager_b.start_server(
port=requested_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok_b:
raise RuntimeError("Failed to start embedding server for Scenario B")
# Wait for server to fully initialize
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
time.sleep(2)
try:
# Read the index first
index_no_update = _read_index_for_search(args.index_path) # unchanged index
# Then configure ZMQ port on the correct index object
if hasattr(index_no_update.hnsw, "set_zmq_port"):
index_no_update.hnsw.set_zmq_port(port_b)
elif hasattr(index_no_update, "set_zmq_port"):
index_no_update.set_zmq_port(port_b)
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
logger.info("Warming up embedding model for Scenario B...")
_ = compute_embeddings(
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
)
# Prepare worker A: compute embeddings for the same N passages
emb_time = 0.0
updates_embs_offline: np.ndarray | None = None
def _worker_emb():
nonlocal emb_time, updates_embs_offline
t = time.time()
updates_embs_offline = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time = time.time() - t
# Pre-compute query embedding and warm up search outside of timed section.
q_vec = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_vec = np.asarray(q_vec, dtype=np.float32)
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
print("[DEBUG B] Warming up search...")
_ = _search(index_no_update, q_vec, 1)
# Worker B: timed search on the warmed index
search_time = 0.0
offline_elapsed = 0.0
index_results: tuple[np.ndarray, np.ndarray] | None = None
def _worker_search():
nonlocal search_time, index_results
t = time.time()
distances, indices = _search(index_no_update, q_vec, args.k)
search_time = time.time() - t
index_results = (distances, indices)
# Run two workers concurrently
t0 = time.time()
th1 = threading.Thread(target=_worker_emb)
th2 = threading.Thread(target=_worker_search)
th1.start()
th2.start()
th1.join()
th2.join()
offline_elapsed = time.time() - t0
# For mixing: compute query vs. offline update similarities (pure client-side)
offline_scores: list[tuple[int, float]] = []
if updates_embs_offline is not None:
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
for j in range(upd2.shape[0]):
if args.distance_metric in ("mips", "cosine"):
s = float(np.dot(q_vec[0], upd2[j]))
else:
diff = q_vec[0] - upd2[j]
s = -float(np.dot(diff, diff))
offline_scores.append((j, s))
merged_topk = (
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
if index_results
else []
)
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
print(
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
)
if merged_topk:
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
print(f"Merged top-5 preview: {preview}")
# CSV row for B
if args.csv_path:
row_b = {
"run_id": run_id,
"scenario": "B",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": 0.0,
"add_total_s": 0.0,
"search_time_s": round(search_time, 6),
"emb_time_s": round(emb_time, 6),
"makespan_s": round(offline_elapsed, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_b)
finally:
server_manager_b.stop_server()
# Summary
print("\n=== Summary ===")
msg_a = (
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
if args.only in ("A", "both")
else "A: skipped"
)
msg_b = (
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
if args.only in ("B", "both")
else "B: skipped"
)
print(msg_a + "\n" + msg_b)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario max_initial num_updates k total_time_s add_total_s search_time_s emb_time_s makespan_s model_name embedding_mode distance_metric
2 20251024-141607 A 300 1 10 3.273957 3.050168 0.097825 0.017339 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-141607 B 300 1 10 0.0 0.0 0.111892 0.007869 0.112635 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251025-160652 A 300 5 10 5.061945 4.805962 0.123271 0.015008 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251025-160652 B 300 5 10 0.0 0.0 0.101809 0.008817 0.102447 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -0,0 +1,645 @@
#!/usr/bin/env python3
"""
Plot latency bars from the benchmark CSV produced by
benchmarks/update/bench_hnsw_rng_recompute.py.
If you also provide an offline_vs_update.csv via --csv-right
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
output a side-by-side figure:
- Left: ms/passage bars (four RNG scenarios).
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
Usage:
uv run python benchmarks/update/plot_bench_results.py \
--csv benchmarks/update/bench_results.csv \
--out benchmarks/update/bench_latency_from_csv.png
The script selects the latest run_id in the CSV and plots four bars for
the default scenarios:
- baseline
- no_cache_baseline
- disable_forward_rng
- disable_forward_and_reverse_rng
If multiple rows exist per scenario for that run_id, the script averages
their latency_ms_per_passage values.
"""
import argparse
import csv
from collections import defaultdict
from pathlib import Path
DEFAULT_SCENARIOS = [
"no_cache_baseline",
"baseline",
"disable_forward_rng",
"disable_forward_and_reverse_rng",
]
SCENARIO_LABELS = {
"baseline": "+ Cache",
"no_cache_baseline": "Naive \n Recompute",
"disable_forward_rng": "+ w/o \n Fwd RNG",
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
}
# Paper-style colors and hatches for scenarios
SCENARIO_STYLES = {
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
}
def load_latest_run(csv_path: Path):
rows = []
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
if not rows:
raise SystemExit("CSV is empty: no rows to plot")
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
run_ids = [r.get("run_id", "") for r in rows]
latest = max(run_ids)
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
if not latest_rows:
# Fallback: take last 4 rows
latest_rows = rows[-4:]
latest = latest_rows[-1].get("run_id", "unknown")
return latest, latest_rows
def aggregate_latency(rows):
acc = defaultdict(list)
for r in rows:
sc = r.get("scenario", "")
try:
val = float(r.get("latency_ms_per_passage", "nan"))
except ValueError:
continue
acc[sc].append(val)
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
return avg
def _auto_cap(values: list[float]) -> float | None:
if not values:
return None
sorted_vals = sorted(values, reverse=True)
if len(sorted_vals) < 2:
return None
max_v, second = sorted_vals[0], sorted_vals[1]
if second <= 0:
return None
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
if max_v >= 2.5 * second:
return second * 1.1
return None
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
# Draw small diagonal ticks near left/right to signal cap
x0, x1 = rel_x0, rel_x1
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
def _fmt_ms(v: float) -> str:
if v >= 1000:
return f"{v / 1000:.1f}k"
return f"{v:.1f}"
def main():
# Set LaTeX style for paper figures (matching paper_fig.py)
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument(
"--csv",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Path to results CSV (defaults to bench_results.csv)",
)
ap.add_argument(
"--out",
type=Path,
default=Path("add_ablation.pdf"),
help="Output image path",
)
ap.add_argument(
"--csv-right",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
)
ap.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
)
ap.add_argument(
"--no-auto-cap",
action="store_true",
help="Disable auto-cap heuristic when --cap-y is not provided.",
)
ap.add_argument(
"--broken-y",
action="store_true",
default=True,
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
)
ap.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
)
ap.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
)
args = ap.parse_args()
latest_run, latest_rows = load_latest_run(args.csv)
avg = aggregate_latency(latest_rows)
try:
import matplotlib.pyplot as plt
except Exception as e:
raise SystemExit(f"matplotlib not available: {e}")
scenarios = DEFAULT_SCENARIOS
values = [avg.get(name, 0.0) for name in scenarios]
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
# If right CSV is provided, build side-by-side figure
if args.csv_right is not None:
try:
right_rows_all = []
with args.csv_right.open("r", encoding="utf-8") as f:
rreader = csv.DictReader(f)
right_rows_all = list(rreader)
if right_rows_all:
r_latest = max(r.get("run_id", "") for r in right_rows_all)
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
else:
r_latest = None
right_rows = []
except Exception:
r_latest = None
right_rows = []
a_total = 0.0
b_makespan = 0.0
for r in right_rows:
sc = (r.get("scenario", "") or "").strip().upper()
if sc == "A":
try:
a_total = float(r.get("total_time_s", 0.0))
except Exception:
pass
elif sc == "B":
try:
b_makespan = float(r.get("makespan_s", 0.0))
except Exception:
pass
import matplotlib.pyplot as plt
from matplotlib import gridspec
# Left subplot (reuse current style, with optional cap)
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
x = list(range(len(labels)))
if args.broken_y:
# Use broken axis for left subplot
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
gs = gridspec.GridSpec(
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
)
ax_left_top = fig.add_subplot(gs[0, 0])
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
ax_right = fig.add_subplot(gs[:, 1])
# Determine break points
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = (
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
) # Increased to show more range
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.5, lower_cap * 1.02)
)
ymax = (
max(values) * 1.90 if values else 1.0
) # Increase headroom to 1.90 for text label and tick range
# Draw bars on both axes
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Set limits
ax_left_bottom.set_ylim(0, lower_cap)
ax_left_top.set_ylim(upper_start, ymax)
# Annotate values (convert ms to s)
values_s = [v / 1000.0 for v in values]
lower_cap_s = lower_cap / 1000.0
upper_start_s = upper_start / 1000.0
ymax_s = ymax / 1000.0
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
ax_left_bottom.clear()
ax_left_top.clear()
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
# Draw in bottom axis for all bars
ax_left_bottom.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
# Only draw in top axis if the bar is tall enough to reach the upper range
if v > upper_start_s:
ax_left_top.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
for i, v in enumerate(values_s):
if v <= lower_cap_s:
ax_left_bottom.text(
i,
v + lower_cap_s * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
else:
ax_left_top.text(
i,
v + (ymax_s - upper_start_s) * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
# Hide spines between axes
ax_left_top.spines["bottom"].set_visible(False)
ax_left_bottom.spines["top"].set_visible(False)
ax_left_top.tick_params(
labeltop=False, labelbottom=False, bottom=False
) # Hide tick marks
ax_left_bottom.xaxis.tick_bottom()
ax_left_bottom.tick_params(top=False) # Hide top tick marks
# Draw break marks (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_left_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_left_bottom.transAxes})
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.set_xticks(x)
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_left_bottom.tick_params(axis="y", labelsize=10)
ax_left_top.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match bar width with right subplot
ax_left_bottom.set_xlim(-0.6, 3.6)
ax_left_top.set_xlim(-0.6, 3.6)
ax_left = ax_left_bottom # for compatibility
else:
# Regular side-by-side layout
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
for i, (val, show) in enumerate(zip(values, show_vals)):
if val > cap:
bars[i].set_hatch("//")
ax_left.text(
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
)
else:
ax_left.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(val),
ha="center",
va="bottom",
fontsize=9,
)
ax_left.set_ylim(0, cap * 1.10)
_add_break_marker(ax_left, y=0.98)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
else:
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
for i, v in enumerate(values):
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
ax_left.set_ylabel("Latency (ms per passage)")
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
ax_left.set_title(
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
)
# Right subplot (A vs B, seconds) - paper style
r_labels = ["Sequential", "Delayed \n Add+Search"]
r_values = [a_total or 0.0, b_makespan or 0.0]
r_styles = [
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
{"edgecolor": "#edc948", "hatch": "/////"},
]
# 2 bars, centered with proper spacing
xr = [0, 1]
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (v, style) in enumerate(zip(r_values, r_styles)):
ax_right.bar(
xr[i],
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
for i, v in enumerate(r_values):
max_v = max(r_values) if r_values else 1.0
offset = max(0.0002, 0.02 * max_v)
ax_right.text(
xr[i],
v + offset,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
ax_right.set_xticks(xr)
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_right.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match left subplot's bar width visually
# Accounting for width_ratios=[1.5, 1]:
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
# Right: 2 bars, need same visual width
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
# range_right = 4.2 / 1.5 = 2.8
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
ax_right.set_xlim(-0.9, 1.9)
# Set y-axis limit with headroom for text labels
if r_values:
max_v = max(r_values)
ax_right.set_ylim(0, max_v * 1.15)
# Format y-axis to avoid scientific notation
ax_right.ticklabel_format(style="plain", axis="y")
plt.tight_layout()
# Add aligned ylabels using fig.text (after tight_layout)
# Get the vertical center of the entire figure
fig_center_y = 0.5
# Left ylabel - closer to left plot
left_x = 0.05
fig.text(
left_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
# Right ylabel - closer to right plot
right_bbox = ax_right.get_position()
right_x = right_bbox.x0 - 0.07
fig.text(
right_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
return
# Broken-Y mode
if args.broken_y:
import matplotlib.pyplot as plt
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.5, 6.75),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
)
# Determine default breaks from second-highest
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Limits
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
# Annotate values
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
# Hide spines between axes and draw diagonal break marks
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
ax_bottom.xaxis.tick_bottom()
# Diagonal lines at the break (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
ax_bottom.set_xticks(x)
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
ax = ax_bottom # for labeling below
else:
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
plt.figure(figsize=(5.4, 3.15))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
bar = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(bar[0])
# Hatch and annotate when capped
if val > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
_add_break_marker(ax, y=0.98)
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
v > cap for v in values
) else None
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(
idx,
val + 1.0,
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
# Try to extract some context for title
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
if args.broken_y:
fig.text(
0.02,
0.5,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
fig.suptitle(
"Add Operation Latency",
fontsize=11,
y=0.98,
fontweight="bold",
)
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
else:
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
plt.tight_layout()
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
if __name__ == "__main__":
main()

View File

@@ -158,6 +158,95 @@ builder.build_index("./indexes/my-notes", chunks)
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Optional Embedding Features
### Task-Specific Prompt Templates
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
- **Indexing documents**: `"title: none | text: "`
- **Search queries**: `"task: search result | query: "`
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
```bash
# Build index with EmbeddingGemma (via LM Studio or Ollama)
leann build my-docs \
--docs ./documents \
--embedding-mode openai \
--embedding-model text-embedding-embeddinggemma-300m-qat \
--embedding-api-base http://localhost:1234/v1 \
--embedding-prompt-template "title: none | text: " \
--force
# Search with query-specific prompt
leann search my-docs \
--query "What is quantum computing?" \
--embedding-prompt-template "task: search result | query: "
```
**Important Notes:**
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
**Python API:**
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
embedding_mode="openai",
embedding_model="text-embedding-embeddinggemma-300m-qat",
embedding_options={
"base_url": "http://localhost:1234/v1",
"api_key": "lm-studio",
"prompt_template": "title: none | text: ",
},
)
builder.build_index("./indexes/my-docs", chunks)
```
**References:**
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
### LM Studio Auto-Detection (Optional)
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
**Prerequisites:**
```bash
# Install Node.js (if not already installed)
# Then install the LM Studio SDK globally
npm install -g @lmstudio/sdk
```
**How it works:**
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
2. Queries model metadata via Node.js subprocess
3. Automatically unloads model after query (respects your JIT auto-evict settings)
4. Falls back to static registry if SDK unavailable
**No configuration needed** - it works automatically when SDK is installed:
```bash
leann build my-docs \
--docs ./documents \
--embedding-mode openai \
--embedding-model text-embedding-nomic-embed-text-v1.5 \
--embedding-api-base http://localhost:1234/v1
# Context length auto-detected if SDK available
# Falls back to registry (2048) if not
```
**Benefits:**
- ✅ Automatic token limit detection
- ✅ Respects LM Studio JIT auto-evict settings
- ✅ No manual registry maintenance
- ✅ Graceful fallback if SDK unavailable
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)

View File

@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
## 2. When should I use prompt templates?
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
**Example usage with EmbeddingGemma:**
```bash
# Build with document prompt
leann build my-docs --embedding-prompt-template "title: none | text: "
# Search with query prompt
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
```
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
## 3. Why is LM Studio loading multiple copies of my model?
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
**If you still see duplicates:**
- Update to the latest LEANN version
- Restart LM Studio to clear loaded models
- Check that you have JIT auto-evict enabled in LM Studio settings
**How it works now:**
1. LEANN loads model temporarily to get context length
2. Immediately unloads after query
3. LM Studio JIT loads model on-demand for actual embeddings
4. Auto-evicts per your settings
## 4. Do I need Node.js and @lmstudio/sdk?
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
**Benefits if you install it:**
- Automatic context length detection for LM Studio models
- No manual registry maintenance
- Always gets accurate token limits from the model itself
**To install (optional):**
```bash
npm install -g @lmstudio/sdk
```
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.

395
docs/slack-setup-guide.md Normal file
View File

@@ -0,0 +1,395 @@
# Slack Integration Setup Guide
This guide provides step-by-step instructions for setting up Slack integration with LEANN.
## Overview
LEANN's Slack integration uses MCP (Model Context Protocol) servers to fetch and index your Slack messages for RAG (Retrieval-Augmented Generation). This allows you to search through your Slack conversations using natural language queries.
## Prerequisites
1. **Slack Workspace Access**: You need admin or owner permissions in your Slack workspace to create apps and configure OAuth tokens.
2. **Slack MCP Server**: Install a Slack MCP server (e.g., `slack-mcp-server` via npm)
3. **LEANN**: Ensure you have LEANN installed and working
## Step 1: Create a Slack App
### 1.1 Go to Slack API Dashboard
1. Visit [https://api.slack.com/apps](https://api.slack.com/apps)
2. Click **"Create New App"**
3. Choose **"From scratch"**
4. Enter your app name (e.g., "LEANN Slack Integration")
5. Select your workspace
6. Click **"Create App"**
### 1.2 Configure App Permissions
#### Token Scopes
1. In your app dashboard, go to **"OAuth & Permissions"** in the left sidebar
2. Scroll down to **"Scopes"** section
3. Under **"Bot Token Scopes & OAuth Scope"**, click **"Add an OAuth Scope"**
4. Add the following scopes:
- `channels:read` - Read public channel information
- `channels:history` - Read messages in public channels
- `groups:read` - Read private channel information
- `groups:history` - Read messages in private channels
- `im:read` - Read direct message information
- `im:history` - Read direct messages
- `mpim:read` - Read group direct message information
- `mpim:history` - Read group direct messages
- `users:read` - Read user information
- `team:read` - Read workspace information
#### App-Level Tokens (Optional)
Some MCP servers may require app-level tokens:
1. Go to **"Basic Information"** in the left sidebar
2. Scroll down to **"App-Level Tokens"**
3. Click **"Generate Token and Scopes"**
4. Enter a name (e.g., "LEANN Integration")
5. Add the `connections:write` scope
6. Click **"Generate"**
7. Copy the token (starts with `xapp-`)
### 1.3 Install App to Workspace
1. Go to **"OAuth & Permissions"** in the left sidebar
2. Click **"Install to Workspace"**
3. Review the permissions and click **"Allow"**
4. Copy the **"Bot User OAuth Token"** (starts with `xoxb-`)
5. Copy the **"User OAuth Token"** (starts with `xoxp-`)
## Step 2: Install Slack MCP Server
### Option A: Using npm (Recommended)
```bash
# Install globally
npm install -g slack-mcp-server
# Or install locally
npm install slack-mcp-server
```
### Option B: Using npx (No installation required)
```bash
# Use directly without installation
npx slack-mcp-server
```
## Step 3: Install and Configure Ollama (for Real LLM Responses)
### 3.1 Install Ollama
```bash
# Install Ollama using Homebrew (macOS)
brew install ollama
# Or download from https://ollama.ai/
```
### 3.2 Start Ollama Service
```bash
# Start Ollama as a service
brew services start ollama
# Or start manually
ollama serve
```
### 3.3 Pull a Model
```bash
# Pull a lightweight model for testing
ollama pull llama3.2:1b
# Verify the model is available
ollama list
```
## Step 4: Configure Environment Variables
Create a `.env` file or set environment variables:
```bash
# Required: User OAuth Token
SLACK_OAUTH_TOKEN=xoxp-your-user-oauth-token-here
# Optional: App-Level Token (if your MCP server requires it)
SLACK_APP_TOKEN=xapp-your-app-token-here
# Optional: Workspace-specific settings
SLACK_WORKSPACE_ID=T1234567890 # Your workspace ID (optional)
```
## Step 5: Test the Setup
### 5.1 Test MCP Server Connection
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--test-connection \
--workspace-name "Your Workspace Name"
```
This will test the connection and list available tools without indexing any data.
### 5.2 Index a Specific Channel
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Your Workspace Name" \
--channels general \
--query "What did we discuss about the project?"
```
### 5.3 Real RAG Query Examples
This section demonstrates successful Slack RAG integration queries against the Sky Lab Computing workspace's "random" channel. The system successfully retrieves actual conversation messages and performs semantic search with high relevance scores, including finding specific research paper announcements and technical discussions.
### Example 1: Advisor Models Query
**Query:** "train black-box models to adopt to your personal data"
This query demonstrates the system's ability to find specific research announcements about training black-box models for personal data adaptation.
![Advisor Models Query - Command Setup](videos/slack_integration_1.1.png)
![Advisor Models Query - Search Results](videos/slack_integration_1.2.png)
![Advisor Models Query - LLM Response](videos/slack_integration_1.3.png)
### Example 2: Barbarians at the Gate Query
**Query:** "AI-driven research systems ADRS"
This query demonstrates the system's ability to find specific research announcements about AI-driven research systems and algorithm discovery.
![Barbarians Query - Command Setup](videos/slack_integration_2.1.png)
![Barbarians Query - Search Results](videos/slack_integration_2.2.png)
![Barbarians Query - LLM Response](videos/slack_integration_2.3.png)
### Prerequisites
- Bot is installed in the Sky Lab Computing workspace and invited to the target channel (run `/invite @YourBotName` in the channel if needed)
- Bot token available and exported in the same terminal session
### Commands
1) Set the workspace token for this shell
```bash
export SLACK_MCP_XOXP_TOKEN="xoxp-***-redacted-***"
```
2) Run queries against the "random" channel by channel ID (C0GN5BX0F)
**Advisor Models Query:**
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Sky Lab Computing" \
--channels C0GN5BX0F \
--max-messages-per-channel 100000 \
--query "train black-box models to adopt to your personal data" \
--llm ollama \
--llm-model "llama3.2:1b" \
--llm-host "http://localhost:11434" \
--no-concatenate-conversations
```
**Barbarians at the Gate Query:**
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Sky Lab Computing" \
--channels C0GN5BX0F \
--max-messages-per-channel 100000 \
--query "AI-driven research systems ADRS" \
--llm ollama \
--llm-model "llama3.2:1b" \
--llm-host "http://localhost:11434" \
--no-concatenate-conversations
```
These examples demonstrate the system's ability to find and retrieve specific research announcements and technical discussions from the conversation history, showcasing the power of semantic search in Slack data.
3) Optional: Ask a broader question
```bash
python test_channel_by_id_or_name.py \
--channel-id C0GN5BX0F \
--workspace-name "Sky Lab Computing" \
--query "What is LEANN about?"
```
Notes:
- If you see `not_in_channel`, invite the bot to the channel and re-run.
- If you see `channel_not_found`, confirm the channel ID and workspace.
- Deep search via server-side “search” tools may require additional Slack scopes; the example above performs client-side filtering over retrieved history.
## Common Issues and Solutions
### Issue 1: "users cache is not ready yet" Error
**Problem**: You see this warning:
```
WARNING - Failed to fetch messages from channel random: Failed to fetch messages: {'code': -32603, 'message': 'users cache is not ready yet, sync process is still running... please wait'}
```
**Solution**: This is a common timing issue. The LEANN integration now includes automatic retry logic:
1. **Wait and Retry**: The system will automatically retry with exponential backoff (2s, 4s, 8s, etc.)
2. **Increase Retry Parameters**: If needed, you can customize retry behavior:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--max-retries 10 \
--retry-delay 3.0 \
--channels general \
--query "Your query here"
```
3. **Keep MCP Server Running**: Start the MCP server separately and keep it running:
```bash
# Terminal 1: Start MCP server
slack-mcp-server
# Terminal 2: Run LEANN (it will connect to the running server)
python -m apps.slack_rag --mcp-server "slack-mcp-server" --channels general --query "test"
```
### Issue 2: "No message fetching tool found"
**Problem**: The MCP server doesn't have the expected tools.
**Solution**:
1. Check if your MCP server is properly installed and configured
2. Verify your Slack tokens are correct
3. Try a different MCP server implementation
4. Check the MCP server documentation for required configuration
### Issue 3: Permission Denied Errors
**Problem**: You get permission errors when trying to access channels.
**Solutions**:
1. **Check Bot Permissions**: Ensure your bot has been added to the channels you want to access
2. **Verify Token Scopes**: Make sure you have all required scopes configured
3. **Channel Access**: For private channels, the bot needs to be explicitly invited
4. **Workspace Permissions**: Ensure your Slack app has the necessary workspace permissions
### Issue 4: Empty Results
**Problem**: No messages are returned even though the channel has messages.
**Solutions**:
1. **Check Channel Names**: Ensure channel names are correct (without the # symbol)
2. **Verify Bot Access**: Make sure the bot can access the channels
3. **Check Date Ranges**: Some MCP servers have limitations on message history
4. **Increase Message Limits**: Try increasing the message limit:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--channels general \
--max-messages-per-channel 1000 \
--query "test"
```
## Advanced Configuration
### Custom MCP Server Commands
If you need to pass additional parameters to your MCP server:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server --token-file /path/to/tokens.json" \
--workspace-name "Your Workspace" \
--channels general \
--query "Your query"
```
### Multiple Workspaces
To work with multiple Slack workspaces, you can:
1. Create separate apps for each workspace
2. Use different environment variables
3. Run separate instances with different configurations
### Performance Optimization
For better performance with large workspaces:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Your Workspace" \
--max-messages-per-channel 500 \
--no-concatenate-conversations \
--query "Your query"
```
---
## Troubleshooting Checklist
- [ ] Slack app created with proper permissions
- [ ] Bot token (xoxb-) copied correctly
- [ ] App-level token (xapp-) created if needed
- [ ] MCP server installed and accessible
- [ ] Ollama installed and running (`brew services start ollama`)
- [ ] Ollama model pulled (`ollama pull llama3.2:1b`)
- [ ] Environment variables set correctly
- [ ] Bot invited to relevant channels
- [ ] Channel names specified without # symbol
- [ ] Sufficient retry attempts configured
- [ ] Network connectivity to Slack APIs
## Getting Help
If you continue to have issues:
1. **Check Logs**: Look for detailed error messages in the console output
2. **Test MCP Server**: Use `--test-connection` to verify the MCP server is working
3. **Verify Tokens**: Double-check that your Slack tokens are valid and have the right scopes
4. **Check Ollama**: Ensure Ollama is running (`ollama serve`) and the model is available (`ollama list`)
5. **Community Support**: Reach out to the LEANN community for help
## Example Commands
### Basic Usage
```bash
# Test connection
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
# Index specific channels
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "My Company" \
--channels general random \
--query "What did we decide about the project timeline?"
```
### Advanced Usage
```bash
# With custom retry settings
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "My Company" \
--channels general \
--max-retries 10 \
--retry-delay 5.0 \
--max-messages-per-channel 2000 \
--query "Show me all decisions made in the last month"
```

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 445 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 508 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 437 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 474 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 501 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 454 KiB

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.3.4"
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
version = "0.3.5"
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build]
# Key: simplified CMake path

View File

@@ -29,12 +29,25 @@ if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
endif()
# Use system ZeroMQ instead of building from source
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
find_package(PkgConfig REQUIRED)
pkg_check_modules(ZMQ REQUIRED libzmq)
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
endif()
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
# This creates PkgConfig::ZMQ target automatically with correct properties
if(TARGET PkgConfig::ZMQ)
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
else()
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
endif()
# Add cppzmq headers
include_directories(third_party/cppzmq)
include_directories(SYSTEM third_party/cppzmq)
# Configure msgpack-c - disable boost dependency
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)

View File

@@ -215,6 +215,8 @@ class HNSWSearcher(BaseSearcher):
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
if hasattr(self._index, "set_zmq_port"):
self._index.set_zmq_port(zmq_port)
if query.dtype != np.float32:
query = query.astype(np.float32)

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.3.4"
version = "0.3.5"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.3.4",
"leann-core==0.3.5",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.3.4"
version = "0.3.5"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -820,10 +820,10 @@ class LeannBuilder:
actual_port,
requested_zmq_port,
)
try:
index.hnsw.zmq_port = actual_port
except AttributeError:
pass
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
if needs_recompute:
for i in range(embeddings.shape[0]):
@@ -916,6 +916,7 @@ class LeannSearcher:
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
provider_options: Optional[dict[str, Any]] = None,
**kwargs,
) -> list[SearchResult]:
"""
@@ -979,10 +980,24 @@ class LeannSearcher:
start_time = time.time()
# Extract query template from stored embedding_options with fallback chain:
# 1. Check provider_options override (highest priority)
# 2. Check query_prompt_template (new format)
# 3. Check prompt_template (old format for backward compat)
# 4. None (no template)
query_template = None
if provider_options and "prompt_template" in provider_options:
query_template = provider_options["prompt_template"]
elif "query_prompt_template" in self.embedding_options:
query_template = self.embedding_options["query_prompt_template"]
elif "prompt_template" in self.embedding_options:
query_template = self.embedding_options["prompt_template"]
query_embedding = self.backend_impl.compute_query_embedding(
query,
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
query_template=query_template,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
@@ -1236,6 +1251,17 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge."
)
print("The context provided to the LLM is:")
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
print("-" * 150)
for r in results:
chunk_relevance = f"{r.score:.3f}"
chunk_id = r.id
chunk_content = r.text[:60]
chunk_source = r.metadata.get("source", "")[:80]
print(
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
)
ask_time = time.time()
ans = self.llm.ask(prompt, **llm_kwargs)
ask_time = time.time() - ask_time

View File

@@ -834,6 +834,11 @@ class OpenAIChat(LLMInterface):
try:
response = self.client.chat.completions.create(**params)
print(
f"Total tokens = {response.usage.total_tokens}, prompt tokens = {response.usage.prompt_tokens}, completion tokens = {response.usage.completion_tokens}"
)
if response.choices[0].finish_reason == "length":
print("The query is exceeding the maximum allowed number of tokens")
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")

View File

@@ -5,12 +5,128 @@ Packaged within leann-core so installed wheels can import it reliably.
import logging
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__)
# Flag to ensure AST token warning only shown once per session
_ast_token_warning_shown = False
def estimate_token_count(text: str) -> int:
"""
Estimate token count for a text string.
Uses conservative estimation: ~4 characters per token for natural text,
~1.2 tokens per character for code (worse tokenization).
Args:
text: Input text to estimate tokens for
Returns:
Estimated token count
"""
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
return len(encoder.encode(text))
except ImportError:
# Fallback: Conservative character-based estimation
# Assume worst case for code: 1.2 tokens per character
return int(len(text) * 1.2)
def calculate_safe_chunk_size(
model_token_limit: int,
overlap_tokens: int,
chunking_mode: str = "traditional",
safety_factor: float = 0.9,
) -> int:
"""
Calculate safe chunk size accounting for overlap and safety margin.
Args:
model_token_limit: Maximum tokens supported by embedding model
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
chunking_mode: "traditional" (tokens) or "ast" (characters)
safety_factor: Safety margin (0.9 = 10% safety margin)
Returns:
Safe chunk size: tokens for traditional, characters for AST
"""
safe_limit = int(model_token_limit * safety_factor)
if chunking_mode == "traditional":
# Traditional chunking uses tokens
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
return max(1, safe_limit - overlap_tokens)
else: # AST chunking
# AST uses characters, need to convert
# Conservative estimate: 1.2 tokens per char for code
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
safe_chars = int(safe_limit / 1.2)
return max(1, safe_chars - overlap_chars)
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
"""
Validate that chunks don't exceed token limits and truncate if necessary.
Args:
chunks: List of text chunks to validate
max_tokens: Maximum tokens allowed per chunk
Returns:
Tuple of (validated_chunks, num_truncated)
"""
validated_chunks = []
num_truncated = 0
for i, chunk in enumerate(chunks):
estimated_tokens = estimate_token_count(chunk)
if estimated_tokens > max_tokens:
# Truncate chunk to fit token limit
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
tokens = encoder.encode(chunk)
if len(tokens) > max_tokens:
truncated_tokens = tokens[:max_tokens]
truncated_chunk = encoder.decode(truncated_tokens)
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
)
else:
validated_chunks.append(chunk)
except ImportError:
# Fallback: Conservative character truncation
char_limit = int(max_tokens / 1.2) # Conservative for code
if len(chunk) > char_limit:
truncated_chunk = chunk[:char_limit]
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
f"(conservative estimate for {max_tokens} tokens)"
)
else:
validated_chunks.append(chunk)
else:
validated_chunks.append(chunk)
if num_truncated > 0:
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
return validated_chunks, num_truncated
# Code file extensions supported by astchunk
CODE_EXTENSIONS = {
".py": "python",
@@ -61,27 +177,45 @@ def create_ast_chunks(
max_chunk_size: int = 512,
chunk_overlap: int = 64,
metadata_template: str = "default",
) -> list[str]:
) -> list[dict[str, Any]]:
"""Create AST-aware chunks from code documents using astchunk.
Falls back to traditional chunking if astchunk is unavailable.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
try:
from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
all_chunks = []
for doc in documents:
language = doc.metadata.get("language")
if not language:
logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
continue
try:
# Warn once if AST chunk size + overlap might exceed common token limits
# Note: Actual truncation happens at embedding time with dynamic model limits
global _ast_token_warning_shown
estimated_max_tokens = int(
(max_chunk_size + chunk_overlap) * 1.2
) # Conservative estimate
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
logger.warning(
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
)
_ast_token_warning_shown = True
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
@@ -105,17 +239,40 @@ def create_ast_chunks(
chunks = chunk_builder.chunkify(code_content)
for chunk in chunks:
chunk_text = None
astchunk_metadata = {}
if hasattr(chunk, "text"):
chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str):
chunk_text = chunk
elif isinstance(chunk, dict):
# Handle astchunk format: {"content": "...", "metadata": {...}}
if "content" in chunk:
chunk_text = chunk["content"]
astchunk_metadata = chunk.get("metadata", {})
elif "text" in chunk:
chunk_text = chunk["text"]
else:
chunk_text = str(chunk) # Last resort
else:
chunk_text = str(chunk)
if chunk_text and chunk_text.strip():
all_chunks.append(chunk_text.strip())
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Merge document metadata + astchunk metadata
combined_metadata = {**doc_metadata, **astchunk_metadata}
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
@@ -123,15 +280,19 @@ def create_ast_chunks(
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
return all_chunks
def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
) -> list[dict[str, Any]]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
@@ -147,19 +308,40 @@ def create_traditional_chunks(
paragraph_separator="\n\n",
)
all_texts = []
result = []
for doc in documents:
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
try:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
all_texts.extend(node.get_content() for node in nodes)
for node in nodes:
result.append({"text": node.get_content(), "metadata": doc_metadata})
except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}")
content = doc.get_content()
if content and content.strip():
all_texts.append(content.strip())
result.append({"text": content.strip(), "metadata": doc_metadata})
return all_texts
return result
def _traditional_chunks_as_dicts(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]:
"""Helper: Traditional chunking that returns dict format for consistency.
This is now just an alias for create_traditional_chunks for backwards compatibility.
"""
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
def create_text_chunks(
@@ -171,8 +353,12 @@ def create_text_chunks(
ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
) -> list[str]:
"""Create text chunks from documents with optional AST support for code files."""
) -> list[dict[str, Any]]:
"""Create text chunks from documents with optional AST support for code files.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if not documents:
logger.warning("No documents provided for chunking")
return []
@@ -207,14 +393,17 @@ def create_text_chunks(
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
all_chunks.extend(
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
)
else:
raise
if text_docs:
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
else:
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")
# Note: Token truncation is now handled at embedding time with dynamic model limits
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
return all_chunks

View File

@@ -1,5 +1,6 @@
import argparse
import asyncio
import time
from pathlib import Path
from typing import Any, Optional, Union
@@ -106,7 +107,7 @@ Examples:
help="Documents directories and/or files (default: current directory)",
)
build_parser.add_argument(
"--backend",
"--backend-name",
type=str,
default="hnsw",
choices=["hnsw", "diskann"],
@@ -143,6 +144,18 @@ Examples:
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
build_parser.add_argument(
"--embedding-prompt-template",
type=str,
default=None,
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
)
build_parser.add_argument(
"--query-prompt-template",
type=str,
default=None,
help="Prompt template for queries (different from build template for task-specific models)",
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index"
)
@@ -180,25 +193,25 @@ Examples:
"--doc-chunk-size",
type=int,
default=256,
help="Document chunk size in tokens/characters (default: 256)",
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
)
build_parser.add_argument(
"--doc-chunk-overlap",
type=int,
default=128,
help="Document chunk overlap (default: 128)",
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
)
build_parser.add_argument(
"--code-chunk-size",
type=int,
default=512,
help="Code chunk size in tokens/lines (default: 512)",
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
)
build_parser.add_argument(
"--code-chunk-overlap",
type=int,
default=50,
help="Code chunk overlap (default: 50)",
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
)
build_parser.add_argument(
"--use-ast-chunking",
@@ -208,14 +221,14 @@ Examples:
build_parser.add_argument(
"--ast-chunk-size",
type=int,
default=768,
help="AST chunk size in characters (default: 768)",
default=300,
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
)
build_parser.add_argument(
"--ast-chunk-overlap",
type=int,
default=96,
help="AST chunk overlap in characters (default: 96)",
default=64,
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
)
build_parser.add_argument(
"--ast-fallback-traditional",
@@ -254,6 +267,17 @@ Examples:
action="store_true",
help="Non-interactive mode: automatically select index without prompting",
)
search_parser.add_argument(
"--show-metadata",
action="store_true",
help="Display file paths and metadata in search results",
)
search_parser.add_argument(
"--embedding-prompt-template",
type=str,
default=None,
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
@@ -1156,6 +1180,11 @@ Examples:
print(f"Warning: Could not process {file_path}: {e}")
# Load other file types with default reader
# Exclude PDFs from code_extensions if they were already processed separately
other_file_extensions = code_extensions
if should_process_pdfs and ".pdf" in code_extensions:
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
try:
# Create a custom file filter function using our PathSpec
def file_filter(
@@ -1171,21 +1200,26 @@ Examples:
except (ValueError, OSError):
return True # Include files that can't be processed
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=code_extensions,
file_extractor={}, # Use default extractors
exclude_hidden=not include_hidden,
filename_as_id=True,
).load_data(show_progress=True)
# Only load other file types if there are extensions to process
if other_file_extensions:
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=other_file_extensions,
file_extractor={}, # Use default extractors
exclude_hidden=not include_hidden,
filename_as_id=True,
).load_data(show_progress=True)
else:
other_docs = []
# Filter documents after loading based on gitignore rules
filtered_docs = []
for doc in other_docs:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
doc.metadata["source"] = file_path
filtered_docs.append(doc)
documents.extend(filtered_docs)
@@ -1261,7 +1295,7 @@ Examples:
from .chunking_utils import create_text_chunks
# Use enhanced chunking with AST support
all_texts = create_text_chunks(
chunk_texts = create_text_chunks(
documents,
chunk_size=self.node_parser.chunk_size,
chunk_overlap=self.node_parser.chunk_overlap,
@@ -1272,6 +1306,9 @@ Examples:
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
)
# create_text_chunks now returns list[dict] with metadata preserved
all_texts.extend(chunk_texts)
except ImportError as e:
print(
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
@@ -1283,14 +1320,27 @@ Examples:
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
# Check if this is a code file based on source path
source_path = doc.metadata.get("source", "")
file_path = doc.metadata.get("file_path", "")
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
# Extract metadata to preserve with chunks
chunk_metadata = {
"file_path": file_path or source_path,
"file_name": doc.metadata.get("file_name", ""),
}
# Add optional metadata if available
if "creation_date" in doc.metadata:
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Use appropriate parser based on file type
parser = self.code_parser if is_code_file else self.node_parser
nodes = parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
return all_texts
@@ -1365,7 +1415,7 @@ Examples:
index_dir.mkdir(parents=True, exist_ok=True)
print(f"Building index '{index_name}' with {args.backend} backend...")
print(f"Building index '{index_name}' with {args.backend_name} backend...")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
@@ -1375,9 +1425,17 @@ Examples:
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
if args.query_prompt_template:
# New format: separate templates
if args.embedding_prompt_template:
embedding_options["build_prompt_template"] = args.embedding_prompt_template
embedding_options["query_prompt_template"] = args.query_prompt_template
elif args.embedding_prompt_template:
# Old format: single template (backward compat)
embedding_options["prompt_template"] = args.embedding_prompt_template
builder = LeannBuilder(
backend_name=args.backend,
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
@@ -1388,8 +1446,8 @@ Examples:
num_threads=args.num_threads,
)
for chunk_text in all_texts:
builder.add_text(chunk_text)
for chunk in all_texts:
builder.add_text(chunk["text"], metadata=chunk["metadata"])
builder.build_index(index_path)
print(f"Index built at {index_path}")
@@ -1496,6 +1554,11 @@ Examples:
print("Invalid input. Aborting search.")
return
# Build provider_options for runtime override
provider_options = {}
if args.embedding_prompt_template:
provider_options["prompt_template"] = args.embedding_prompt_template
searcher = LeannSearcher(index_path=index_path)
results = searcher.search(
query,
@@ -1505,12 +1568,31 @@ Examples:
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
provider_options=provider_options if provider_options else None,
)
print(f"Search results for '{query}' (top {len(results)}):")
for i, result in enumerate(results, 1):
print(f"{i}. Score: {result.score:.3f}")
# Display metadata if flag is set
if args.show_metadata and result.metadata:
file_path = result.metadata.get("file_path", "")
if file_path:
print(f" 📄 File: {file_path}")
file_name = result.metadata.get("file_name", "")
if file_name and file_name != file_path:
print(f" 📝 Name: {file_name}")
# Show timestamps if available
if "creation_date" in result.metadata:
print(f" 🕐 Created: {result.metadata['creation_date']}")
if "last_modified_date" in result.metadata:
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
print(f" {result.text[:200]}...")
print(f" Source: {result.metadata.get('source', '')}")
print()
async def ask_questions(self, args):
@@ -1542,6 +1624,7 @@ Examples:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
query_start_time = time.time()
response = chat.ask(
prompt,
top_k=args.top_k,
@@ -1552,7 +1635,9 @@ Examples:
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
query_completion_time = time.time() - query_start_time
print(f"LEANN: {response}")
print(f"The query took {query_completion_time:.3f} seconds to finish")
initial_query = (args.query or "").strip()

View File

@@ -4,12 +4,15 @@ Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import json
import logging
import os
import subprocess
import time
from typing import Any, Optional
import numpy as np
import tiktoken
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
@@ -20,6 +23,288 @@ LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Token limit registry for embedding models
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
# Ollama models use dynamic discovery via /api/show
EMBEDDING_MODEL_LIMITS = {
# Nomic models (common across servers)
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
"nomic-embed-text-v1.5": 2048,
"nomic-embed-text-v2": 512,
# Other embedding models
"mxbai-embed-large": 512,
"all-minilm": 512,
"bge-m3": 8192,
"snowflake-arctic-embed": 512,
# OpenAI models
"text-embedding-3-small": 8192,
"text-embedding-3-large": 8192,
"text-embedding-ada-002": 8192,
}
# Runtime cache for dynamically discovered token limits
# Key: (model_name, base_url), Value: token_limit
# Prevents repeated SDK/API calls for the same model
_token_limit_cache: dict[tuple[str, str], int] = {}
def get_model_token_limit(
model_name: str,
base_url: Optional[str] = None,
default: int = 2048,
) -> int:
"""
Get token limit for a given embedding model.
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
Caches discovered limits to prevent repeated API/SDK calls.
Args:
model_name: Name of the embedding model
base_url: Base URL of the embedding server (for dynamic discovery)
default: Default token limit if model not found
Returns:
Token limit for the model in tokens
"""
# Check cache first to avoid repeated SDK/API calls
cache_key = (model_name, base_url or "")
if cache_key in _token_limit_cache:
cached_limit = _token_limit_cache[cache_key]
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
return cached_limit
# Try Ollama dynamic discovery if base_url provided
if base_url:
# Detect Ollama servers by port or "ollama" in URL
if "11434" in base_url or "ollama" in base_url.lower():
limit = _query_ollama_context_limit(model_name, base_url)
if limit:
_token_limit_cache[cache_key] = limit
return limit
# Try LM Studio SDK discovery
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
# Convert HTTP to WebSocket URL
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
# Remove /v1 suffix if present
if ws_url.endswith("/v1"):
ws_url = ws_url[:-3]
limit = _query_lmstudio_context_limit(model_name, ws_url)
if limit:
_token_limit_cache[cache_key] = limit
return limit
# Fallback to known model registry with version handling (from PR #154)
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
base_model_name = model_name.split(":")[0]
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
limit = EMBEDDING_MODEL_LIMITS[model_name]
_token_limit_cache[cache_key] = limit
return limit
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
_token_limit_cache[cache_key] = limit
return limit
# Check partial matches for common patterns
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
_token_limit_cache[cache_key] = registry_limit
return registry_limit
# Default fallback
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
_token_limit_cache[cache_key] = default
return default
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
"""
Truncate texts to fit within token limit using tiktoken.
Args:
texts: List of text strings to truncate
token_limit: Maximum number of tokens allowed
Returns:
List of truncated texts (same length as input)
"""
if not texts:
return []
# Use tiktoken with cl100k_base encoding
enc = tiktoken.get_encoding("cl100k_base")
truncated_texts = []
truncation_count = 0
total_tokens_removed = 0
max_original_length = 0
for i, text in enumerate(texts):
tokens = enc.encode(text)
original_length = len(tokens)
if original_length <= token_limit:
# Text is within limit, keep as is
truncated_texts.append(text)
else:
# Truncate to token_limit
truncated_tokens = tokens[:token_limit]
truncated_text = enc.decode(truncated_tokens)
truncated_texts.append(truncated_text)
# Track truncation statistics
truncation_count += 1
tokens_removed = original_length - token_limit
total_tokens_removed += tokens_removed
max_original_length = max(max_original_length, original_length)
# Log individual truncation at WARNING level (first few only)
if truncation_count <= 3:
logger.warning(
f"Text {i + 1} truncated: {original_length}{token_limit} tokens "
f"({tokens_removed} tokens removed)"
)
elif truncation_count == 4:
logger.warning("Further truncation warnings suppressed...")
# Log summary at INFO level
if truncation_count > 0:
logger.warning(
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
)
else:
logger.debug(
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
)
return truncated_texts
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query Ollama /api/show for model context limit.
Args:
model_name: Name of the Ollama model
base_url: Base URL of the Ollama server
Returns:
Context limit in tokens if found, None otherwise
"""
try:
import requests
response = requests.post(
f"{base_url}/api/show",
json={"name": model_name},
timeout=5,
)
if response.status_code == 200:
data = response.json()
if "model_info" in data:
# Look for *.context_length in model_info
for key, value in data["model_info"].items():
if "context_length" in key and isinstance(value, int):
logger.info(f"Detected {model_name} context limit: {value} tokens")
return value
except Exception as e:
logger.debug(f"Failed to query Ollama context limit: {e}")
return None
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query LM Studio SDK for model context length via Node.js subprocess.
Args:
model_name: Name of the LM Studio model
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
Returns:
Context limit in tokens if found, None otherwise
"""
# Inline JavaScript using @lmstudio/sdk
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
js_code = f"""
const {{ LMStudioClient }} = require('@lmstudio/sdk');
(async () => {{
try {{
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
const contextLength = await model.getContextLength();
await model.unload(); // Unload immediately to respect JIT auto-evict settings
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
}} catch (error) {{
console.error(JSON.stringify({{ error: error.message }}));
process.exit(1);
}}
}})();
"""
try:
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
env = os.environ.copy()
# Try to get npm global root (works with nvm, brew node, etc.)
try:
npm_root = subprocess.run(
["npm", "root", "-g"],
capture_output=True,
text=True,
timeout=5,
)
if npm_root.returncode == 0:
global_modules = npm_root.stdout.strip()
# Append to existing NODE_PATH if present
existing_node_path = env.get("NODE_PATH", "")
env["NODE_PATH"] = (
f"{global_modules}:{existing_node_path}"
if existing_node_path
else global_modules
)
except Exception:
# If npm not available, continue with existing NODE_PATH
pass
result = subprocess.run(
["node", "-e", js_code],
capture_output=True,
text=True,
timeout=10,
env=env,
)
if result.returncode != 0:
logger.debug(f"LM Studio SDK error: {result.stderr}")
return None
data = json.loads(result.stdout)
context_length = data.get("contextLength")
if context_length and context_length > 0:
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
return context_length
except FileNotFoundError:
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
except subprocess.TimeoutExpired:
logger.debug("LM Studio SDK query timeout")
except json.JSONDecodeError:
logger.debug("LM Studio SDK returned invalid JSON")
except Exception as e:
logger.debug(f"LM Studio SDK query failed: {e}")
return None
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
@@ -67,6 +352,7 @@ def compute_embeddings(
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
provider_options=provider_options,
)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
@@ -76,6 +362,7 @@ def compute_embeddings(
model_name,
is_build=is_build,
host=provider_options.get("host"),
provider_options=provider_options,
)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
@@ -414,6 +701,7 @@ def compute_embeddings_openai(
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
@@ -432,26 +720,40 @@ def compute_embeddings_openai(
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
resolved_base_url = resolve_openai_base_url(base_url)
resolved_api_key = resolve_openai_api_key(api_key)
# Extract base_url and api_key from provider_options if not provided directly
provider_options = provider_options or {}
effective_base_url = base_url or provider_options.get("base_url")
effective_api_key = api_key or provider_options.get("api_key")
resolved_base_url = resolve_openai_base_url(effective_base_url)
resolved_api_key = resolve_openai_api_key(effective_api_key)
if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
# Create OpenAI client
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
print(f"len of texts: {len(texts)}")
# Apply prompt template if provided
# Priority: build_prompt_template (new format) > prompt_template (old format)
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
"prompt_template"
)
if prompt_template:
logger.warning(f"Applying prompt template: '{prompt_template}'")
texts = [f"{prompt_template}{text}" for text in texts]
# Query token limit and apply truncation
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
texts = truncate_to_token_limit(texts, token_limit)
# OpenAI has limits on batch size and input length
max_batch_size = 800 # Conservative batch size because the token limit is 300K
all_embeddings = []
@@ -482,7 +784,15 @@ def compute_embeddings_openai(
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
# Verify we got the expected number of embeddings
if len(batch_embeddings) != len(batch_texts):
logger.warning(
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
)
# Only take the number of embeddings that match the batch size
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
@@ -572,17 +882,20 @@ def compute_embeddings_ollama(
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
Compute embeddings using Ollama API with true batch processing.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Uses the /api/embed endpoint which supports batch inputs.
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
Args:
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (defaults to environment or http://localhost:11434)
provider_options: Optional provider-specific options (e.g., prompt_template)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -681,11 +994,11 @@ def compute_embeddings_ollama(
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it
# Verify the model supports embeddings by testing it with /api/embed
try:
test_response = requests.post(
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
f"{resolved_host}/api/embed",
json={"model": model_name, "input": "test"},
timeout=10,
)
if test_response.status_code != 200:
@@ -717,56 +1030,82 @@ def compute_embeddings_ollama(
# If torch is not available, use conservative batch size
batch_size = 32
logger.info(f"Using batch size: {batch_size}")
logger.info(f"Using batch size: {batch_size} for true batch processing")
# Apply prompt template if provided
provider_options = provider_options or {}
# Priority: build_prompt_template (new format) > prompt_template (old format)
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
"prompt_template"
)
if prompt_template:
logger.warning(f"Applying prompt template: '{prompt_template}'")
texts = [f"{prompt_template}{text}" for text in texts]
# Get model token limit and apply truncation before batching
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
logger.info(f"Model '{model_name}' token limit: {token_limit}")
# Apply truncation to all texts before batch processing
# Function logs truncation details internally
texts = truncate_to_token_limit(texts, token_limit)
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts."""
all_embeddings = []
failed_indices = []
"""Get embeddings for a batch of texts using /api/embed endpoint."""
max_retries = 3
retry_count = 0
for i, text in enumerate(batch_texts):
max_retries = 3
retry_count = 0
# Texts are already truncated to token limit by the outer function
while retry_count < max_retries:
try:
# Use /api/embed endpoint with "input" parameter for batch processing
response = requests.post(
f"{resolved_host}/api/embed",
json={"model": model_name, "input": batch_texts},
timeout=60, # Increased timeout for batch processing
)
response.raise_for_status()
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
try:
response = requests.post(
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
result = response.json()
batch_embeddings = result.get("embeddings")
if batch_embeddings is None:
raise ValueError("No embeddings returned from API")
if not isinstance(batch_embeddings, list):
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
if len(batch_embeddings) != len(batch_texts):
raise ValueError(
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
return batch_embeddings, []
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for batch after {max_retries} retries")
return None, list(range(len(batch_texts)))
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
# Enhanced error detection for token limit violations
error_msg = str(e).lower()
if "token" in error_msg and (
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
):
logger.error(
f"Token limit exceeded for batch. Error: {e}. "
f"Consider reducing chunk sizes or check token truncation."
)
else:
logger.error(f"Failed to get embeddings for batch: {e}")
return None, list(range(len(batch_texts)))
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
return None, list(range(len(batch_texts)))
# Process texts in batches
all_embeddings = []
@@ -784,7 +1123,7 @@ def compute_embeddings_ollama(
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
else:
batch_iterator = range(num_batches)
@@ -795,10 +1134,14 @@ def compute_embeddings_ollama(
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
all_embeddings.extend(batch_embeddings)
if batch_embeddings is not None:
all_embeddings.extend(batch_embeddings)
else:
# Entire batch failed, add None placeholders
all_embeddings.extend([None] * len(batch_texts))
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
# Handle failed embeddings
if all_failed_indices:

View File

@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
query_template: Optional[str] = None,
) -> np.ndarray:
"""Compute embedding for a query string
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
query_template: Optional prompt template to prepend to query
Returns:
Query embedding as numpy array with shape (1, D)

View File

@@ -60,6 +60,11 @@ def handle_request(request):
"maximum": 128,
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
},
"show_metadata": {
"type": "boolean",
"default": False,
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
},
},
"required": ["index_name", "query"],
},
@@ -104,6 +109,8 @@ def handle_request(request):
f"--complexity={args.get('complexity', 32)}",
"--non-interactive",
]
if args.get("show_metadata", False):
cmd.append("--show-metadata")
result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_list":

View File

@@ -33,6 +33,8 @@ def autodiscover_backends():
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata["name"]
if dist_name is None:
continue
if dist_name.startswith("leann-backend-"):
backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name)

View File

@@ -71,6 +71,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
or "mips"
)
# Filter out ALL prompt templates from provider_options during search
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
# The server should never apply templates during search to avoid double-templating
search_provider_options = {
k: v
for k, v in self.embedding_options.items()
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
}
server_started, actual_port = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
@@ -78,7 +87,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file,
distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options,
provider_options=search_provider_options,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -90,6 +99,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
query: str,
use_server_if_available: bool = True,
zmq_port: int = 5557,
query_template: Optional[str] = None,
) -> np.ndarray:
"""
Compute embedding for a query string.
@@ -98,10 +108,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
query_template: Optional prompt template to prepend to query
Returns:
Query embedding as numpy array
"""
# Apply query template BEFORE any computation path
# This ensures template is applied consistently for both server and fallback paths
if query_template:
query = f"{query_template}{query}"
# Try to use embedding server if available and requested
if use_server_if_available:
try:

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.3.4"
version = "0.3.5"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -57,6 +57,8 @@ dependencies = [
"tree-sitter-c-sharp>=0.20.0",
"tree-sitter-typescript>=0.20.0",
"torchvision>=0.23.0",
"einops",
"seaborn",
]
[project.optional-dependencies]
@@ -67,7 +69,8 @@ diskann = [
# Add a new optional dependency group for document processing
documents = [
"beautifulsoup4>=4.13.0", # For HTML parsing
"python-docx>=0.8.11", # For Word documents
"python-docx>=0.8.11", # For Word documents (creating/editing)
"docx2txt>=0.9", # For Word documents (text extraction)
"openpyxl>=3.1.0", # For Excel files
"pandas>=2.2.0", # For data processing
]
@@ -162,6 +165,7 @@ python_functions = ["test_*"]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"openai: marks tests that require OpenAI API key",
"integration: marks tests that require live services (Ollama, LM Studio, etc.)",
]
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
addopts = [

View File

@@ -36,6 +36,14 @@ Tests DiskANN graph partitioning functionality:
- Includes performance comparison between DiskANN (with partition) and HNSW
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
### `test_prompt_template_e2e.py`
Integration tests for prompt template feature with live embedding services:
- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio)
- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default)
- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk)
- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration`
- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models
## Running Tests
### Install test dependencies:
@@ -66,6 +74,12 @@ pytest tests/ -m "not openai"
# Skip slow tests
pytest tests/ -m "not slow"
# Skip integration tests (that require live services)
pytest tests/ -m "not integration"
# Run only integration tests (requires LM Studio or Ollama running)
pytest tests/test_prompt_template_e2e.py -v -s
# Run DiskANN partition tests (requires local machine, not CI)
pytest tests/test_diskann_partition.py
```
@@ -101,6 +115,20 @@ The `pytest.ini` file configures:
- Custom markers for slow and OpenAI tests
- Verbose output with short tracebacks
### Integration Test Prerequisites
Integration tests (`test_prompt_template_e2e.py`) require live services:
**Required:**
- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded
**Optional:**
- Ollama running at `http://localhost:11434` for token limit detection tests
- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests
Tests gracefully skip if services are unavailable.
### Known Issues
- OpenAI tests are automatically skipped if no API key is provided
- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed)

View File

@@ -8,7 +8,7 @@ import subprocess
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
@@ -116,8 +116,10 @@ class TestChunkingFunctions:
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(len(chunk.strip()) > 0 for chunk in chunks)
# Traditional chunks now return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks)
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
def test_create_traditional_chunks_empty_docs(self):
"""Test traditional chunking with empty documents."""
@@ -158,11 +160,22 @@ class Calculator:
# Should have multiple chunks due to different functions/classes
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(len(chunk.strip()) > 0 for chunk in chunks)
# R3: Expect dict format with "text" and "metadata" keys
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
"Each chunk text should be non-empty"
)
# Check metadata is present
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
"Each chunk should have file_path metadata"
)
# Check that code structure is somewhat preserved
combined_content = " ".join(chunks)
combined_content = " ".join([c["text"] for c in chunks])
assert "def hello_world" in combined_content
assert "class Calculator" in combined_content
@@ -194,7 +207,11 @@ class Calculator:
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
# R3: Traditional chunking should also return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_ast_mode(self):
"""Test text chunking in AST mode."""
@@ -213,7 +230,11 @@ class Calculator:
)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
# R3: AST mode should also return dict format
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_custom_extensions(self):
"""Test text chunking with custom code file extensions."""
@@ -353,6 +374,552 @@ class MathUtils:
pytest.skip("Test timed out - likely due to model download in CI")
class TestASTContentExtraction:
"""Test AST content extraction bug fix.
These tests verify that astchunk's dict format with 'content' key is handled correctly,
and that the extraction logic doesn't fall through to stringifying entire dicts.
"""
def test_extract_content_from_astchunk_dict(self):
"""Test that astchunk dict format with 'content' key is handled correctly.
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
This causes fallthrough to str(chunk), stringifying the entire dict.
This test will FAIL until the bug is fixed because:
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
- Fixed code should extract just the content value
"""
# Mock the ASTChunkBuilder class
mock_builder = Mock()
# Astchunk returns this format
astchunk_format_chunk = {
"content": "def hello():\n print('world')",
"metadata": {
"filepath": "test.py",
"line_count": 2,
"start_line_no": 0,
"end_line_no": 1,
"node_count": 1,
},
}
mock_builder.chunkify.return_value = [astchunk_format_chunk]
# Create mock document
doc = MockDocument(
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
)
# Mock the astchunk module and its ASTChunkBuilder class
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
# Patch sys.modules to inject our mock before the import
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should return dict format with proper metadata
assert len(chunks) > 0, "Should return at least one chunk"
# R3: Each chunk should be a dict
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "metadata" in chunk, "Chunk should have 'metadata' key"
chunk_text = chunk["text"]
# CRITICAL: Should NOT contain stringified dict markers in the text field
# These assertions will FAIL with current buggy code
assert "'content':" not in chunk_text, (
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
)
assert "'metadata':" not in chunk_text, (
"Chunk text contains stringified metadata - extraction failed! "
f"Got: {chunk_text[:100]}..."
)
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
"Chunk text appears to be a stringified dict"
)
# Should contain actual content
assert "def hello()" in chunk_text, "Should extract actual code content"
assert "print('world')" in chunk_text, "Should extract complete code content"
# R3: Should preserve astchunk metadata
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
"Should preserve file path metadata"
)
def test_extract_text_key_fallback(self):
"""Test that 'text' key still works for backward compatibility.
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
This test should PASS even with current code.
"""
mock_builder = Mock()
# Some chunks might use "text" key
text_key_chunk = {"text": "def legacy_function():\n return True"}
mock_builder.chunkify.return_value = [text_key_chunk]
# Create mock document
doc = MockDocument(
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract text correctly as dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
# Should NOT be stringified
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
# Should contain actual content
assert "def legacy_function()" in chunk_text
assert "return True" in chunk_text
def test_handles_string_chunks(self):
"""Test that plain string chunks still work.
Some chunkers might return plain strings - verify these are preserved.
This test should PASS with current code.
"""
mock_builder = Mock()
# Plain string chunk
plain_string_chunk = "def simple_function():\n pass"
mock_builder.chunkify.return_value = [plain_string_chunk]
# Create mock document
doc = MockDocument(
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should wrap string in dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
assert chunk_text == plain_string_chunk.strip(), (
"Should preserve plain string chunk content"
)
assert "def simple_function()" in chunk_text
assert "pass" in chunk_text
def test_multiple_chunks_with_mixed_formats(self):
"""Test handling of multiple chunks with different formats.
Real-world scenario: astchunk might return a mix of formats.
This test will FAIL if any chunk with 'content' key gets stringified.
"""
mock_builder = Mock()
# Mix of formats
mixed_chunks = [
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
"def second():\n return 2", # Plain string
{"text": "def third():\n return 3"}, # Old format
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
]
mock_builder.chunkify.return_value = mixed_chunks
# Create mock document
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract all chunks correctly as dicts
assert len(chunks) == 4, "Should extract all 4 chunks"
# Check each chunk
for i, chunk in enumerate(chunks):
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
assert "text" in chunk, f"Chunk {i} should have 'text' key"
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
chunk_text = chunk["text"]
# None should be stringified dicts
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
assert "'metadata':" not in chunk_text, (
f"Chunk {i} text is stringified (has 'metadata':)"
)
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
# Verify actual content is present
combined = "\n".join([c["text"] for c in chunks])
assert "def first()" in combined
assert "def second()" in combined
assert "def third()" in combined
assert "class MyClass:" in combined
def test_empty_content_value_handling(self):
"""Test handling of chunks with empty content values.
Edge case: chunk has 'content' key but value is empty.
Should skip these chunks, not stringify them.
"""
mock_builder = Mock()
chunks_with_empty = [
{"content": "", "metadata": {"line_count": 0}}, # Empty content
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
]
mock_builder.chunkify.return_value = chunks_with_empty
doc = MockDocument(
"def valid():\n return True", "/test/empty.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# R3: Should only have the valid chunk (empty ones filtered out)
assert len(chunks) == 1, "Should filter out empty content chunks"
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "def valid()" in chunk["text"]
# Should not have stringified the empty dict
assert "'content': ''" not in chunk["text"]
class TestASTMetadataPreservation:
"""Test metadata preservation in AST chunk dictionaries.
R3: These tests define the contract for metadata preservation when returning
chunk dictionaries instead of plain strings. Each chunk dict should have:
- "text": str - the actual chunk content
- "metadata": dict - all metadata from document AND astchunk
These tests will FAIL until G3 implementation changes return type to list[dict].
"""
def test_ast_chunks_preserve_file_metadata(self):
"""Test that document metadata is preserved in chunk metadata.
This test verifies that all document-level metadata (file_path, file_name,
creation_date, last_modified_date) is included in each chunk's metadata dict.
This will FAIL because current code returns list[str], not list[dict].
"""
# Create mock document with rich metadata
python_code = '''
def calculate_sum(numbers):
"""Calculate sum of numbers."""
return sum(numbers)
class DataProcessor:
"""Process data records."""
def process(self, data):
return [x * 2 for x in data]
'''
doc = MockDocument(
python_code,
file_path="/project/src/utils.py",
metadata={
"language": "python",
"file_path": "/project/src/utils.py",
"file_name": "utils.py",
"creation_date": "2024-01-15T10:30:00",
"last_modified_date": "2024-10-31T15:45:00",
},
)
# Mock astchunk to return chunks with metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def calculate_sum(numbers):\n return sum(numbers)",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 2,
"start_line_no": 1,
"end_line_no": 2,
"node_count": 1,
},
},
{
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 3,
"start_line_no": 5,
"end_line_no": 7,
"node_count": 2,
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These assertions will FAIL with current list[str] return type
assert len(chunks) == 2, "Should return 2 chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL: current code returns strings
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
# Document metadata preservation - WILL FAIL
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
assert metadata["file_path"] == "/project/src/utils.py", (
f"Chunk {i} file_path incorrect"
)
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
f"Chunk {i} creation_date incorrect"
)
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
f"Chunk {i} last_modified_date incorrect"
)
# Verify metadata is consistent across chunks from same document
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
"All chunks from same document should have same file_path"
)
# Verify text content is present and not stringified
assert "def calculate_sum" in chunks[0]["text"]
assert "class DataProcessor" in chunks[1]["text"]
def test_ast_chunks_include_astchunk_metadata(self):
"""Test that astchunk-specific metadata is merged into chunk metadata.
This test verifies that astchunk's metadata (line_count, start_line_no,
end_line_no, node_count) is merged with document metadata.
This will FAIL because current code returns list[str], not list[dict].
"""
python_code = '''
def function_one():
"""First function."""
x = 1
y = 2
return x + y
def function_two():
"""Second function."""
return 42
'''
doc = MockDocument(
python_code,
file_path="/test/code.py",
metadata={
"language": "python",
"file_path": "/test/code.py",
"file_name": "code.py",
},
)
# Mock astchunk with detailed metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
"metadata": {
"filepath": "/test/code.py",
"line_count": 4,
"start_line_no": 1,
"end_line_no": 4,
"node_count": 5, # function, assignments, return
},
},
{
"content": "def function_two():\n return 42",
"metadata": {
"filepath": "/test/code.py",
"line_count": 2,
"start_line_no": 7,
"end_line_no": 8,
"node_count": 2, # function, return
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These will FAIL with current list[str] return
assert len(chunks) == 2
# First chunk - function_one
chunk1 = chunks[0]
assert isinstance(chunk1, dict), "Chunk should be dict"
assert "metadata" in chunk1
metadata1 = chunk1["metadata"]
# Check astchunk metadata is present
assert "line_count" in metadata1, "Should include astchunk line_count"
assert metadata1["line_count"] == 4, "line_count should be 4"
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
assert "node_count" in metadata1, "Should include astchunk node_count"
assert metadata1["node_count"] == 5, "node_count should be 5"
# Second chunk - function_two
chunk2 = chunks[1]
metadata2 = chunk2["metadata"]
assert metadata2["line_count"] == 2, "line_count should be 2"
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
assert metadata2["node_count"] == 2, "node_count should be 2"
# Verify document metadata is ALSO present (merged, not replaced)
assert metadata1["file_path"] == "/test/code.py"
assert metadata1["file_name"] == "code.py"
assert metadata2["file_path"] == "/test/code.py"
assert metadata2["file_name"] == "code.py"
# Verify text content is correct
assert "def function_one" in chunk1["text"]
assert "def function_two" in chunk2["text"]
def test_traditional_chunks_as_dicts_helper(self):
"""Test the helper function that wraps traditional chunks as dicts.
This test verifies that when create_traditional_chunks is called,
its plain string chunks are wrapped into dict format with metadata.
This will FAIL because the helper function _traditional_chunks_as_dicts()
doesn't exist yet, and create_traditional_chunks returns list[str].
"""
# Create documents with various metadata
docs = [
MockDocument(
"This is the first paragraph of text. It contains multiple sentences. "
"This should be split into chunks based on size.",
file_path="/docs/readme.txt",
metadata={
"file_path": "/docs/readme.txt",
"file_name": "readme.txt",
"creation_date": "2024-01-01",
},
),
MockDocument(
"Second document with different metadata. It also has content that needs chunking.",
file_path="/docs/guide.md",
metadata={
"file_path": "/docs/guide.md",
"file_name": "guide.md",
"last_modified_date": "2024-10-31",
},
),
]
# Call create_traditional_chunks (which should now return list[dict])
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
# CRITICAL: Will FAIL - current code returns list[str]
assert len(chunks) > 0, "Should return chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
# Text should be non-empty
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
# Metadata should include document info
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
# Verify metadata tracking works correctly
# At least one chunk should be from readme.txt
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
# At least one chunk should be from guide.md
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
# Verify creation_date is preserved for readme chunks
for chunk in readme_chunks:
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
"readme.txt chunks should preserve creation_date"
)
# Verify last_modified_date is preserved for guide chunks
for chunk in guide_chunks:
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
"guide.md chunks should preserve last_modified_date"
)
# Verify text content is present
all_text = " ".join([c["text"] for c in chunks])
assert "first paragraph" in all_text
assert "Second document" in all_text
class TestErrorHandling:
"""Test error handling and edge cases."""

View File

@@ -0,0 +1,533 @@
"""
Tests for CLI argument integration of --embedding-prompt-template.
These tests verify that:
1. The --embedding-prompt-template flag is properly registered on build and search commands
2. The template value flows from CLI args to embedding_options dict
3. The template is passed through to compute_embeddings() function
4. Default behavior (no flag) is handled correctly
"""
from unittest.mock import Mock, patch
from leann.cli import LeannCLI
class TestCLIPromptTemplateArgument:
"""Tests for --embedding-prompt-template on build and search commands."""
def test_commands_accept_prompt_template_argument(self):
"""Verify that build and search parsers accept --embedding-prompt-template flag."""
cli = LeannCLI()
parser = cli.create_parser()
template_value = "search_query: "
# Test build command
build_args = parser.parse_args(
[
"build",
"test-index",
"--docs",
"/tmp/test-docs",
"--embedding-prompt-template",
template_value,
]
)
assert build_args.command == "build"
assert hasattr(build_args, "embedding_prompt_template"), (
"build command should have embedding_prompt_template attribute"
)
assert build_args.embedding_prompt_template == template_value
# Test search command
search_args = parser.parse_args(
["search", "test-index", "my query", "--embedding-prompt-template", template_value]
)
assert search_args.command == "search"
assert hasattr(search_args, "embedding_prompt_template"), (
"search command should have embedding_prompt_template attribute"
)
assert search_args.embedding_prompt_template == template_value
def test_commands_default_to_none(self):
"""Verify default value is None when flag not provided (backward compatibility)."""
cli = LeannCLI()
parser = cli.create_parser()
# Test build command default
build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"])
assert hasattr(build_args, "embedding_prompt_template"), (
"build command should have embedding_prompt_template attribute"
)
assert build_args.embedding_prompt_template is None, (
"Build default value should be None when flag not provided"
)
# Test search command default
search_args = parser.parse_args(["search", "test-index", "my query"])
assert hasattr(search_args, "embedding_prompt_template"), (
"search command should have embedding_prompt_template attribute"
)
assert search_args.embedding_prompt_template is None, (
"Search default value should be None when flag not provided"
)
class TestBuildCommandPromptTemplateArgumentExtras:
"""Additional build-specific tests for prompt template argument."""
def test_build_command_prompt_template_with_multiword_value(self):
"""
Verify that template values with spaces are handled correctly.
Templates like "search_document: " or "Represent this sentence for searching: "
should be accepted as a single string argument.
"""
cli = LeannCLI()
parser = cli.create_parser()
template = "Represent this sentence for searching: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
"/tmp/test-docs",
"--embedding-prompt-template",
template,
]
)
assert args.embedding_prompt_template == template
class TestPromptTemplateStoredInEmbeddingOptions:
"""Tests for template storage in embedding_options dict."""
@patch("leann.cli.LeannBuilder")
def test_prompt_template_stored_in_embedding_options_on_build(
self, mock_builder_class, tmp_path
):
"""
Verify that when --embedding-prompt-template is provided to build command,
the value is stored in embedding_options dict passed to LeannBuilder.
This test will fail because the CLI doesn't currently process this argument
and add it to embedding_options.
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
# Create CLI and run build command
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "search_query: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--force", # Force rebuild to ensure LeannBuilder is called
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with embedding_options containing prompt_template
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when template provided"
)
assert "prompt_template" in embedding_options, (
"embedding_options should contain 'prompt_template' key"
)
assert embedding_options["prompt_template"] == template, (
f"Template should be '{template}', got {embedding_options.get('prompt_template')}"
)
@patch("leann.cli.LeannBuilder")
def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path):
"""
Verify that when --embedding-prompt-template is NOT provided,
embedding_options either doesn't have the key or it's None.
This ensures we don't pass empty/None values unnecessarily.
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--force", # Force rebuild to ensure LeannBuilder is called
]
)
import asyncio
asyncio.run(cli.build_index(args))
# Check that if embedding_options is passed, it doesn't have prompt_template
call_kwargs = mock_builder_class.call_args.kwargs
if call_kwargs.get("embedding_options"):
embedding_options = call_kwargs["embedding_options"]
# Either the key shouldn't exist, or it should be None
assert (
"prompt_template" not in embedding_options
or embedding_options["prompt_template"] is None
), "prompt_template should not be set when flag not provided"
# R1 Tests: Build-time separate template storage
@patch("leann.cli.LeannBuilder")
def test_build_stores_separate_templates(self, mock_builder_class, tmp_path):
"""
R1 Test 1: Verify that when both --embedding-prompt-template and
--query-prompt-template are provided to build command, both values
are stored separately in embedding_options dict as build_prompt_template
and query_prompt_template.
This test will fail because:
1. CLI doesn't accept --query-prompt-template flag yet
2. CLI doesn't store templates as separate build_prompt_template and
query_prompt_template keys
Expected behavior after implementation:
- .meta.json contains: {"embedding_options": {
"build_prompt_template": "doc: ",
"query_prompt_template": "query: "
}}
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
build_template = "doc: "
query_template = "query: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
build_template,
"--query-prompt-template",
query_template,
"--force",
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with separate template keys
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when templates provided"
)
assert "build_prompt_template" in embedding_options, (
"embedding_options should contain 'build_prompt_template' key"
)
assert embedding_options["build_prompt_template"] == build_template, (
f"build_prompt_template should be '{build_template}'"
)
assert "query_prompt_template" in embedding_options, (
"embedding_options should contain 'query_prompt_template' key"
)
assert embedding_options["query_prompt_template"] == query_template, (
f"query_prompt_template should be '{query_template}'"
)
# Old key should NOT be present when using new separate template format
assert "prompt_template" not in embedding_options, (
"Old 'prompt_template' key should not be present with separate templates"
)
@patch("leann.cli.LeannBuilder")
def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path):
"""
R1 Test 2: Verify backward compatibility - when only
--embedding-prompt-template is provided (old behavior), it should
still be stored as 'prompt_template' in embedding_options.
This ensures existing workflows continue to work unchanged.
This test currently passes because it matches existing behavior, but it
documents the requirement that this behavior must be preserved after
implementing the separate template feature.
Expected behavior:
- .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}}
- No build_prompt_template or query_prompt_template keys
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "prompt: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--force",
]
)
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that LeannBuilder was called with old format
call_kwargs = mock_builder_class.call_args.kwargs
assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options"
embedding_options = call_kwargs["embedding_options"]
assert embedding_options is not None, (
"embedding_options should not be None when template provided"
)
assert "prompt_template" in embedding_options, (
"embedding_options should contain old 'prompt_template' key for backward compat"
)
assert embedding_options["prompt_template"] == template, (
f"prompt_template should be '{template}'"
)
# New keys should NOT be present in backward compat mode
assert "build_prompt_template" not in embedding_options, (
"build_prompt_template should not be present with single template flag"
)
assert "query_prompt_template" not in embedding_options, (
"query_prompt_template should not be present with single template flag"
)
@patch("leann.cli.LeannBuilder")
def test_build_no_templates(self, mock_builder_class, tmp_path):
"""
R1 Test 3: Verify that when no template flags are provided,
embedding_options has no prompt template keys.
This ensures clean defaults and no unnecessary keys in .meta.json.
This test currently passes because it matches existing behavior, but it
documents the requirement that this behavior must be preserved after
implementing the separate template feature.
Expected behavior:
- .meta.json has no prompt_template, build_prompt_template, or
query_prompt_template keys (or embedding_options is empty/None)
"""
# Setup mocks
mock_builder = Mock()
mock_builder_class.return_value = mock_builder
cli = LeannCLI()
# Mock load_documents to return a document so builder is created
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"])
# Run the build command
import asyncio
asyncio.run(cli.build_index(args))
# Check that no template keys are present
call_kwargs = mock_builder_class.call_args.kwargs
if call_kwargs.get("embedding_options"):
embedding_options = call_kwargs["embedding_options"]
# None of the template keys should be present
assert "prompt_template" not in embedding_options, (
"prompt_template should not be present when no flags provided"
)
assert "build_prompt_template" not in embedding_options, (
"build_prompt_template should not be present when no flags provided"
)
assert "query_prompt_template" not in embedding_options, (
"query_prompt_template should not be present when no flags provided"
)
class TestPromptTemplateFlowsToComputeEmbeddings:
"""Tests for template flowing through to compute_embeddings function."""
@patch("leann.api.compute_embeddings")
def test_prompt_template_flows_to_compute_embeddings_via_provider_options(
self, mock_compute_embeddings, tmp_path
):
"""
Verify that the prompt template flows from CLI args through LeannBuilder
to compute_embeddings() function via provider_options parameter.
This is an integration test that verifies the complete flow:
CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options)
This test will fail because:
1. CLI doesn't capture the argument yet
2. embedding_options doesn't include prompt_template
3. LeannBuilder doesn't pass it through to compute_embeddings
"""
# Mock compute_embeddings to return dummy embeddings as numpy array
import numpy as np
mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
# Use real LeannBuilder (not mocked) to test the actual flow
cli = LeannCLI()
# Mock load_documents to return a simple document
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
parser = cli.create_parser()
template = "search_document: "
args = parser.parse_args(
[
"build",
"test-index",
"--docs",
str(tmp_path),
"--embedding-prompt-template",
template,
"--backend-name",
"hnsw", # Use hnsw backend
"--force", # Force rebuild to ensure index is created
]
)
# This should fail because the flow isn't implemented yet
import asyncio
asyncio.run(cli.build_index(args))
# Verify compute_embeddings was called with provider_options containing prompt_template
assert mock_compute_embeddings.called, "compute_embeddings should have been called"
# Check the call arguments
call_kwargs = mock_compute_embeddings.call_args.kwargs
assert "provider_options" in call_kwargs, (
"compute_embeddings should receive provider_options parameter"
)
provider_options = call_kwargs["provider_options"]
assert provider_options is not None, "provider_options should not be None"
assert "prompt_template" in provider_options, (
"provider_options should contain prompt_template key"
)
assert provider_options["prompt_template"] == template, (
f"Template should be '{template}', got {provider_options.get('prompt_template')}"
)
class TestPromptTemplateArgumentHelp:
"""Tests for argument help text and documentation."""
def test_build_command_prompt_template_has_help_text(self):
"""
Verify that --embedding-prompt-template has descriptive help text.
Good help text is crucial for CLI usability.
"""
cli = LeannCLI()
parser = cli.create_parser()
# Get the build subparser
# This is a bit tricky - we need to parse to get the help
# We'll check that the help includes relevant keywords
import io
from contextlib import redirect_stdout
f = io.StringIO()
try:
with redirect_stdout(f):
parser.parse_args(["build", "--help"])
except SystemExit:
pass # --help causes sys.exit(0)
help_text = f.getvalue()
assert "--embedding-prompt-template" in help_text, (
"Help text should mention --embedding-prompt-template"
)
# Check for keywords that should be in the help
help_lower = help_text.lower()
assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), (
"Help text should explain what the prompt template does"
)
def test_search_command_prompt_template_has_help_text(self):
"""
Verify that search command also has help text for --embedding-prompt-template.
"""
cli = LeannCLI()
parser = cli.create_parser()
import io
from contextlib import redirect_stdout
f = io.StringIO()
try:
with redirect_stdout(f):
parser.parse_args(["search", "--help"])
except SystemExit:
pass # --help causes sys.exit(0)
help_text = f.getvalue()
assert "--embedding-prompt-template" in help_text, (
"Search help text should mention --embedding-prompt-template"
)

View File

@@ -0,0 +1,281 @@
"""Unit tests for prompt template prepending in OpenAI embeddings.
This test suite defines the contract for prompt template functionality that allows
users to prepend a consistent prompt to all embedding inputs. These tests verify:
1. Template prepending to all input texts before embedding computation
2. Graceful handling of None/missing provider_options
3. Empty string template behavior (no-op)
4. Logging of template application for observability
5. Template application before token truncation
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import pytest
from leann.embedding_compute import compute_embeddings_openai
class TestPromptTemplatePrepending:
"""Tests for prompt template prepending in compute_embeddings_openai."""
@pytest.fixture
def mock_openai_client(self):
"""Create mock OpenAI client that captures input texts."""
mock_client = MagicMock()
# Mock the embeddings.create response
mock_response = Mock()
mock_response.data = [
Mock(embedding=[0.1, 0.2, 0.3]),
Mock(embedding=[0.4, 0.5, 0.6]),
]
mock_client.embeddings.create.return_value = mock_response
return mock_client
@pytest.fixture
def mock_openai_module(self, mock_openai_client, monkeypatch):
"""Mock the openai module to return our mock client."""
# Mock the API key environment variable
monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking")
# openai is imported inside the function, so we need to patch it there
with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai:
yield mock_openai
def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client):
"""Verify template is prepended to all input texts.
When provider_options contains "prompt_template", that template should
be prepended to every text in the input list before sending to OpenAI API.
This is the core functionality: the template acts as a consistent prefix
that provides context or instruction for the embedding model.
"""
texts = ["First document", "Second document"]
template = "search_document: "
provider_options = {"prompt_template": template}
# Call compute_embeddings_openai with provider_options
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify embeddings.create was called with templated texts
mock_openai_client.embeddings.create.assert_called_once()
call_args = mock_openai_client.embeddings.create.call_args
# Extract the input texts sent to API
sent_texts = call_args.kwargs["input"]
# Verify template was prepended to all texts
assert len(sent_texts) == 2, "Should send same number of texts"
assert sent_texts[0] == "search_document: First document", (
"Template should be prepended to first text"
)
assert sent_texts[1] == "search_document: Second document", (
"Template should be prepended to second text"
)
# Verify result is valid embeddings array
assert isinstance(result, np.ndarray)
assert result.shape == (2, 3), "Should return correct shape"
def test_template_not_applied_when_missing_or_empty(
self, mock_openai_module, mock_openai_client
):
"""Verify template not applied when provider_options is None, missing key, or empty string.
This consolidated test covers three scenarios where templates should NOT be applied:
1. provider_options is None (default behavior)
2. provider_options exists but missing 'prompt_template' key
3. prompt_template is explicitly set to empty string ""
In all cases, texts should be sent to the API unchanged.
"""
# Scenario 1: None provider_options
texts = ["Original text one", "Original text two"]
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=None,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Original text one", (
"Text should be unchanged with None provider_options"
)
assert sent_texts[1] == "Original text two"
assert isinstance(result, np.ndarray)
assert result.shape == (2, 3)
# Reset mock for next scenario
mock_openai_client.reset_mock()
mock_response = Mock()
mock_response.data = [
Mock(embedding=[0.1, 0.2, 0.3]),
Mock(embedding=[0.4, 0.5, 0.6]),
]
mock_openai_client.embeddings.create.return_value = mock_response
# Scenario 2: Missing 'prompt_template' key
texts = ["Text without template", "Another text"]
provider_options = {"base_url": "https://api.openai.com/v1"}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key"
assert sent_texts[1] == "Another text"
assert isinstance(result, np.ndarray)
# Reset mock for next scenario
mock_openai_client.reset_mock()
mock_openai_client.embeddings.create.return_value = mock_response
# Scenario 3: Empty string template
texts = ["Text one", "Text two"]
provider_options = {"prompt_template": ""}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "Text one", "Empty template should not modify text"
assert sent_texts[1] == "Text two"
assert isinstance(result, np.ndarray)
def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client):
"""Verify template is prepended in all batches when texts exceed batch size.
OpenAI API has batch size limits. When input texts are split into
multiple batches, the template should be prepended to texts in every batch.
This ensures consistency across all API calls.
"""
# Create many texts that will be split into multiple batches
texts = [f"Document {i}" for i in range(1000)]
template = "passage: "
provider_options = {"prompt_template": template}
# Mock multiple batch responses
mock_response = Mock()
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)]
mock_openai_client.embeddings.create.return_value = mock_response
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify embeddings.create was called multiple times (batching)
assert mock_openai_client.embeddings.create.call_count >= 2, (
"Should make multiple API calls for large text list"
)
# Verify template was prepended in ALL batches
for call in mock_openai_client.embeddings.create.call_args_list:
sent_texts = call.kwargs["input"]
for text in sent_texts:
assert text.startswith(template), (
f"All texts in all batches should start with template. Got: {text}"
)
# Verify result shape
assert result.shape[0] == 1000, "Should return embeddings for all texts"
def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client):
"""Verify template with special characters is handled correctly.
Templates may contain special characters, Unicode, newlines, etc.
These should all be prepended correctly without encoding issues.
"""
texts = ["Document content"]
# Template with various special characters
template = "🔍 Search query [EN]: "
provider_options = {"prompt_template": template}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify special characters in template were preserved
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "🔍 Search query [EN]: Document content", (
"Special characters in template should be preserved"
)
assert isinstance(result, np.ndarray)
def test_prompt_template_integration_with_existing_validation(
self, mock_openai_module, mock_openai_client
):
"""Verify template works with existing input validation.
compute_embeddings_openai has validation for empty texts and whitespace.
Template prepending should happen AFTER validation, so validation errors
are thrown based on original texts, not templated texts.
This ensures users get clear error messages about their input.
"""
# Empty text should still raise ValueError even with template
texts = [""]
provider_options = {"prompt_template": "prefix: "}
with pytest.raises(ValueError, match="empty/invalid"):
compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
def test_prompt_template_with_api_key_and_base_url(
self, mock_openai_module, mock_openai_client
):
"""Verify template works alongside other provider_options.
provider_options may contain multiple settings: prompt_template,
base_url, api_key. All should work together correctly.
"""
texts = ["Test document"]
provider_options = {
"prompt_template": "embed: ",
"base_url": "https://custom.api.com/v1",
"api_key": "test-key-123",
}
result = compute_embeddings_openai(
texts=texts,
model_name="text-embedding-3-small",
provider_options=provider_options,
)
# Verify template was applied
call_args = mock_openai_client.embeddings.create.call_args
sent_texts = call_args.kwargs["input"]
assert sent_texts[0] == "embed: Test document"
# Verify OpenAI client was created with correct base_url
mock_openai_module.assert_called()
client_init_kwargs = mock_openai_module.call_args.kwargs
assert client_init_kwargs["base_url"] == "https://custom.api.com/v1"
assert client_init_kwargs["api_key"] == "test-key-123"
assert isinstance(result, np.ndarray)

View File

@@ -0,0 +1,315 @@
"""Unit tests for LM Studio TypeScript SDK bridge functionality.
This test suite defines the contract for the LM Studio SDK bridge that queries
model context length via Node.js subprocess. These tests verify:
1. Successful SDK query returns context length
2. Graceful fallback when Node.js not installed (FileNotFoundError)
3. Graceful fallback when SDK not installed (npm error)
4. Timeout handling (subprocess.TimeoutExpired)
5. Invalid JSON response handling
All tests are written in Red Phase - they should FAIL initially because the
`_query_lmstudio_context_limit` function does not exist yet.
The function contract:
- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234")
- Outputs: context_length (int) or None on error
- Requirements:
1. Call Node.js with inline JavaScript using @lmstudio/sdk
2. 10-second timeout (accounts for Node.js startup)
3. Graceful fallback on any error (returns None, doesn't raise)
4. Parse JSON response with contextLength field
5. Log errors at debug level (not warning/error)
"""
import subprocess
from unittest.mock import Mock
import pytest
# Try to import the function - if it doesn't exist, tests will fail as expected
try:
from leann.embedding_compute import _query_lmstudio_context_limit
except ImportError:
# Function doesn't exist yet (Red Phase) - create a placeholder that will fail
def _query_lmstudio_context_limit(*args, **kwargs):
raise NotImplementedError(
"_query_lmstudio_context_limit not implemented yet - this is the Red Phase"
)
class TestLMStudioBridge:
"""Tests for LM Studio TypeScript SDK bridge integration."""
def test_query_lmstudio_success(self, monkeypatch):
"""Verify successful SDK query returns context length.
When the Node.js subprocess successfully queries the LM Studio SDK,
it should return a JSON response with contextLength field. The function
should parse this and return the integer context length.
"""
def mock_run(*args, **kwargs):
# Verify timeout is set to 10 seconds
assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup"
# Verify capture_output and text=True are set
assert kwargs.get("capture_output") is True, "Should capture stdout/stderr"
assert kwargs.get("text") is True, "Should decode output as text"
# Return successful JSON response
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
# Test with typical LM Studio model
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit == 8192, "Should return context length from SDK response"
def test_query_lmstudio_nodejs_not_found(self, monkeypatch):
"""Verify graceful fallback when Node.js not installed.
When Node.js is not installed, subprocess.run will raise FileNotFoundError.
The function should catch this and return None (graceful fallback to registry).
"""
def mock_run(*args, **kwargs):
raise FileNotFoundError("node: command not found")
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when Node.js not installed"
def test_query_lmstudio_sdk_not_installed(self, monkeypatch):
"""Verify graceful fallback when @lmstudio/sdk not installed.
When the SDK npm package is not installed, Node.js will return non-zero
exit code with error message in stderr. The function should detect this
and return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 1
mock_result.stdout = ""
mock_result.stderr = (
"Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js"
)
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when SDK not installed"
def test_query_lmstudio_timeout(self, monkeypatch):
"""Verify graceful fallback when subprocess times out.
When the Node.js process takes longer than 10 seconds (e.g., LM Studio
not responding), subprocess.TimeoutExpired should be raised. The function
should catch this and return None.
"""
def mock_run(*args, **kwargs):
raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10)
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None on timeout"
def test_query_lmstudio_invalid_json(self, monkeypatch):
"""Verify graceful fallback when response is invalid JSON.
When the subprocess returns malformed JSON (e.g., due to SDK error),
json.loads will raise ValueError/JSONDecodeError. The function should
catch this and return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = "This is not valid JSON"
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when JSON parsing fails"
def test_query_lmstudio_missing_context_length_field(self, monkeypatch):
"""Verify graceful fallback when JSON lacks contextLength field.
When the SDK returns valid JSON but without the expected contextLength
field (e.g., error response), the function should return None.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="nonexistent-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength field missing"
def test_query_lmstudio_null_context_length(self, monkeypatch):
"""Verify graceful fallback when contextLength is null.
When the SDK returns contextLength: null (model couldn't be loaded),
the function should return None for registry fallback.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength is null"
def test_query_lmstudio_zero_context_length(self, monkeypatch):
"""Verify graceful fallback when contextLength is zero.
When the SDK returns contextLength: 0 (invalid value), the function
should return None to trigger registry fallback.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit is None, "Should return None when contextLength is zero"
def test_query_lmstudio_with_custom_port(self, monkeypatch):
"""Verify SDK query works with non-default WebSocket port.
LM Studio can run on custom ports. The function should pass the
provided base_url to the Node.js subprocess.
"""
def mock_run(*args, **kwargs):
# Verify the base_url argument is passed correctly
command = args[0] if args else kwargs.get("args", [])
assert "ws://localhost:8080" in " ".join(command), (
"Should pass custom port to subprocess"
)
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="custom-model", base_url="ws://localhost:8080"
)
assert limit == 4096, "Should work with custom WebSocket port"
@pytest.mark.parametrize(
"context_length,expected",
[
(512, 512), # Small context
(2048, 2048), # Common context
(8192, 8192), # Large context
(32768, 32768), # Very large context
],
)
def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected):
"""Verify SDK query handles various context length values.
Different models have different context lengths. The function should
correctly parse and return any positive integer value.
"""
def mock_run(*args, **kwargs):
mock_result = Mock()
mock_result.returncode = 0
mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}'
mock_result.stderr = ""
return mock_result
monkeypatch.setattr("subprocess.run", mock_run)
limit = _query_lmstudio_context_limit(
model_name="test-model", base_url="ws://localhost:1234"
)
assert limit == expected, f"Should return {expected} for context length {context_length}"
def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog):
"""Verify errors are logged at DEBUG level, not WARNING/ERROR.
Following the graceful fallback pattern from Ollama implementation,
errors should be logged at debug level to avoid alarming users when
fallback to registry works fine.
"""
import logging
caplog.set_level(logging.DEBUG, logger="leann.embedding_compute")
def mock_run(*args, **kwargs):
raise FileNotFoundError("node: command not found")
monkeypatch.setattr("subprocess.run", mock_run)
_query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234")
# Check that debug logging occurred (not warning/error)
debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"]
assert len(debug_logs) > 0, "Should log error at DEBUG level"
# Verify no WARNING or ERROR logs
warning_or_error_logs = [
record for record in caplog.records if record.levelname in ["WARNING", "ERROR"]
]
assert len(warning_or_error_logs) == 0, (
"Should not log at WARNING/ERROR level for expected failures"
)

View File

@@ -0,0 +1,400 @@
"""End-to-end integration tests for prompt template and token limit features.
These tests verify real-world functionality with live services:
- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support
- Ollama with dynamic token limit detection
- Hybrid token limit discovery mechanism
Run with: pytest tests/test_prompt_template_e2e.py -v -s
Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration"
Prerequisites:
1. LM Studio running with embedding model: http://localhost:1234
2. [Optional] Ollama running: ollama serve
3. [Optional] Ollama model: ollama pull nomic-embed-text
4. [Optional] Node.js + @lmstudio/sdk for context length detection
"""
import logging
import socket
import numpy as np
import pytest
import requests
from leann.embedding_compute import (
compute_embeddings_ollama,
compute_embeddings_openai,
get_model_token_limit,
)
# Test markers for conditional execution
pytestmark = pytest.mark.integration
logger = logging.getLogger(__name__)
def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool:
"""Check if a service is available on the given host:port."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
result = sock.connect_ex((host, port))
sock.close()
return result == 0
except Exception:
return False
def check_ollama_available() -> bool:
"""Check if Ollama service is available."""
if not check_service_available("localhost", 11434):
return False
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
return response.status_code == 200
except Exception:
return False
def check_lmstudio_available() -> bool:
"""Check if LM Studio service is available."""
if not check_service_available("localhost", 1234):
return False
try:
response = requests.get("http://localhost:1234/v1/models", timeout=2.0)
return response.status_code == 200
except Exception:
return False
def get_lmstudio_first_model() -> str:
"""Get the first available model from LM Studio."""
try:
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
data = response.json()
models = data.get("data", [])
if models:
return models[0]["id"]
except Exception:
pass
return None
class TestPromptTemplateOpenAI:
"""End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio)."""
@pytest.mark.skipif(
not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234"
)
def test_lmstudio_embedding_with_prompt_template(self):
"""Test prompt templates with LM Studio using OpenAI-compatible API."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
texts = ["artificial intelligence", "machine learning"]
prompt_template = "search_query: "
# Get embeddings with prompt template via provider_options
provider_options = {"prompt_template": prompt_template}
embeddings = compute_embeddings_openai(
texts=texts,
model_name=model_name,
base_url="http://localhost:1234/v1",
api_key="lm-studio", # LM Studio doesn't require real key
provider_options=provider_options,
)
assert embeddings is not None
assert len(embeddings) == 2
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
assert all(len(emb) > 0 for emb in embeddings)
logger.info(
f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
)
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_prompt_template_affects_embeddings(self):
"""Verify that prompt templates actually change embedding values."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
text = "machine learning"
base_url = "http://localhost:1234/v1"
api_key = "lm-studio"
# Get embeddings without template
embeddings_no_template = compute_embeddings_openai(
texts=[text],
model_name=model_name,
base_url=base_url,
api_key=api_key,
provider_options={},
)
# Get embeddings with template
embeddings_with_template = compute_embeddings_openai(
texts=[text],
model_name=model_name,
base_url=base_url,
api_key=api_key,
provider_options={"prompt_template": "search_query: "},
)
# Embeddings should be different when template is applied
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
logger.info("✓ Prompt template changes embedding values as expected")
class TestPromptTemplateOllama:
"""End-to-end tests for prompt template with Ollama."""
@pytest.mark.skipif(
not check_ollama_available(), reason="Ollama service not available on localhost:11434"
)
def test_ollama_embedding_with_prompt_template(self):
"""Test prompt templates with Ollama using any available embedding model."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
model_name = embedding_models[0]
texts = ["artificial intelligence", "machine learning"]
prompt_template = "search_query: "
# Get embeddings with prompt template via provider_options
provider_options = {"prompt_template": prompt_template}
embeddings = compute_embeddings_ollama(
texts=texts,
model_name=model_name,
is_build=False,
host="http://localhost:11434",
provider_options=provider_options,
)
assert embeddings is not None
assert len(embeddings) == 2
assert all(isinstance(emb, np.ndarray) for emb in embeddings)
assert all(len(emb) > 0 for emb in embeddings)
logger.info(
f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions"
)
except Exception as e:
pytest.skip(f"Could not test Ollama prompt template: {e}")
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_ollama_prompt_template_affects_embeddings(self):
"""Verify that prompt templates actually change embedding values with Ollama."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
model_name = embedding_models[0]
text = "machine learning"
host = "http://localhost:11434"
# Get embeddings without template
embeddings_no_template = compute_embeddings_ollama(
texts=[text], model_name=model_name, is_build=False, host=host, provider_options={}
)
# Get embeddings with template
embeddings_with_template = compute_embeddings_ollama(
texts=[text],
model_name=model_name,
is_build=False,
host=host,
provider_options={"prompt_template": "search_query: "},
)
# Embeddings should be different when template is applied
assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0])
logger.info("✓ Ollama prompt template changes embedding values as expected")
except Exception as e:
pytest.skip(f"Could not test Ollama prompt template: {e}")
class TestLMStudioSDK:
"""End-to-end tests for LM Studio SDK integration."""
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_model_listing(self):
"""Test that we can list models from LM Studio."""
try:
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
assert response.status_code == 200
data = response.json()
assert "data" in data
models = data["data"]
logger.info(f"✓ LM Studio models available: {len(models)}")
if models:
logger.info(f" First model: {models[0].get('id', 'unknown')}")
except Exception as e:
pytest.skip(f"LM Studio API error: {e}")
@pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available")
def test_lmstudio_sdk_context_length_detection(self):
"""Test context length detection via LM Studio SDK bridge (requires Node.js + SDK)."""
model_name = get_lmstudio_first_model()
if not model_name:
pytest.skip("No models loaded in LM Studio")
try:
from leann.embedding_compute import _query_lmstudio_context_limit
# SDK requires WebSocket URL (ws://)
context_length = _query_lmstudio_context_limit(
model_name=model_name, base_url="ws://localhost:1234"
)
if context_length is None:
logger.warning(
"⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)"
)
pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable")
else:
assert context_length > 0
logger.info(
f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}"
)
except ImportError:
pytest.skip("_query_lmstudio_context_limit not implemented yet")
except Exception as e:
logger.error(f"LM Studio SDK test error: {e}")
raise
class TestOllamaTokenLimit:
"""End-to-end tests for Ollama token limit discovery."""
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_ollama_token_limit_detection(self):
"""Test dynamic token limit detection from Ollama /api/show endpoint."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
test_model = embedding_models[0]
# Test token limit detection
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
assert limit > 0
logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}")
except Exception as e:
pytest.skip(f"Could not test Ollama token detection: {e}")
class TestHybridTokenLimit:
"""End-to-end tests for hybrid token limit discovery mechanism."""
def test_hybrid_discovery_registry_fallback(self):
"""Test fallback to static registry for known OpenAI models."""
# Use a known OpenAI model (should be in registry)
limit = get_model_token_limit(
model_name="text-embedding-3-small",
base_url="http://fake-server:9999", # Fake URL to force registry lookup
)
# text-embedding-3-small should have 8192 in registry
assert limit == 8192
logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens")
def test_hybrid_discovery_default_fallback(self):
"""Test fallback to safe default for completely unknown models."""
limit = get_model_token_limit(
model_name="completely-unknown-model-xyz-12345",
base_url="http://fake-server:9999",
default=512,
)
# Should get the specified default
assert limit == 512
logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens")
@pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available")
def test_hybrid_discovery_ollama_dynamic_first(self):
"""Test that Ollama models use dynamic discovery first."""
# Get any available embedding model
try:
response = requests.get("http://localhost:11434/api/tags", timeout=2.0)
models = response.json().get("models", [])
embedding_models = []
for model in models:
name = model["name"]
base_name = name.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]):
embedding_models.append(name)
if not embedding_models:
pytest.skip("No embedding models available in Ollama")
test_model = embedding_models[0]
# Should query Ollama /api/show dynamically
limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434")
assert limit > 0
logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}")
except Exception as e:
pytest.skip(f"Could not test hybrid Ollama discovery: {e}")
if __name__ == "__main__":
print("\n" + "=" * 70)
print("INTEGRATION TEST SUITE - Real Service Testing")
print("=" * 70)
print("\nThese tests require live services:")
print(" • LM Studio: http://localhost:1234 (with embedding model loaded)")
print(" • [Optional] Ollama: http://localhost:11434")
print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests")
print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s")
print("=" * 70 + "\n")

View File

@@ -0,0 +1,808 @@
"""
Integration tests for prompt template metadata persistence and reuse.
These tests verify the complete lifecycle of prompt template persistence:
1. Template is saved to .meta.json during index build
2. Template is automatically loaded during search operations
3. Template can be overridden with explicit flag during search
4. Template is reused during chat/ask operations
These are integration tests that:
- Use real file system with temporary directories
- Run actual build and search operations
- Inspect .meta.json file contents directly
- Mock embedding servers to avoid external dependencies
- Use small test codebases for fast execution
Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch
import numpy as np
import pytest
from leann.api import LeannBuilder, LeannSearcher
class TestPromptTemplateMetadataPersistence:
"""Tests for prompt template storage in .meta.json during build."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
# Return dummy embeddings as numpy array
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings):
"""
Verify that when build is run with embedding_options containing prompt_template,
the template value is saved to .meta.json file.
This is the core persistence requirement - templates must be saved to allow
reuse in subsequent search operations without re-specifying the flag.
Expected failure: .meta.json exists but doesn't contain embedding_options
with prompt_template, or the value is not persisted correctly.
"""
# Setup test data
index_path = temp_index_dir / "test_index.leann"
template = "search_document: "
# Build index with prompt template in embedding_options
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": template},
)
# Add a simple document
builder.add_text("This is a test document for indexing")
# Build the index
builder.build_index(str(index_path))
# Verify .meta.json was created and contains the template
meta_path = temp_index_dir / "test_index.leann.meta.json"
assert meta_path.exists(), ".meta.json file should be created during build"
# Read and parse metadata
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
# Verify embedding_options exists in metadata
assert "embedding_options" in meta_data, (
"embedding_options should be saved to .meta.json when provided"
)
# Verify prompt_template is in embedding_options
embedding_options = meta_data["embedding_options"]
assert "prompt_template" in embedding_options, (
"prompt_template should be saved within embedding_options"
)
# Verify the template value matches what we provided
assert embedding_options["prompt_template"] == template, (
f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'"
)
def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings):
"""
Verify that when no prompt template is provided during build,
.meta.json either doesn't have embedding_options or prompt_template key.
This ensures clean metadata without unnecessary keys when features aren't used.
Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template.
"""
index_path = temp_index_dir / "test_no_template.leann"
# Build index WITHOUT prompt template
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# No embedding_options provided
)
builder.add_text("Document without template")
builder.build_index(str(index_path))
# Verify metadata
meta_path = temp_index_dir / "test_no_template.leann.meta.json"
assert meta_path.exists()
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
# If embedding_options exists, it should not contain prompt_template
if "embedding_options" in meta_data:
embedding_options = meta_data["embedding_options"]
assert "prompt_template" not in embedding_options, (
"prompt_template should not be in metadata when not provided"
)
class TestPromptTemplateAutoLoadOnSearch:
"""Tests for automatic loading of prompt template during search operations.
NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search).
This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad
which uses simpler mocking and doesn't hang.
"""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings):
"""
Verify that searching an index built WITHOUT a prompt template
works correctly (backward compatibility).
The searcher should handle missing prompt_template gracefully.
Expected behavior: Search succeeds, no template is used.
"""
# Build index without template
index_path = temp_index_dir / "no_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
)
builder.add_text("Document without template")
builder.build_index(str(index_path))
# Reset mocks
mock_embeddings.reset_mock()
# Create searcher and search
searcher = LeannSearcher(index_path=str(index_path))
# Verify no template in embedding_options
assert "prompt_template" not in searcher.embedding_options, (
"Searcher should not have prompt_template when not in metadata"
)
class TestQueryPromptTemplateAutoLoad:
"""Tests for automatic loading of separate query_prompt_template during search (R2).
These tests verify the new two-template system where:
- build_prompt_template: Applied during index building
- query_prompt_template: Applied during search operations
Expected to FAIL in Red Phase (R2) because query template extraction
and application is not yet implemented in LeannSearcher.search().
"""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_compute_embeddings(self):
"""Mock compute_embeddings to capture calls and return dummy embeddings."""
with patch("leann.embedding_compute.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify that search() automatically loads and applies query_prompt_template from .meta.json.
Given: Index built with separate build_prompt_template and query_prompt_template
When: LeannSearcher.search("my query") is called
Then: Query embedding is computed with "query: my query" (query template applied)
This is the core R2 requirement - query templates must be auto-loaded and applied
during search without user intervention.
Expected failure: compute_embeddings called with raw "my query" instead of
"query: my query" because query template extraction is not implemented.
"""
# Setup: Build index with separate templates in new format
index_path = temp_index_dir / "query_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={
"build_prompt_template": "doc: ",
"query_prompt_template": "query: ",
},
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock to ignore build calls
mock_compute_embeddings.reset_mock()
# Act: Search with query
searcher = LeannSearcher(index_path=str(index_path))
# Mock the backend search to avoid actual search
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {
"labels": [["test_id_0"]], # IDs (nested list for batch support)
"distances": [[0.9]], # Distances (nested list for batch support)
}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: compute_embeddings was called with query template applied
assert mock_compute_embeddings.called, "compute_embeddings should be called during search"
# Get the actual text passed to compute_embeddings
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0] # First positional arg (list of texts)
assert len(texts_arg) == 1, "Should compute embedding for one query"
assert texts_arg[0] == "query: my query", (
f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'"
)
def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify backward compatibility with old single prompt_template format.
Given: Index with old format (single prompt_template, no query_prompt_template)
When: LeannSearcher.search("my query") is called
Then: Query embedding computed with "doc: my query" (old template applied)
This ensures indexes built with the old single-template system continue
to work correctly with the new search implementation.
Expected failure: Old template not recognized/applied because backward
compatibility logic is not implemented.
"""
# Setup: Build index with old single-template format
index_path = temp_index_dir / "old_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": "doc: "}, # Old format
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: Old template was applied
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "doc: my query", (
f"Old prompt_template should be applied for backward compatibility: "
f"expected 'doc: my query', got '{texts_arg[0]}'"
)
def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings):
"""
Verify backward compatibility when no template is present in .meta.json.
Given: Index with no template in .meta.json (very old indexes)
When: LeannSearcher.search("my query") is called
Then: Query embedding computed with "my query" (no template, raw query)
This ensures the most basic backward compatibility - indexes without
any template support continue to work as before.
Expected failure: May fail if default template is incorrectly applied,
or if missing template causes error.
"""
# Setup: Build index without any template
index_path = temp_index_dir / "no_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# No embedding_options at all
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
searcher.search("my query", top_k=1, recompute_embeddings=False)
# Assert: No template applied (raw query)
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "my query", (
f"No template should be applied when missing from metadata: "
f"expected 'my query', got '{texts_arg[0]}'"
)
def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings):
"""
Verify that explicit provider_options can override stored query template.
Given: Index with query_prompt_template: "query: "
When: search() called with provider_options={"prompt_template": "override: "}
Then: Query embedding computed with "override: test" (override takes precedence)
This enables users to experiment with different query templates without
rebuilding the index, or to handle special query types differently.
Expected failure: provider_options parameter is accepted via **kwargs but
not used. Query embedding computed with raw "test" instead of "override: test"
because override logic is not implemented.
"""
# Setup: Build index with query template
index_path = temp_index_dir / "override_template.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={
"build_prompt_template": "doc: ",
"query_prompt_template": "query: ",
},
)
builder.add_text("Test document")
builder.build_index(str(index_path))
# Reset mock
mock_compute_embeddings.reset_mock()
# Act: Search with override
searcher = LeannSearcher(index_path=str(index_path))
with patch.object(searcher.backend_impl, "search") as mock_backend_search:
mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]}
# This should accept provider_options parameter
searcher.search(
"test",
top_k=1,
recompute_embeddings=False,
provider_options={"prompt_template": "override: "},
)
# Assert: Override template was applied
call_args = mock_compute_embeddings.call_args
texts_arg = call_args[0][0]
assert texts_arg[0] == "override: test", (
f"Override template should take precedence: "
f"expected 'override: test', got '{texts_arg[0]}'"
)
class TestPromptTemplateReuseInChat:
"""Tests for prompt template reuse in chat/ask operations."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_embeddings(self):
"""Mock compute_embeddings to return dummy embeddings."""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
yield mock_compute
@pytest.fixture
def mock_embedding_server_manager(self):
"""Mock EmbeddingServerManager for chat tests."""
with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class:
mock_manager = Mock()
mock_manager.start_server.return_value = (True, 5557)
mock_manager_class.return_value = mock_manager
yield mock_manager
@pytest.fixture
def index_with_template(self, temp_index_dir, mock_embeddings):
"""Build an index with a prompt template."""
index_path = temp_index_dir / "chat_template_index.leann"
template = "document_query: "
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_options={"prompt_template": template},
)
builder.add_text("Test document for chat")
builder.build_index(str(index_path))
return str(index_path), template
class TestPromptTemplateIntegrationWithEmbeddingModes:
"""Tests for prompt template compatibility with different embedding modes."""
@pytest.fixture
def temp_index_dir(self):
"""Create temporary directory for test indexes."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.mark.parametrize(
"mode,model,template,filename_prefix",
[
(
"openai",
"text-embedding-3-small",
"Represent this for searching: ",
"openai_template",
),
("ollama", "nomic-embed-text", "search_query: ", "ollama_template"),
("sentence-transformers", "facebook/contriever", "query: ", "st_template"),
],
)
def test_prompt_template_metadata_with_embedding_modes(
self, temp_index_dir, mode, model, template, filename_prefix
):
"""Verify prompt template is saved correctly across different embedding modes.
Tests that prompt templates are persisted to .meta.json for:
- OpenAI mode (primary use case)
- Ollama mode (also supports templates)
- Sentence-transformers mode (saved for forward compatibility)
Expected behavior: Template is saved to .meta.json regardless of mode.
"""
with patch("leann.api.compute_embeddings") as mock_compute:
mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32)
index_path = temp_index_dir / f"{filename_prefix}.leann"
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model,
embedding_mode=mode,
embedding_options={"prompt_template": template},
)
builder.add_text(f"{mode.capitalize()} test document")
builder.build_index(str(index_path))
# Verify metadata
meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
assert meta_data["embedding_mode"] == mode
# Template should be saved for all modes (even if not used by some)
if "embedding_options" in meta_data:
assert meta_data["embedding_options"]["prompt_template"] == template
class TestQueryTemplateApplicationInComputeEmbedding:
"""Tests for query template application in compute_query_embedding() (Bug Fix).
These tests verify that query templates are applied consistently in BOTH
code paths (server and fallback) when computing query embeddings.
This addresses the bug where query templates were only applied in the
fallback path, not when using the embedding server (the default path).
Bug Context:
- Issue: Query templates were stored in metadata but only applied during
fallback (direct) computation, not when using embedding server
- Fix: Move template application to BEFORE any computation path in
compute_query_embedding() (searcher_base.py:107-110)
- Impact: Critical for models like EmbeddingGemma that require task-specific
templates for optimal performance
These tests ensure the fix works correctly and prevent regression.
"""
@pytest.fixture
def temp_index_with_template(self):
"""Create a temporary index with query template in metadata"""
with tempfile.TemporaryDirectory() as tmpdir:
index_dir = Path(tmpdir)
index_file = index_dir / "test.leann"
meta_file = index_dir / "test.leann.meta.json"
# Create minimal metadata with query template
metadata = {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": "text-embedding-embeddinggemma-300m-qat",
"dimensions": 768,
"embedding_mode": "openai",
"backend_kwargs": {
"graph_degree": 32,
"complexity": 64,
"distance_metric": "cosine",
},
"embedding_options": {
"base_url": "http://localhost:1234/v1",
"api_key": "test-key",
"build_prompt_template": "title: none | text: ",
"query_prompt_template": "task: search result | query: ",
},
}
meta_file.write_text(json.dumps(metadata, indent=2))
# Create minimal HNSW index file (empty is okay for this test)
index_file.write_bytes(b"")
yield str(index_file)
def test_query_template_applied_in_fallback_path(self, temp_index_with_template):
"""Test that query template is applied when using fallback (direct) path"""
from leann.searcher_base import BaseSearcher
# Create a concrete implementation for testing
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
# Load metadata
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
# Mock compute_embeddings to capture the query text
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
# Call compute_query_embedding with template (fallback path)
result = searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False, # Force fallback path
query_template="task: search result | query: ",
)
# Verify template was applied
assert len(captured_queries) == 1
assert captured_queries[0] == "task: search result | query: vector database"
assert result.shape == (1, 768)
def test_query_template_applied_in_server_path(self, temp_index_with_template):
"""Test that query template is applied when using server path"""
from leann.searcher_base import BaseSearcher
# Create a concrete implementation for testing
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
# Load metadata
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
# Mock the server methods to capture the query text
captured_queries = []
def mock_ensure_server_running(passages_file, port):
return port
def mock_compute_embedding_via_server(chunks, port):
captured_queries.extend(chunks)
return np.random.rand(len(chunks), 768).astype(np.float32)
searcher._ensure_server_running = mock_ensure_server_running
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
# Call compute_query_embedding with template (server path)
result = searcher.compute_query_embedding(
query="vector database",
use_server_if_available=True, # Use server path
query_template="task: search result | query: ",
)
# Verify template was applied BEFORE calling server
assert len(captured_queries) == 1
assert captured_queries[0] == "task: search result | query: vector database"
assert result.shape == (1, 768)
def test_query_template_without_template_parameter(self, temp_index_with_template):
"""Test that query is unchanged when no template is provided"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False,
query_template=None, # No template
)
# Verify query is unchanged
assert len(captured_queries) == 1
assert captured_queries[0] == "vector database"
def test_query_template_consistency_between_paths(self, temp_index_with_template):
"""Test that both paths apply template identically"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
query_template = "task: search result | query: "
original_query = "vector database"
# Capture queries from fallback path
fallback_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
fallback_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query=original_query,
use_server_if_available=False,
query_template=query_template,
)
# Capture queries from server path
server_queries = []
def mock_ensure_server_running(passages_file, port):
return port
def mock_compute_embedding_via_server(chunks, port):
server_queries.extend(chunks)
return np.random.rand(len(chunks), 768).astype(np.float32)
searcher._ensure_server_running = mock_ensure_server_running
searcher._compute_embedding_via_server = mock_compute_embedding_via_server
searcher.compute_query_embedding(
query=original_query,
use_server_if_available=True,
query_template=query_template,
)
# Verify both paths produced identical templated queries
assert len(fallback_queries) == 1
assert len(server_queries) == 1
assert fallback_queries[0] == server_queries[0]
assert fallback_queries[0] == f"{query_template}{original_query}"
def test_query_template_with_empty_string(self, temp_index_with_template):
"""Test behavior with empty template string"""
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
searcher.index_path = Path(temp_index_with_template)
searcher.index_dir = searcher.index_path.parent
meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json"
with open(meta_file) as f:
searcher.meta = json.load(f)
searcher.embedding_model = searcher.meta["embedding_model"]
searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers")
searcher.embedding_options = searcher.meta.get("embedding_options", {})
captured_queries = []
def mock_compute_embeddings(texts, model, mode, provider_options=None):
captured_queries.extend(texts)
return np.random.rand(len(texts), 768).astype(np.float32)
with patch(
"leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings
):
searcher.compute_query_embedding(
query="vector database",
use_server_if_available=False,
query_template="", # Empty string
)
# Empty string is falsy, so no template should be applied
assert captured_queries[0] == "vector database"

View File

@@ -0,0 +1,643 @@
"""Unit tests for token-aware truncation functionality.
This test suite defines the contract for token truncation functions that prevent
500 errors from Ollama when text exceeds model token limits. These tests verify:
1. Model token limit retrieval (known and unknown models)
2. Text truncation behavior for single and multiple texts
3. Token counting and truncation accuracy using tiktoken
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
import pytest
import tiktoken
from leann.embedding_compute import (
EMBEDDING_MODEL_LIMITS,
get_model_token_limit,
truncate_to_token_limit,
)
class TestModelTokenLimits:
"""Tests for retrieving model-specific token limits."""
def test_get_model_token_limit_known_model(self):
"""Verify correct token limit is returned for known models.
Known models should return their specific token limits from
EMBEDDING_MODEL_LIMITS dictionary.
"""
# Test nomic-embed-text (2048 tokens)
limit = get_model_token_limit("nomic-embed-text")
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
# Test nomic-embed-text-v1.5 (2048 tokens)
limit = get_model_token_limit("nomic-embed-text-v1.5")
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
# Test nomic-embed-text-v2 (512 tokens)
limit = get_model_token_limit("nomic-embed-text-v2")
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
# Test OpenAI models (8192 tokens)
limit = get_model_token_limit("text-embedding-3-small")
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
def test_get_model_token_limit_unknown_model(self):
"""Verify default token limit is returned for unknown models.
Unknown models should return the default limit (2048) to allow
operation with reasonable safety margin.
"""
# Test with completely unknown model
limit = get_model_token_limit("unknown-model-xyz")
assert limit == 2048, "Unknown models should return default 2048"
# Test with empty string
limit = get_model_token_limit("")
assert limit == 2048, "Empty model name should return default 2048"
def test_get_model_token_limit_custom_default(self):
"""Verify custom default can be specified for unknown models.
Allow callers to specify their own default token limit when
model is not in the known models dictionary.
"""
limit = get_model_token_limit("unknown-model", default=4096)
assert limit == 4096, "Should return custom default for unknown models"
# Known model should ignore custom default
limit = get_model_token_limit("nomic-embed-text", default=4096)
assert limit == 2048, "Known model should ignore custom default"
def test_embedding_model_limits_dictionary_exists(self):
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
The dictionary should be importable and contain at least the
known nomic models with correct token limits.
"""
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
"Should contain nomic-embed-text-v1.5"
)
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
# OpenAI models
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
class TestTokenTruncation:
"""Tests for truncating texts to token limits."""
@pytest.fixture
def tokenizer(self):
"""Provide tiktoken tokenizer for token counting verification."""
return tiktoken.get_encoding("cl100k_base")
def test_truncate_single_text_under_limit(self, tokenizer):
"""Verify text under token limit remains unchanged.
When text is already within the token limit, it should be
returned unchanged with no truncation.
"""
text = "This is a short text that is well under the token limit."
token_count = len(tokenizer.encode(text))
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
# Truncate with generous limit
result = truncate_to_token_limit([text], token_limit=512)
assert len(result) == 1, "Should return same number of texts"
assert result[0] == text, "Text under limit should be unchanged"
def test_truncate_single_text_over_limit(self, tokenizer):
"""Verify text over token limit is truncated correctly.
When text exceeds the token limit, it should be truncated to
fit within the limit while maintaining valid token boundaries.
"""
# Create a text that definitely exceeds limit
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 50, (
f"Test setup: text should be long (has {original_token_count} tokens)"
)
# Truncate to 50 tokens
result = truncate_to_token_limit([text], token_limit=50)
assert len(result) == 1, "Should return same number of texts"
assert result[0] != text, "Text over limit should be truncated"
assert len(result[0]) < len(text), "Truncated text should be shorter"
# Verify truncated text is within token limit
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 50, (
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
)
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
"""Verify multiple texts with mixed lengths are handled correctly.
When processing multiple texts:
- Texts under limit should remain unchanged
- Texts over limit should be truncated independently
- Output list should maintain same order and length
"""
texts = [
"Short text.", # Under limit
"word " * 200, # Over limit
"Another short one.", # Under limit
"token " * 150, # Over limit
]
# Verify test setup
for i, text in enumerate(texts):
token_count = len(tokenizer.encode(text))
if i in [1, 3]:
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
else:
assert token_count < 50, (
f"Text {i} should be under limit (has {token_count} tokens)"
)
# Truncate with 50 token limit
result = truncate_to_token_limit(texts, token_limit=50)
assert len(result) == len(texts), "Should return same number of texts"
# Verify each text individually
for i, (original, truncated) in enumerate(zip(texts, result)):
token_count = len(tokenizer.encode(truncated))
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
# Short texts should be unchanged
if i in [0, 2]:
assert truncated == original, f"Short text {i} should be unchanged"
# Long texts should be truncated
else:
assert len(truncated) < len(original), f"Long text {i} should be truncated"
def test_truncate_empty_list(self):
"""Verify empty input list returns empty output list.
Edge case: empty list should return empty list without errors.
"""
result = truncate_to_token_limit([], token_limit=512)
assert result == [], "Empty input should return empty output"
def test_truncate_preserves_order(self, tokenizer):
"""Verify truncation preserves original text order.
Output list should maintain the same order as input list,
regardless of which texts were truncated.
"""
texts = [
"First text " * 50, # Will be truncated
"Second text.", # Won't be truncated
"Third text " * 50, # Will be truncated
]
result = truncate_to_token_limit(texts, token_limit=20)
assert len(result) == 3, "Should preserve list length"
# Check that order is maintained by looking for distinctive words
assert "First" in result[0], "First text should remain in first position"
assert "Second" in result[1], "Second text should remain in second position"
assert "Third" in result[2], "Third text should remain in third position"
def test_truncate_extremely_long_text(self, tokenizer):
"""Verify extremely long texts are truncated efficiently.
Test with text that far exceeds token limit to ensure
truncation handles extreme cases without performance issues.
"""
# Create very long text (simulate real-world scenario)
text = "token " * 5000 # ~5000+ tokens
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 1000, "Test setup: text should be very long"
# Truncate to small limit
result = truncate_to_token_limit([text], token_limit=100)
assert len(result) == 1
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 100, (
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
)
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
def test_truncate_exact_token_limit(self, tokenizer):
"""Verify text at exactly token limit is handled correctly.
Edge case: text with exactly the token limit should either
remain unchanged or be safely truncated by 1 token.
"""
# Create text with approximately 50 tokens
# We'll adjust to get exactly 50
target_tokens = 50
text = "word " * 50
tokens = tokenizer.encode(text)
# Adjust to get exactly target_tokens
if len(tokens) > target_tokens:
tokens = tokens[:target_tokens]
text = tokenizer.decode(tokens)
elif len(tokens) < target_tokens:
# Add more words
while len(tokenizer.encode(text)) < target_tokens:
text += "word "
tokens = tokenizer.encode(text)[:target_tokens]
text = tokenizer.decode(tokens)
# Verify we have exactly target_tokens
assert len(tokenizer.encode(text)) == target_tokens, (
"Test setup: should have exactly 50 tokens"
)
result = truncate_to_token_limit([text], token_limit=target_tokens)
assert len(result) == 1
result_tokens = len(tokenizer.encode(result[0]))
assert result_tokens <= target_tokens, (
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
)
class TestLMStudioHybridDiscovery:
"""Tests for LM Studio integration in get_model_token_limit() hybrid discovery.
These tests verify that get_model_token_limit() properly integrates with
the LM Studio SDK bridge for dynamic token limit discovery. The integration
should:
1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL)
2. Convert HTTP URLs to WebSocket format for SDK queries
3. Query LM Studio SDK and use discovered limit
4. Fall back to registry when SDK returns None
5. Execute AFTER Ollama detection but BEFORE registry fallback
All tests are written in Red Phase - they should FAIL initially because the
LM Studio detection and integration logic does not exist yet in get_model_token_limit().
"""
def test_get_model_token_limit_lmstudio_success(self, monkeypatch):
"""Verify LM Studio SDK query succeeds and returns detected limit.
When a LM Studio base_url is detected and the SDK query succeeds,
get_model_token_limit() should return the dynamically discovered
context length without falling back to the registry.
"""
# Mock _query_lmstudio_context_limit to return successful SDK query
def mock_query_lmstudio(model_name, base_url):
# Verify WebSocket URL was passed (not HTTP)
assert base_url.startswith("ws://"), (
f"Should convert HTTP to WebSocket format, got: {base_url}"
)
return 8192 # Successful SDK query
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with HTTP URL that should be converted to WebSocket
limit = get_model_token_limit(
model_name="custom-model", base_url="http://localhost:1234/v1"
)
assert limit == 8192, "Should return limit from LM Studio SDK query"
def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch):
"""Verify fallback to registry when LM Studio SDK returns None.
When LM Studio SDK query fails (returns None), get_model_token_limit()
should fall back to the EMBEDDING_MODEL_LIMITS registry.
"""
# Mock _query_lmstudio_context_limit to return None (SDK failure)
def mock_query_lmstudio(model_name, base_url):
return None # SDK query failed
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with known model that exists in registry
limit = get_model_token_limit(
model_name="nomic-embed-text", base_url="http://localhost:1234/v1"
)
# Should fall back to registry value
assert limit == 2048, "Should fall back to registry when SDK returns None"
def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch):
"""Verify detection of LM Studio via port 1234.
get_model_token_limit() should recognize port 1234 as a LM Studio
server and attempt SDK query, regardless of hostname.
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return 4096
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with port 1234 (default LM Studio port)
limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1")
assert query_called, "Should detect port 1234 and call LM Studio SDK query"
assert limit == 4096, "Should return SDK query result"
@pytest.mark.parametrize(
"test_url,expected_limit,keyword",
[
("http://lmstudio.local:8080/v1", 16384, "lmstudio"),
("http://api.lm.studio:5000/v1", 32768, "lm.studio"),
],
)
def test_get_model_token_limit_lmstudio_url_keyword_detection(
self, monkeypatch, test_url, expected_limit, keyword
):
"""Verify detection of LM Studio via keywords in URL.
get_model_token_limit() should recognize 'lmstudio' or 'lm.studio'
in the URL as indicating a LM Studio server.
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return expected_limit
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
limit = get_model_token_limit(model_name="test-model", base_url=test_url)
assert query_called, f"Should detect '{keyword}' keyword and call SDK query"
assert limit == expected_limit, f"Should return SDK query result for {keyword}"
@pytest.mark.parametrize(
"input_url,expected_protocol,expected_host",
[
("http://localhost:1234/v1", "ws://", "localhost:1234"),
("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"),
],
)
def test_get_model_token_limit_protocol_conversion(
self, monkeypatch, input_url, expected_protocol, expected_host
):
"""Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query.
LM Studio SDK requires WebSocket URLs. get_model_token_limit() should:
1. Convert 'http://' to 'ws://'
2. Convert 'https://' to 'wss://'
3. Remove '/v1' or other path suffixes (SDK expects base URL)
4. Preserve host and port
"""
conversions_tested = []
def mock_query_lmstudio(model_name, base_url):
conversions_tested.append(base_url)
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
get_model_token_limit(model_name="test-model", base_url=input_url)
# Verify conversion happened
assert len(conversions_tested) == 1, "Should have called SDK query once"
assert conversions_tested[0].startswith(expected_protocol), (
f"Should convert to {expected_protocol}"
)
assert expected_host in conversions_tested[0], (
f"Should preserve host and port: {expected_host}"
)
def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch):
"""Verify LM Studio detection happens AFTER Ollama detection.
The hybrid discovery order should be:
1. Ollama dynamic discovery (port 11434 or 'ollama' in URL)
2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL)
3. Registry fallback
If both Ollama and LM Studio patterns match, Ollama should take precedence.
This test verifies that LM Studio is checked but doesn't interfere with Ollama.
"""
ollama_called = False
lmstudio_called = False
def mock_query_ollama(model_name, base_url):
nonlocal ollama_called
ollama_called = True
return 2048 # Ollama query succeeds
def mock_query_lmstudio(model_name, base_url):
nonlocal lmstudio_called
lmstudio_called = True
return None # Should not be reached if Ollama succeeds
monkeypatch.setattr(
"leann.embedding_compute._query_ollama_context_limit",
mock_query_ollama,
)
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with Ollama URL
limit = get_model_token_limit(
model_name="test-model", base_url="http://localhost:11434/api"
)
assert ollama_called, "Should attempt Ollama query first"
assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds"
assert limit == 2048, "Should return Ollama result"
def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch):
"""Verify LM Studio SDK query is NOT called for non-LM Studio URLs.
Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should
trigger LM Studio SDK queries. Other URLs should skip to registry fallback.
"""
lmstudio_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal lmstudio_called
lmstudio_called = True
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test with non-LM Studio URLs
test_cases = [
"http://localhost:8080/v1", # Different port
"http://openai.example.com/v1", # Different service
"http://localhost:3000/v1", # Another port
]
for base_url in test_cases:
lmstudio_called = False # Reset for each test
get_model_token_limit(model_name="nomic-embed-text", base_url=base_url)
assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}"
def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch):
"""Verify LM Studio detection is case-insensitive for keywords.
Keywords 'lmstudio' and 'lm.studio' should be detected regardless
of case (LMStudio, LMSTUDIO, LmStudio, etc.).
"""
query_called = False
def mock_query_lmstudio(model_name, base_url):
nonlocal query_called
query_called = True
return 8192
monkeypatch.setattr(
"leann.embedding_compute._query_lmstudio_context_limit",
mock_query_lmstudio,
)
# Test various case variations
test_cases = [
"http://LMStudio.local:8080/v1",
"http://LMSTUDIO.example.com/v1",
"http://LmStudio.local/v1",
"http://api.LM.STUDIO:5000/v1",
]
for base_url in test_cases:
query_called = False # Reset for each test
limit = get_model_token_limit(model_name="test-model", base_url=base_url)
assert query_called, f"Should detect LM Studio in URL: {base_url}"
assert limit == 8192, f"Should return SDK result for URL: {base_url}"
class TestTokenLimitCaching:
"""Tests for token limit caching to prevent repeated SDK/API calls.
Caching prevents duplicate SDK/API calls within the same Python process,
which is important because:
1. LM Studio SDK load() can load duplicate model instances
2. Ollama /api/show queries add latency
3. Registry lookups are pure overhead
Cache is process-scoped and resets between leann build invocations.
"""
def setup_method(self):
"""Clear cache before each test."""
from leann.embedding_compute import _token_limit_cache
_token_limit_cache.clear()
def test_registry_lookup_is_cached(self):
"""Verify that registry lookups are cached."""
from leann.embedding_compute import _token_limit_cache
# First call
limit1 = get_model_token_limit("text-embedding-3-small")
assert limit1 == 8192
# Verify it's in cache
cache_key = ("text-embedding-3-small", "")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 8192
# Second call should use cache
limit2 = get_model_token_limit("text-embedding-3-small")
assert limit2 == 8192
def test_default_fallback_is_cached(self):
"""Verify that default fallbacks are cached."""
from leann.embedding_compute import _token_limit_cache
# First call with unknown model
limit1 = get_model_token_limit("unknown-model-xyz", default=512)
assert limit1 == 512
# Verify it's in cache
cache_key = ("unknown-model-xyz", "")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 512
# Second call should use cache
limit2 = get_model_token_limit("unknown-model-xyz", default=512)
assert limit2 == 512
def test_different_urls_create_separate_cache_entries(self):
"""Verify that different base_urls create separate cache entries."""
from leann.embedding_compute import _token_limit_cache
# Same model, different URLs
limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434")
limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1")
# Both should find the model in registry (2048)
assert limit1 == 2048
assert limit2 == 2048
# But they should be separate cache entries
cache_key1 = ("nomic-embed-text", "http://localhost:11434")
cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1")
assert cache_key1 in _token_limit_cache
assert cache_key2 in _token_limit_cache
assert len(_token_limit_cache) == 2
def test_cache_prevents_repeated_lookups(self):
"""Verify that cache prevents repeated registry/API lookups."""
from leann.embedding_compute import _token_limit_cache
model_name = "text-embedding-ada-002"
# First call - should add to cache
assert len(_token_limit_cache) == 0
limit1 = get_model_token_limit(model_name)
cache_size_after_first = len(_token_limit_cache)
assert cache_size_after_first == 1
# Multiple subsequent calls - cache size should not change
for _ in range(5):
limit = get_model_token_limit(model_name)
assert limit == limit1
assert len(_token_limit_cache) == cache_size_after_first
def test_versioned_model_names_cached_correctly(self):
"""Verify that versioned model names (e.g., model:tag) are cached."""
from leann.embedding_compute import _token_limit_cache
# Model with version tag
limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434")
assert limit == 2048
# Should be cached with full name including version
cache_key = ("nomic-embed-text:latest", "http://localhost:11434")
assert cache_key in _token_limit_cache
assert _token_limit_cache[cache_key] == 2048

7553
uv.lock generated
View File

File diff suppressed because it is too large Load Diff