Compare commits

..

37 Commits

Author SHA1 Message Date
GitHub Actions
3c4785bb63 chore: release v0.3.5 2025-11-12 06:01:25 +00:00
yichuan-w
3766ad1fd2 robust multi-vector 2025-11-09 02:34:53 +00:00
ww26
c3aceed1e0 metadata reveal for ast-chunking; smart detection of seq length in ollama; auto adjust chunk length for ast to prevent silent truncation (#157)
* feat: enhance token limits with dynamic discovery + AST metadata

Improves upon upstream PR #154 with two major enhancements:

1. **Hybrid Token Limit Discovery**
   - Dynamic: Query Ollama /api/show for context limits
   - Fallback: Registry for LM Studio/OpenAI
   - Zero maintenance for Ollama users
   - Respects custom num_ctx settings

2. **AST Metadata Preservation**
   - create_ast_chunks() returns dict format with metadata
   - Preserves file_path, file_name, timestamps
   - Includes astchunk metadata (line numbers, node counts)
   - Fixes content extraction bug (checks "content" key)
   - Enables --show-metadata flag

3. **Better Token Limits**
   - nomic-embed-text: 2048 tokens (vs 512)
   - nomic-embed-text-v1.5: 2048 tokens
   - Added OpenAI models: 8192 tokens

4. **Comprehensive Tests**
   - 11 tests for token truncation
   - 545 new lines in test_astchunk_integration.py
   - All metadata preservation tests passing

* fix: merge EMBEDDING_MODEL_LIMITS and remove redundant validation

- Merged upstream's model list with our corrected token limits
- Kept our corrected nomic-embed-text: 2048 (not 512)
- Removed post-chunking validation (redundant with embedding-time truncation)
- All tests passing except 2 pre-existing integration test failures

* style: apply ruff formatting and restore PR #154 version handling

- Remove duplicate truncate_to_token_limit and get_model_token_limit functions
- Restore version handling logic (model:latest -> model) from PR #154
- Restore partial matching fallback for model name variations
- Apply ruff formatting to all modified files
- All 11 token truncation tests passing

* style: sort imports alphabetically (pre-commit auto-fix)

* fix: show AST token limit warning only once per session

- Add module-level flag to track if warning shown
- Prevents spam when processing multiple files
- Add clarifying note that auto-truncation happens at embedding time
- Addresses issue where warning appeared for every code file

* enhance: add detailed logging for token truncation

- Track and report truncation statistics (count, tokens removed, max length)
- Show first 3 individual truncations with exact token counts
- Provide comprehensive summary when truncation occurs
- Use WARNING level for data loss visibility
- Silent (DEBUG level only) when no truncation needed

Replaces misleading "truncated where necessary" message that appeared
even when nothing was truncated.
2025-11-08 17:37:31 -08:00
yichuan-w
dc6c9f696e update some search in copali 2025-11-08 08:53:03 +00:00
CalebZ9909
2406c41eef Update faiss submodule to latest commit
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-08 00:47:21 +00:00
Andy Lee
d4f5f2896f Faster Update (#148)
* stash

* stash

* add std err in add and trace progress

* fix.

* docs

* style: format

* docs

* better figs

* better figs

* update results

* fotmat

---------

Co-authored-by: yichuan-w <yichuan-w@users.noreply.github.com>
2025-11-05 13:37:47 -08:00
Aakash Suresh
366984e92e Merge pull request #154 from yichuan-w/fix/chunking-token-limit-behavior
Fix/chunking token limit behavior
2025-11-02 21:37:47 -08:00
aakash
64b92a04a7 fixing chunking token issues within limit for embedding models 2025-10-31 17:15:00 -07:00
ww26
a85d0ad4a7 Feature/optimize ollama batching (#152)
* feat: add metadata output to search results

- Add --show-metadata flag to display file paths in search results
- Preserve document metadata (file_path, file_name, timestamps) during chunking
- Update MCP tool schema to support show_metadata parameter
- Enhance CLI search output to display metadata when requested
- Fix pre-existing bug: args.backend -> args.backend_name

Resolves yichuan-w/LEANN#144

* fix: resolve ZMQ linking issues in Python extension

- Use pkg_check_modules IMPORTED_TARGET to create PkgConfig::ZMQ
- Set PKG_CONFIG_PATH to prioritize ARM64 Homebrew on Apple Silicon
- Override macOS -undefined dynamic_lookup to force proper symbol resolution
- Use PUBLIC linkage for ZMQ in faiss library for transitive linking
- Mark cppzmq includes as SYSTEM to suppress warnings

Fixes editable install ZMQ symbol errors while maintaining compatibility
across Linux, macOS Intel, and macOS ARM64 platforms.

* style: apply ruff formatting

* chore: update faiss submodule to use ww2283 fork

Use ww2283/faiss fork with fix/zmq-linking branch to resolve CI checkout
failures. The ZMQ linking fixes are not yet merged upstream.

* feat: implement true batch processing for Ollama embeddings

Migrate from deprecated /api/embeddings to modern /api/embed endpoint
which supports batch inputs. This reduces HTTP overhead by sending
32 texts per request instead of making individual API calls.

Changes:
- Update endpoint from /api/embeddings to /api/embed
- Change parameter from 'prompt' (single) to 'input' (array)
- Update response parsing for batch embeddings array
- Increase timeout to 60s for batch processing
- Improve error handling for batch requests

Performance:
- Reduces API calls by 32x (batch size)
- Eliminates HTTP connection overhead per text
- Note: Ollama still processes batch items sequentially internally

Related: #151

* fall back to original faiss as i merge the PR

---------

Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
2025-10-30 16:39:14 -07:00
yichuan-w
dbb5f4d352 Fix CI failure by removing paru-bin submodule
Remove paru-bin directory that was incorrectly added as a git submodule.
This directory is an AUR build artifact and should not be tracked.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-25 14:51:06 -07:00
yichuan-w
f180b83589 add deep wiki 2025-10-25 14:46:17 -07:00
CelineNi2
abf312d998 Display context chunks in ask and search results (#149)
* Printing querying time

* Adding source name to chunks

Adding source name as metadata to chunks, then printing the sources when searching

* Printing the context provided to LLM

To check the data transmitted to the LLMs : display the relevance, ID, content, and source of each sent chunk.

* Correcting source as metadata for chunks

* Applying ruff format

* Applying Ruff formatting

* Ruff formatting
2025-10-23 15:03:59 -07:00
Aakash Suresh
ab251ab751 Fix/twitter bookmarks anchor link (#143)
* fix: Fix Twitter bookmarks anchor link

- Convert Twitter Bookmarks from collapsible details to proper header
- Update internal link to match new anchor format
- Ensures external links to #twitter-bookmarks-your-personal-tweet-library work correctly

Fixes broken link: https://github.com/yichuan-w/LEANN?tab=readme-ov-file#twitter-bookmarks-your-personal-tweet-library

* fix: Fix Slack messages anchor link as well

- Convert Slack Messages from collapsible details to proper header
- Update internal link to match new anchor format
- Ensures external links to #slack-messages-search-your-team-conversations work correctly

Both Twitter and Slack MCP sections now have reliable anchor links.

* fix: Point Slack and Twitter links to main MCP section

- Both Slack and Twitter are subsections under MCP Integration
- Links should point to #mcp-integration-rag-on-live-data-from-any-platform
- Users will land on the MCP section and can find both Slack and Twitter subsections there

This matches the actual document structure where Slack and Twitter are under the MCP Integration section.

* Improve Slack MCP integration with retry logic and comprehensive setup guide

- Add retry mechanism with exponential backoff for cache sync issues
- Handle 'users cache is not ready yet' errors gracefully
- Add max-retries and retry-delay CLI arguments for better control
- Create comprehensive Slack setup guide with troubleshooting
- Update README with link to detailed setup guide
- Improve error messages and user experience

* Fix trailing whitespace in slack setup guide

Pre-commit hooks formatting fixes

* Add comprehensive Slack setup guide with success screenshot

- Create detailed setup guide with step-by-step instructions
- Add troubleshooting section for common issues like cache sync errors
- Include real terminal output example from successful integration
- Add screenshot showing VS Code interface with Slack channel data
- Remove excessive emojis for more professional documentation
- Document retry logic improvements and CLI arguments

* Fix formatting issues in Slack setup guide

- Remove trailing whitespace
- Fix end of file formatting
- Pre-commit hooks formatting fixes

* Add real RAG example showing intelligent Slack query functionality

- Add detailed example of asking 'What is LEANN about?'
- Show retrieved messages from Slack channels
- Demonstrate intelligent answer generation based on context
- Add command example for running real RAG queries
- Explain the 4-step process: retrieve, index, generate, cite

* Update Slack setup guide with bot invitation requirements

- Add important section about inviting bot to channels before RAG queries
- Explain the 'not_in_channel' errors and their meaning
- Provide clear steps for bot invitation process
- Document realistic scenario where bot needs explicit channel access
- Update documentation to be more professional and less cursor-style

* Docs: add real RAG example for Sky Lab #random

- Embed screenshot videos/rag-sky-random.png
- Add step-by-step commands and notes
- Include helper test script tests/test_channel_by_id_or_name.py
- Redact example tokens from docs

* Docs/CI: fix broken image paths and ruff lint\n\n- Move screenshot to docs/videos and update references\n- Remove obsolete rag-query-results image\n- Rename variable to satisfy ruff

* Docs: fix image path for lychee (use videos/ relative under docs/)

* Docs: finalize Slack setup guide with Sky random RAG example and image path fixes\n\n- Redact example tokens from docs

* Fix Slack MCP integration and update documentation

- Fix SlackMCPReader to use conversations_history instead of channels_list
- Add fallback imports for leann.interactive_utils and leann.settings
- Update slack-setup-guide.md with real screenshots and improved text
- Remove old screenshot files

* Add Slack integration screenshots to docs/videos

- Add slack_integration.png showing RAG query results
- Add slack_integration_2.png showing additional demo functionality
- Fixes lychee link checker errors for missing image files

* Update Slack integration screenshot with latest changes

* Remove test_channel_by_id_or_name.py

- Clean up temporary test file that was used for debugging
- Keep only the main slack_rag.py application for production use

* Update Slack RAG example to show LEANN announcement retrieval

- Change query from 'PUBPOL 290' to 'What is LEANN about?' for more challenging retrieval
- Update command to use python -m apps.slack_rag instead of test script
- Add expected response showing Yichuan Wang's LEANN announcement message
- Emphasize this demonstrates ability to find specific announcements in conversation history
- Update description to highlight challenging query capabilities

* Update Slack RAG integration with improved CSV parsing and new screenshots

- Fixed CSV message parsing in slack_mcp_reader.py to properly handle individual messages
- Updated slack_rag.py to filter empty channel strings
- Enhanced slack-setup-guide.md with two new query examples:
  - Advisor Models query: 'train black-box models to adopt to your personal data'
  - Barbarians at the Gate query: 'AI-driven research systems ADRS'
- Replaced old screenshots with four new ones showing both query examples
- Updated documentation to use User OAuth Token (xoxp-) instead of Bot Token (xoxb-)
- Added proper command examples with --no-concatenate-conversations and --force-rebuild flags

* Update Slack RAG documentation with Ollama integration and new screenshots

- Updated slack-setup-guide.md with comprehensive Ollama setup instructions
- Added 6 new screenshots showing complete RAG workflow:
  - Command setup, search results, and LLM responses for both queries
- Removed simulated LLM references, now uses real Ollama with llama3.2:1b
- Enhanced documentation with step-by-step Ollama installation
- Updated troubleshooting checklist to include Ollama-specific checks
- Fixed command syntax and added proper Ollama configuration
- Demonstrates working Slack RAG with real AI-generated responses

* Remove Key Features section from Slack RAG examples

- Simplified documentation by removing the bullet point list
- Keeps the focus on the actual examples and screenshots
2025-10-19 11:47:29 -07:00
CelineNi2
28085f6f04 Add messages regarding the use of token during query (#147)
* Add messages regarding the use of token during query

* fix: apply ruff format
2025-10-15 16:48:48 -07:00
CelineNi2
6495833887 Changing the option name "--backend" for "--backend-name" as written in the documentation (#146) 2025-10-14 13:35:10 -07:00
yichuan520030910320
5543b3c5f7 [minor] format fix 2025-10-09 15:10:54 -07:00
yichuan-w
a99983b3d9 fix readme 2025-10-08 21:51:25 +00:00
Aakash Suresh
36482e016c fix: Fix Twitter bookmarks anchor link (#140)
* fix: Fix Twitter bookmarks anchor link

- Convert Twitter Bookmarks from collapsible details to proper header
- Update internal link to match new anchor format
- Ensures external links to #twitter-bookmarks-your-personal-tweet-library work correctly

Fixes broken link: https://github.com/yichuan-w/LEANN?tab=readme-ov-file#twitter-bookmarks-your-personal-tweet-library

* fix: Fix Slack messages anchor link as well

- Convert Slack Messages from collapsible details to proper header
- Update internal link to match new anchor format
- Ensures external links to #slack-messages-search-your-team-conversations work correctly

Both Twitter and Slack MCP sections now have reliable anchor links.

* fix: Point Slack and Twitter links to main MCP section

- Both Slack and Twitter are subsections under MCP Integration
- Links should point to #mcp-integration-rag-on-live-data-from-any-platform
- Users will land on the MCP section and can find both Slack and Twitter subsections there

This matches the actual document structure where Slack and Twitter are under the MCP Integration section.
2025-10-08 02:32:02 -07:00
Aakash Suresh
32967daf81 security: Enhance Hugging Face model loading security - resolves #136 (#138)
BREAKING CHANGE: trust_remote_code now defaults to False for security

- Set trust_remote_code=False by default in HFChat class
- Add explicit trust_remote_code parameter to HFChat.__init__()
- Add security warning when trust_remote_code=True is used
- Update get_llm() function to support trust_remote_code parameter
- Update benchmark utilities (load_hf_model, load_vllm_model, load_qwen_vl_model)
- Add comprehensive documentation for security implications

Security Benefits:
- Prevents arbitrary code execution from compromised model repositories
- Requires explicit opt-in for models that need remote code execution
- Shows clear warnings when security is reduced
- Follows security-by-default principle

Migration Guide:
- Most users: No changes needed (more secure by default)
- Users with models requiring remote code: Add trust_remote_code=True explicitly
- Config users: Add 'trust_remote_code': true to LLM config if needed

Fixes #136
2025-10-07 13:13:44 -07:00
Aakash Suresh
b4bb8dec75 feat: Add MCP integration support for Slack and Twitter (#134)
* feat: Add MCP integration support for Slack and Twitter

- Implement SlackMCPReader for connecting to Slack MCP servers
- Implement TwitterMCPReader for connecting to Twitter MCP servers
- Add SlackRAG and TwitterRAG applications with full CLI support
- Support live data fetching via Model Context Protocol (MCP)
- Add comprehensive documentation and usage examples
- Include connection testing capabilities with --test-connection flag
- Add standalone tests for core functionality
- Update README with detailed MCP integration guide
- Add Aakash Suresh to Active Contributors

Resolves #36

* fix: Resolve linting issues in MCP integration

- Replace deprecated typing.Dict/List with built-in dict/list
- Fix boolean comparisons (== True/False) to direct checks
- Remove unused variables in demo script
- Update type annotations to use modern Python syntax

All pre-commit hooks should now pass.

* fix: Apply final formatting fixes for pre-commit hooks

- Remove unused imports (asyncio, pathlib.Path)
- Remove unused class imports in demo script
- Ensure all files pass ruff format and pre-commit checks

This should resolve all remaining CI linting issues.

* fix: Apply pre-commit formatting changes

- Fix trailing whitespace in all files
- Apply ruff formatting to match project standards
- Ensure consistent code style across all MCP integration files

This commit applies the exact changes that pre-commit hooks expect.

* fix: Apply pre-commit hooks formatting fixes

- Remove trailing whitespace from all files
- Fix ruff formatting issues (2 errors resolved)
- Apply consistent code formatting across 3 files
- Ensure all files pass pre-commit validation

This resolves all CI formatting failures.

* fix: Update MCP RAG classes to match BaseRAGExample signature

- Fix SlackMCPRAG and TwitterMCPRAG __init__ methods to provide required parameters
- Add name, description, and default_index_name to super().__init__ calls
- Resolves test failures: test_slack_rag_initialization and test_twitter_rag_initialization

This fixes the TypeError caused by BaseRAGExample requiring additional parameters.

* style: Apply ruff formatting - add trailing commas

- Add trailing commas to super().__init__ calls in SlackMCPRAG and TwitterMCPRAG
- Fixes ruff format pre-commit hook requirements

* fix: Resolve SentenceTransformer model_kwargs parameter conflict

- Fix local_files_only parameter conflict in embedding_compute.py
- Create separate copies of model_kwargs and tokenizer_kwargs for local vs network loading
- Prevents parameter conflicts when falling back from local to network loading
- Resolves TypeError in test_readme_examples.py tests

This addresses the SentenceTransformer initialization issues in CI tests.

* fix: Add comprehensive SentenceTransformer version compatibility

- Handle both old and new sentence-transformers versions
- Gracefully fallback from advanced parameters to basic initialization
- Catch TypeError for model_kwargs/tokenizer_kwargs and use basic SentenceTransformer init
- Ensures compatibility across different CI environments and local setups
- Maintains optimization benefits where supported while ensuring broad compatibility

This resolves test failures in CI environments with older sentence-transformers versions.

* style: Apply ruff formatting to embedding_compute.py

- Break long logger.warning lines for better readability
- Fixes pre-commit hook formatting requirements

* docs: Comprehensive documentation improvements for better user experience

- Add clear step-by-step Getting Started Guide for new users
- Add comprehensive CLI Reference with all commands and options
- Improve installation instructions with clear steps and verification
- Add detailed troubleshooting section for common issues (Ollama, OpenAI, etc.)
- Clarify difference between CLI commands and specialized apps
- Add environment variables documentation
- Improve MCP integration documentation with CLI integration examples
- Address user feedback about confusing installation and setup process

This resolves documentation gaps that made LEANN difficult for non-specialists to use.

* style: Remove trailing whitespace from README.md

- Fix trailing whitespace issues found by pre-commit hooks
- Ensures consistent formatting across documentation

* docs: Simplify README by removing excessive documentation

- Remove overly complex CLI reference and getting started sections (lines 61-334)
- Remove emojis from section headers for cleaner appearance
- Keep README simple and focused as requested
- Maintain essential MCP integration documentation

This addresses feedback to keep documentation minimal and avoid auto-generated content.

* docs: Address maintainer feedback on README improvements

- Restore emojis in section headers (Prerequisites and Quick Install)
- Add MCP live data feature mention in line 23 with links to Slack and Twitter
- Add detailed API credential setup instructions for Slack:
  - Step-by-step Slack App creation process
  - Required OAuth scopes and permissions
  - Clear token identification (xoxb- vs xapp-)
- Add detailed API credential setup instructions for Twitter:
  - Twitter Developer Account application process
  - API v2 requirements for bookmarks access
  - Required permissions and scopes

This addresses maintainer feedback to make API setup more user-friendly.
2025-10-07 02:18:32 -07:00
Andy Lee
5ba9cf6442 chore: require sentence-transformers >=3 and pin transformers <4.46 2025-10-06 15:52:56 -07:00
Andy Lee
1484406a8d chore: align core deps with transformers pin 2025-10-05 19:01:58 -07:00
Andy Lee
761ec1f0ac chore: pin transformers for py39 2025-10-05 18:29:45 -07:00
Andy Lee
4808afc686 docs: point DiskANN link to public PDF 2025-10-05 17:58:57 -07:00
Jon Haddad
0bba4b2157 Add readline support to interactive command line interfaces (#121)
* Add readline support to interactive command line interfaces

- Implement readline history, navigation, and editing for CLI, API, and RAG chat modes
- Create shared InteractiveSession class to consolidate readline functionality
- Add command history persistence across sessions with separate files per context
- Support built-in commands: help, clear, history, quit/exit
- Enable arrow key navigation and command editing in all interactive modes

* Improvements based on feedback
2025-10-05 17:38:15 -07:00
Kishlay Kisu
e67b5f44fa Implement FileSystem wide semantic file search engine with temporal awareness using LEANN. (#103)
* system wide semantic file search with temporal awareness

* ruff checking passed

* graceful exit for empty dump

* error thrown for time only search

* fixes
2025-10-05 17:26:48 -07:00
Aakash Suresh
658bce47ef Feature/imessage rag support (#131) 2025-10-02 10:40:57 -07:00
Andy Lee
6b399ad8d2 fix: launch another port when updating 2025-09-30 13:00:00 -07:00
Andy Lee
16f35aa067 Update faiss for batch distances calc & caching when updating 2025-09-30 12:42:40 -07:00
Andy Lee
ab9c6bd69e Fix update. Should launch embedding server first (#130)
* fix: set ntotal for storage as well

* fix: launch embedding server before adding
2025-09-30 00:58:17 -07:00
yichuan520030910320
e2b37914ce add dynamic add test 2025-09-30 00:48:46 -07:00
Andy Lee
e588100674 fix: set ntotal for storage as well (#129) 2025-09-29 20:43:16 -07:00
Andy Lee
fecee94af1 Experiments (#68)
* feat: finance bench

* docs: results

* chore: ignroe data README

* feat: fix financebench

* feat: laion, also required idmaps support

* style: format

* style: format

* fix: resolve ruff linting errors

- Remove unused variables in benchmark scripts
- Rename unused loop variables to follow convention

* feat: enron email bench

* experiments for running DiskANN & BM25 on Arch 4090

* style: format

* chore(ci): remove paru-bin submodule and config to fix checkout --recurse-submodules

* docs: data

* docs: data updated

* fix: as package

* fix(ci): only run pre-commit

* chore: use http url of astchunk; use group for some dev deps

* fix(ci): should checkout modules as well since `uv sync` checks

* fix(ci): run with lint only

* fix: find links to install wheels available

* CI: force local wheels in uv install step

* CI: install local wheels via file paths

* CI: pick wheels matching current Python tag

* CI: handle python tag mismatches for local wheels

* CI: use matrix python venv and set macOS deployment target

* CI: revert install step to match main

* CI: use uv group install with local wheel selection

* CI: rely on setup-uv for Python and tighten group install

* CI: install build deps with uv python interpreter

* CI: use temporary uv venv for build deps

* CI: add build venv scripts path for wheel repair
2025-09-24 11:19:04 -07:00
yichuan520030910320
01475c10a0 add img 2025-09-23 23:25:05 -07:00
yichuan520030910320
c8aa063f48 merge main 2025-09-23 23:21:53 -07:00
yichuan520030910320
576beb13db add doc about multimodal 2025-09-23 23:21:03 -07:00
Andy Lee
63c7b0c8a3 Fix restart embedding server when passages change (#117)
* fix: restart embedding server when passages change

* fix: restore python 3.9 typing compatibility
2025-09-23 22:28:36 -07:00
67 changed files with 14027 additions and 4596 deletions

9
.gitignore vendored
View File

@@ -99,3 +99,12 @@ benchmarks/data/
## multi vector
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
# If you need to commit a specific demo PDF, remove this negation locally.
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
# AUR build directory (Arch Linux)
paru-bin/

401
README.md
View File

@@ -8,8 +8,12 @@
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q">
<img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
</a>
<a href="assets/wechat_user_group.JPG" title="Join WeChat group">
<img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group">
</a>
</p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
@@ -20,7 +24,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
@@ -72,8 +76,9 @@ uv venv
source .venv/bin/activate
uv pip install leann
```
<!--
> Low-resource? See Low-resource setups in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
> Low-resource? See "Low-resource setups" in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
<details>
<summary>
@@ -176,7 +181,7 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
## RAG on Everything!
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, ChatGPT conversations, Claude conversations, iMessage conversations, and **live data from any platform through MCP (Model Context Protocol) servers** - including Slack, Twitter, and more.
@@ -542,10 +547,386 @@ Once the index is built, you can ask questions like:
</details>
### 🤖 ChatGPT Chat History: Your Personal AI Conversation Archive!
Transform your ChatGPT conversations into a searchable knowledge base! Search through all your ChatGPT discussions about coding, research, brainstorming, and more.
```bash
python -m apps.chatgpt_rag --export-path chatgpt_export.html --query "How do I create a list in Python?"
```
**Unlock your AI conversation history.** Never lose track of valuable insights from your ChatGPT discussions again.
<details>
<summary><strong>📋 Click to expand: How to Export ChatGPT Data</strong></summary>
**Step-by-step export process:**
1. **Sign in to ChatGPT**
2. **Click your profile icon** in the top right corner
3. **Navigate to Settings** → **Data Controls**
4. **Click "Export"** under Export Data
5. **Confirm the export** request
6. **Download the ZIP file** from the email link (expires in 24 hours)
7. **Extract or use directly** with LEANN
**Supported formats:**
- `.html` files from ChatGPT exports
- `.zip` archives from ChatGPT
- Directories with multiple export files
</details>
<details>
<summary><strong>📋 Click to expand: ChatGPT-Specific Arguments</strong></summary>
#### Parameters
```bash
--export-path PATH # Path to ChatGPT export file (.html/.zip) or directory (default: ./chatgpt_export)
--separate-messages # Process each message separately instead of concatenated conversations
--chunk-size N # Text chunk size (default: 512)
--chunk-overlap N # Overlap between chunks (default: 128)
```
#### Example Commands
```bash
# Basic usage with HTML export
python -m apps.chatgpt_rag --export-path conversations.html
# Process ZIP archive from ChatGPT
python -m apps.chatgpt_rag --export-path chatgpt_export.zip
# Search with specific query
python -m apps.chatgpt_rag --export-path chatgpt_data.html --query "Python programming help"
# Process individual messages for fine-grained search
python -m apps.chatgpt_rag --separate-messages --export-path chatgpt_export.html
# Process directory containing multiple exports
python -m apps.chatgpt_rag --export-path ./chatgpt_exports/ --max-items 1000
```
</details>
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
Once your ChatGPT conversations are indexed, you can search with queries like:
- "What did I ask ChatGPT about Python programming?"
- "Show me conversations about machine learning algorithms"
- "Find discussions about web development frameworks"
- "What coding advice did ChatGPT give me?"
- "Search for conversations about debugging techniques"
- "Find ChatGPT's recommendations for learning resources"
</details>
### 🤖 Claude Chat History: Your Personal AI Conversation Archive!
Transform your Claude conversations into a searchable knowledge base! Search through all your Claude discussions about coding, research, brainstorming, and more.
```bash
python -m apps.claude_rag --export-path claude_export.json --query "What did I ask about Python dictionaries?"
```
**Unlock your AI conversation history.** Never lose track of valuable insights from your Claude discussions again.
<details>
<summary><strong>📋 Click to expand: How to Export Claude Data</strong></summary>
**Step-by-step export process:**
1. **Open Claude** in your browser
2. **Navigate to Settings** (look for gear icon or settings menu)
3. **Find Export/Download** options in your account settings
4. **Download conversation data** (usually in JSON format)
5. **Place the file** in your project directory
*Note: Claude export methods may vary depending on the interface you're using. Check Claude's help documentation for the most current export instructions.*
**Supported formats:**
- `.json` files (recommended)
- `.zip` archives containing JSON data
- Directories with multiple export files
</details>
<details>
<summary><strong>📋 Click to expand: Claude-Specific Arguments</strong></summary>
#### Parameters
```bash
--export-path PATH # Path to Claude export file (.json/.zip) or directory (default: ./claude_export)
--separate-messages # Process each message separately instead of concatenated conversations
--chunk-size N # Text chunk size (default: 512)
--chunk-overlap N # Overlap between chunks (default: 128)
```
#### Example Commands
```bash
# Basic usage with JSON export
python -m apps.claude_rag --export-path my_claude_conversations.json
# Process ZIP archive from Claude
python -m apps.claude_rag --export-path claude_export.zip
# Search with specific query
python -m apps.claude_rag --export-path claude_data.json --query "machine learning advice"
# Process individual messages for fine-grained search
python -m apps.claude_rag --separate-messages --export-path claude_export.json
# Process directory containing multiple exports
python -m apps.claude_rag --export-path ./claude_exports/ --max-items 1000
```
</details>
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
Once your Claude conversations are indexed, you can search with queries like:
- "What did I ask Claude about Python programming?"
- "Show me conversations about machine learning algorithms"
- "Find discussions about software architecture patterns"
- "What debugging advice did Claude give me?"
- "Search for conversations about data structures"
- "Find Claude's recommendations for learning resources"
</details>
### 💬 iMessage History: Your Personal Conversation Archive!
Transform your iMessage conversations into a searchable knowledge base! Search through all your text messages, group chats, and conversations with friends, family, and colleagues.
```bash
python -m apps.imessage_rag --query "What did we discuss about the weekend plans?"
```
**Unlock your message history.** Never lose track of important conversations, shared links, or memorable moments from your iMessage history.
<details>
<summary><strong>📋 Click to expand: How to Access iMessage Data</strong></summary>
**iMessage data location:**
iMessage conversations are stored in a SQLite database on your Mac at:
```
~/Library/Messages/chat.db
```
**Important setup requirements:**
1. **Grant Full Disk Access** to your terminal or IDE:
- Open **System Preferences** → **Security & Privacy** → **Privacy**
- Select **Full Disk Access** from the left sidebar
- Click the **+** button and add your terminal app (Terminal, iTerm2) or IDE (VS Code, etc.)
- Restart your terminal/IDE after granting access
2. **Alternative: Use a backup database**
- If you have Time Machine backups or manual copies of the database
- Use `--db-path` to specify a custom location
**Supported formats:**
- Direct access to `~/Library/Messages/chat.db` (default)
- Custom database path with `--db-path`
- Works with backup copies of the database
</details>
<details>
<summary><strong>📋 Click to expand: iMessage-Specific Arguments</strong></summary>
#### Parameters
```bash
--db-path PATH # Path to chat.db file (default: ~/Library/Messages/chat.db)
--concatenate-conversations # Group messages by conversation (default: True)
--no-concatenate-conversations # Process each message individually
--chunk-size N # Text chunk size (default: 1000)
--chunk-overlap N # Overlap between chunks (default: 200)
```
#### Example Commands
```bash
# Basic usage (requires Full Disk Access)
python -m apps.imessage_rag
# Search with specific query
python -m apps.imessage_rag --query "family dinner plans"
# Use custom database path
python -m apps.imessage_rag --db-path /path/to/backup/chat.db
# Process individual messages instead of conversations
python -m apps.imessage_rag --no-concatenate-conversations
# Limit processing for testing
python -m apps.imessage_rag --max-items 100 --query "weekend"
```
</details>
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
Once your iMessage conversations are indexed, you can search with queries like:
- "What did we discuss about vacation plans?"
- "Find messages about restaurant recommendations"
- "Show me conversations with John about the project"
- "Search for shared links about technology"
- "Find group chat discussions about weekend events"
- "What did mom say about the family gathering?"
</details>
### MCP Integration: RAG on Live Data from Any Platform
Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
**Key Benefits:**
- **Live Data Access**: Fetch real-time data without manual exports
- **Standardized Protocol**: Use any MCP-compatible server
- **Easy Extension**: Add new platforms with minimal code
- **Secure Access**: MCP servers handle authentication
#### 💬 Slack Messages: Search Your Team Conversations
Transform your Slack workspace into a searchable knowledge base! Find discussions, decisions, and shared knowledge across all your channels.
```bash
# Test MCP server connection
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
# Index and search Slack messages
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "my-team" \
--channels general dev-team random \
--query "What did we decide about the product launch?"
```
**📖 Comprehensive Setup Guide**: For detailed setup instructions, troubleshooting common issues (like "users cache is not ready yet"), and advanced configuration options, see our [**Slack Setup Guide**](docs/slack-setup-guide.md).
**Quick Setup:**
1. Install a Slack MCP server (e.g., `npm install -g slack-mcp-server`)
2. Create a Slack App and get API credentials (see detailed guide above)
3. Set environment variables:
```bash
export SLACK_BOT_TOKEN="xoxb-your-bot-token"
export SLACK_APP_TOKEN="xapp-your-app-token" # Optional
```
4. Test connection with `--test-connection` flag
**Arguments:**
- `--mcp-server`: Command to start the Slack MCP server
- `--workspace-name`: Slack workspace name for organization
- `--channels`: Specific channels to index (optional)
- `--concatenate-conversations`: Group messages by channel (default: true)
- `--max-messages-per-channel`: Limit messages per channel (default: 100)
- `--max-retries`: Maximum retries for cache sync issues (default: 5)
- `--retry-delay`: Initial delay between retries in seconds (default: 2.0)
#### 🐦 Twitter Bookmarks: Your Personal Tweet Library
Search through your Twitter bookmarks! Find that perfect article, thread, or insight you saved for later.
```bash
# Test MCP server connection
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --test-connection
# Index and search Twitter bookmarks
python -m apps.twitter_rag \
--mcp-server "twitter-mcp-server" \
--max-bookmarks 1000 \
--query "What AI articles did I bookmark about machine learning?"
```
**Setup Requirements:**
1. Install a Twitter MCP server (e.g., `npm install -g twitter-mcp-server`)
2. Get Twitter API credentials:
- Apply for a Twitter Developer Account at [developer.twitter.com](https://developer.twitter.com)
- Create a new app in the Twitter Developer Portal
- Generate API keys and access tokens with "Read" permissions
- For bookmarks access, you may need Twitter API v2 with appropriate scopes
```bash
export TWITTER_API_KEY="your-api-key"
export TWITTER_API_SECRET="your-api-secret"
export TWITTER_ACCESS_TOKEN="your-access-token"
export TWITTER_ACCESS_TOKEN_SECRET="your-access-token-secret"
```
3. Test connection with `--test-connection` flag
**Arguments:**
- `--mcp-server`: Command to start the Twitter MCP server
- `--username`: Filter bookmarks by username (optional)
- `--max-bookmarks`: Maximum bookmarks to fetch (default: 1000)
- `--no-tweet-content`: Exclude tweet content, only metadata
- `--no-metadata`: Exclude engagement metadata
</details>
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
**Slack Queries:**
- "What did the team discuss about the project deadline?"
- "Find messages about the new feature launch"
- "Show me conversations about budget planning"
- "What decisions were made in the dev-team channel?"
**Twitter Queries:**
- "What AI articles did I bookmark last month?"
- "Find tweets about machine learning techniques"
- "Show me bookmarked threads about startup advice"
- "What Python tutorials did I save?"
</details>
<summary><strong>🔧 Using MCP with CLI Commands</strong></summary>
**Want to use MCP data with regular LEANN CLI?** You can combine MCP apps with CLI commands:
```bash
# Step 1: Use MCP app to fetch and index data
python -m apps.slack_rag --mcp-server "slack-mcp-server" --workspace-name "my-team"
# Step 2: The data is now indexed and available via CLI
leann search slack_messages "project deadline"
leann ask slack_messages "What decisions were made about the product launch?"
# Same for Twitter bookmarks
python -m apps.twitter_rag --mcp-server "twitter-mcp-server"
leann search twitter_bookmarks "machine learning articles"
```
**MCP vs Manual Export:**
- **MCP**: Live data, automatic updates, requires server setup
- **Manual Export**: One-time setup, works offline, requires manual data export
</details>
<details>
<summary><strong>🔧 Adding New MCP Platforms</strong></summary>
Want to add support for other platforms? LEANN's MCP integration is designed for easy extension:
1. **Find or create an MCP server** for your platform
2. **Create a reader class** following the pattern in `apps/slack_data/slack_mcp_reader.py`
3. **Create a RAG application** following the pattern in `apps/slack_rag.py`
4. **Test and contribute** back to the community!
**Popular MCP servers to explore:**
- GitHub repositories and issues
- Discord messages
- Notion pages
- Google Drive documents
- And many more in the MCP ecosystem!
</details>
### 🚀 Claude Code Integration: Transform Your Development Workflow!
<details>
<summary><strong>NEW!! ASTAware Code Chunking</strong></summary>
<summary><strong>ASTAware Code Chunking</strong></summary>
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
@@ -573,7 +954,7 @@ Try our fully agentic pipeline with auto query rewriting, semantic search planni
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
## 🖥️ Command Line Interface
## Command Line Interface
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
@@ -815,7 +1196,7 @@ MIT License - see [LICENSE](LICENSE) for details.
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan), [Aakash Suresh](https://github.com/ASuresh0524)
We welcome more contributors! Feel free to open issues or submit PRs.
@@ -832,3 +1213,7 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
<p align="center">
Made with ❤️ by the Leann team
</p>
## 🤖 Explore LEANN with AI
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.

View File

@@ -10,8 +10,39 @@ from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
# Optional import: older PyPI builds may not include interactive_utils
try:
from leann.interactive_utils import create_rag_session
except ImportError:
def create_rag_session(app_name: str, data_description: str):
class _SimpleSession:
def run_interactive_loop(self, handler):
print(f"Interactive session for {app_name}: {data_description}")
print("Interactive mode not available in this build")
return _SimpleSession()
from leann.registry import register_project_directory
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Optional import: older PyPI builds may not include settings
try:
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
except ImportError:
# Minimal fallbacks if settings helpers are unavailable
import os
def resolve_ollama_host(value: str | None) -> str | None:
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
def resolve_openai_api_key(value: str | None) -> str | None:
return value or os.getenv("OPENAI_API_KEY")
def resolve_openai_base_url(value: str | None) -> str | None:
return value or os.getenv("OPENAI_BASE_URL")
dotenv.load_dotenv()
@@ -149,14 +180,14 @@ class BaseRAGExample(ABC):
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=512,
help="Maximum characters per AST chunk (default: 512)",
default=300,
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks (default: 64)",
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
)
ast_group.add_argument(
"--code-file-extensions",
@@ -307,37 +338,26 @@ class BaseRAGExample(ABC):
complexity=args.search_complexity,
)
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
print("Type 'quit' or 'exit' to stop.\n")
# Create interactive session
session = create_rag_session(
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
)
while True:
try:
query = input("You: ").strip()
if query.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
def handle_query(query: str):
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
if not query:
continue
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
session.run_interactive_loop(handle_query)
async def run_single_query(self, args, index_path: str, query: str):
"""Run a single query against the index."""

View File

View File

@@ -0,0 +1,413 @@
"""
ChatGPT export data reader.
Reads and processes ChatGPT export data from chat.html files.
"""
import re
from pathlib import Path
from typing import Any
from zipfile import ZipFile
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChatGPTReader(BaseReader):
"""
ChatGPT export data reader.
Reads ChatGPT conversation data from exported chat.html files or zip archives.
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
self.concatenate_conversations = concatenate_conversations
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
"""
Extract chat.html from ChatGPT export zip file.
Args:
zip_path: Path to the ChatGPT export zip file
Returns:
HTML content as string, or None if not found
"""
try:
with ZipFile(zip_path, "r") as zip_file:
# Look for chat.html or conversations.html
html_files = [
f
for f in zip_file.namelist()
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
]
if not html_files:
print(f"No HTML chat file found in {zip_path}")
return None
# Use the first HTML file found
html_file = html_files[0]
print(f"Found HTML file: {html_file}")
with zip_file.open(html_file) as f:
return f.read().decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error extracting HTML from zip {zip_path}: {e}")
return None
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
"""
Parse ChatGPT HTML export to extract conversations.
Args:
html_content: HTML content from ChatGPT export
Returns:
List of conversation dictionaries
"""
soup = BeautifulSoup(html_content, "html.parser")
conversations = []
# Try different possible structures for ChatGPT exports
# Structure 1: Look for conversation containers
conversation_containers = soup.find_all(
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
)
if not conversation_containers:
# Structure 2: Look for message containers directly
conversation_containers = [soup] # Use the entire document as one conversation
for container in conversation_containers:
conversation = self._extract_conversation_from_container(container)
if conversation and conversation.get("messages"):
conversations.append(conversation)
# If no structured conversations found, try to extract all text as one conversation
if not conversations:
all_text = soup.get_text(separator="\n", strip=True)
if all_text:
conversations.append(
{
"title": "ChatGPT Conversation",
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
"timestamp": None,
}
)
return conversations
def _extract_conversation_from_container(self, container) -> dict | None:
"""
Extract conversation data from a container element.
Args:
container: BeautifulSoup element containing conversation
Returns:
Dictionary with conversation data or None
"""
messages = []
# Look for message elements with various possible structures
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
for selector in message_selectors:
message_elements = container.select(selector)
if message_elements:
break
else:
message_elements = []
# If no structured messages found, treat the entire container as one message
if not message_elements:
text_content = container.get_text(separator="\n", strip=True)
if text_content:
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
else:
for element in message_elements:
message = self._extract_message_from_element(element)
if message:
messages.append(message)
if not messages:
return None
# Try to extract conversation title
title_element = container.find(["h1", "h2", "h3", "title"])
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
# Try to extract timestamp from various possible locations
timestamp = self._extract_timestamp_from_container(container)
return {"title": title, "messages": messages, "timestamp": timestamp}
def _extract_message_from_element(self, element) -> dict | None:
"""
Extract message data from an element.
Args:
element: BeautifulSoup element containing message
Returns:
Dictionary with message data or None
"""
text_content = element.get_text(separator=" ", strip=True)
# Skip empty or very short messages
if not text_content or len(text_content.strip()) < 3:
return None
# Try to determine role (user/assistant) from class names or content
role = "mixed" # Default role
class_names = " ".join(element.get("class", [])).lower()
if "user" in class_names or "human" in class_names:
role = "user"
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
role = "assistant"
elif text_content.lower().startswith(("you:", "user:", "me:")):
role = "user"
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
role = "assistant"
text_content = re.sub(
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
)
# Try to extract timestamp
timestamp = self._extract_timestamp_from_element(element)
return {"role": role, "content": text_content, "timestamp": timestamp}
def _extract_timestamp_from_element(self, element) -> str | None:
"""Extract timestamp from element."""
# Look for timestamp in various attributes and child elements
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
for attr in timestamp_attrs:
if element.get(attr):
return element.get(attr)
# Look for time elements
time_element = element.find("time")
if time_element:
return time_element.get("datetime") or time_element.get_text(strip=True)
# Look for date-like text patterns
text = element.get_text()
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
for pattern in date_patterns:
match = re.search(pattern, text)
if match:
return match.group()
return None
def _extract_timestamp_from_container(self, container) -> str | None:
"""Extract timestamp from conversation container."""
return self._extract_timestamp_from_element(container)
def _create_concatenated_content(self, conversation: dict) -> str:
"""
Create concatenated content from conversation messages.
Args:
conversation: Dictionary containing conversation data
Returns:
Formatted concatenated content
"""
title = conversation.get("title", "ChatGPT Conversation")
messages = conversation.get("messages", [])
timestamp = conversation.get("timestamp", "Unknown")
# Build message content
message_parts = []
for message in messages:
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if role == "user":
prefix = "[You]"
elif role == "assistant":
prefix = "[ChatGPT]"
else:
prefix = "[Message]"
# Add timestamp if available
if msg_timestamp:
prefix += f" ({msg_timestamp})"
message_parts.append(f"{prefix}: {content}")
concatenated_text = "\n\n".join(message_parts)
# Create final document content
doc_content = f"""Conversation: {title}
Date: {timestamp}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load ChatGPT export data.
Args:
input_dir: Directory containing ChatGPT export files or path to specific file
**load_kwargs:
max_count (int): Maximum number of conversations to process
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
include_metadata (bool): Whether to include metadata in documents
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", -1)
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
include_metadata = load_kwargs.get("include_metadata", True)
if not chatgpt_export_path:
print("No ChatGPT export path provided")
return docs
export_path = Path(chatgpt_export_path)
if not export_path.exists():
print(f"ChatGPT export path not found: {export_path}")
return docs
html_content = None
# Handle different input types
if export_path.is_file():
if export_path.suffix.lower() == ".zip":
# Extract HTML from zip file
html_content = self._extract_html_from_zip(export_path)
elif export_path.suffix.lower() == ".html":
# Read HTML file directly
try:
with open(export_path, encoding="utf-8", errors="ignore") as f:
html_content = f.read()
except Exception as e:
print(f"Error reading HTML file {export_path}: {e}")
return docs
else:
print(f"Unsupported file type: {export_path.suffix}")
return docs
elif export_path.is_dir():
# Look for HTML files in directory
html_files = list(export_path.glob("*.html"))
zip_files = list(export_path.glob("*.zip"))
if html_files:
# Use first HTML file found
html_file = html_files[0]
print(f"Found HTML file: {html_file}")
try:
with open(html_file, encoding="utf-8", errors="ignore") as f:
html_content = f.read()
except Exception as e:
print(f"Error reading HTML file {html_file}: {e}")
return docs
elif zip_files:
# Use first zip file found
zip_file = zip_files[0]
print(f"Found zip file: {zip_file}")
html_content = self._extract_html_from_zip(zip_file)
else:
print(f"No HTML or zip files found in {export_path}")
return docs
if not html_content:
print("No HTML content found to process")
return docs
# Parse conversations from HTML
print("Parsing ChatGPT conversations from HTML...")
conversations = self._parse_chatgpt_html(html_content)
if not conversations:
print("No conversations found in HTML content")
return docs
print(f"Found {len(conversations)} conversations")
# Process conversations into documents
count = 0
for conversation in conversations:
if max_count > 0 and count >= max_count:
break
if self.concatenate_conversations:
# Create one document per conversation with concatenated messages
doc_content = self._create_concatenated_content(conversation)
metadata = {}
if include_metadata:
metadata = {
"title": conversation.get("title", "ChatGPT Conversation"),
"timestamp": conversation.get("timestamp", "Unknown"),
"message_count": len(conversation.get("messages", [])),
"source": "ChatGPT Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
else:
# Create separate documents for each message
for message in conversation.get("messages", []):
if max_count > 0 and count >= max_count:
break
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if not content.strip():
continue
# Create document content with context
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
Role: {role}
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
Message: {content}
"""
metadata = {}
if include_metadata:
metadata = {
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
"role": role,
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
"source": "ChatGPT Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
print(f"Created {len(docs)} documents from ChatGPT export")
return docs

186
apps/chatgpt_rag.py Normal file
View File

@@ -0,0 +1,186 @@
"""
ChatGPT RAG example using the unified interface.
Supports ChatGPT export data from chat.html files.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .chatgpt_data.chatgpt_reader import ChatGPTReader
class ChatGPTRAG(BaseRAGExample):
"""RAG example for ChatGPT conversation data."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all conversations by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="ChatGPT",
description="Process and query ChatGPT conversation exports with LEANN",
default_index_name="chatgpt_conversations_index",
)
def _add_specific_arguments(self, parser):
"""Add ChatGPT-specific arguments."""
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
chatgpt_group.add_argument(
"--export-path",
type=str,
default="./chatgpt_export",
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
)
chatgpt_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
chatgpt_group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
chatgpt_group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
chatgpt_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
"""
Find ChatGPT export files in the given path.
Args:
export_path: Path to search for exports
Returns:
List of paths to ChatGPT export files
"""
export_files = []
if export_path.is_file():
if export_path.suffix.lower() in [".zip", ".html"]:
export_files.append(export_path)
elif export_path.is_dir():
# Look for zip and html files
export_files.extend(export_path.glob("*.zip"))
export_files.extend(export_path.glob("*.html"))
return export_files
async def load_data(self, args) -> list[str]:
"""Load ChatGPT export data and convert to text chunks."""
export_path = Path(args.export_path)
if not export_path.exists():
print(f"ChatGPT export path not found: {export_path}")
print(
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
)
print("\nTo export your ChatGPT data:")
print("1. Sign in to ChatGPT")
print("2. Click on your profile icon → Settings → Data Controls")
print("3. Click 'Export' under Export Data")
print("4. Download the zip file from the email link")
print("5. Extract or place the file/directory at the specified path")
return []
# Find export files
export_files = self._find_chatgpt_exports(export_path)
if not export_files:
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
return []
print(f"Found {len(export_files)} ChatGPT export files")
# Create reader with appropriate settings
concatenate = args.concatenate_conversations and not args.separate_messages
reader = ChatGPTReader(concatenate_conversations=concatenate)
# Process each export file
all_documents = []
total_processed = 0
for i, export_file in enumerate(export_files):
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
try:
# Apply max_items limit per file
max_per_file = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_file = remaining
# Load conversations
documents = reader.load_data(
chatgpt_export_path=str(export_file),
max_count=max_per_file,
include_metadata=True,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} conversations from this file")
else:
print(f"No conversations loaded from {export_file}")
except Exception as e:
print(f"Error processing {export_file}: {e}")
continue
if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid ChatGPT export")
print("- Check that the HTML file contains conversation data")
print("- Try extracting the zip file and pointing to the HTML file directly")
return []
print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for ChatGPT RAG
print("\n🤖 ChatGPT RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did I ask about Python programming?'")
print("- 'Show me conversations about machine learning'")
print("- 'Find discussions about travel planning'")
print("- 'What advice did ChatGPT give me about career development?'")
print("- 'Search for conversations about cooking recipes'")
print("\nTo get started:")
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
print("3. Run this script to build your personal ChatGPT knowledge base!")
print("\nOr run without --query for interactive mode\n")
rag = ChatGPTRAG()
asyncio.run(rag.run())

View File

@@ -12,6 +12,7 @@ from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
@@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
@@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
__all__ = [
"CODE_EXTENSIONS",
"_traditional_chunks_as_dicts",
"create_ast_chunks",
"create_text_chunks",
"create_traditional_chunks",

View File

View File

@@ -0,0 +1,420 @@
"""
Claude export data reader.
Reads and processes Claude conversation data from exported JSON files.
"""
import json
from pathlib import Path
from typing import Any
from zipfile import ZipFile
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ClaudeReader(BaseReader):
"""
Claude export data reader.
Reads Claude conversation data from exported JSON files or zip archives.
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
self.concatenate_conversations = concatenate_conversations
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
"""
Extract JSON files from Claude export zip file.
Args:
zip_path: Path to the Claude export zip file
Returns:
List of JSON content strings, or empty list if not found
"""
json_contents = []
try:
with ZipFile(zip_path, "r") as zip_file:
# Look for JSON files
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
if not json_files:
print(f"No JSON files found in {zip_path}")
return []
print(f"Found {len(json_files)} JSON files in archive")
for json_file in json_files:
with zip_file.open(json_file) as f:
content = f.read().decode("utf-8", errors="ignore")
json_contents.append(content)
except Exception as e:
print(f"Error extracting JSON from zip {zip_path}: {e}")
return json_contents
def _parse_claude_json(self, json_content: str) -> list[dict]:
"""
Parse Claude JSON export to extract conversations.
Args:
json_content: JSON content from Claude export
Returns:
List of conversation dictionaries
"""
try:
data = json.loads(json_content)
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
return []
conversations = []
# Handle different possible JSON structures
if isinstance(data, list):
# If data is a list of conversations
for item in data:
conversation = self._extract_conversation_from_json(item)
if conversation:
conversations.append(conversation)
elif isinstance(data, dict):
# Check for common structures
if "conversations" in data:
# Structure: {"conversations": [...]}
for item in data["conversations"]:
conversation = self._extract_conversation_from_json(item)
if conversation:
conversations.append(conversation)
elif "messages" in data:
# Single conversation with messages
conversation = self._extract_conversation_from_json(data)
if conversation:
conversations.append(conversation)
else:
# Try to treat the whole object as a conversation
conversation = self._extract_conversation_from_json(data)
if conversation:
conversations.append(conversation)
return conversations
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
"""
Extract conversation data from a JSON object.
Args:
conv_data: Dictionary containing conversation data
Returns:
Dictionary with conversation data or None
"""
if not isinstance(conv_data, dict):
return None
messages = []
# Look for messages in various possible structures
message_sources = []
if "messages" in conv_data:
message_sources = conv_data["messages"]
elif "chat" in conv_data:
message_sources = conv_data["chat"]
elif "conversation" in conv_data:
message_sources = conv_data["conversation"]
else:
# If no clear message structure, try to extract from the object itself
if "content" in conv_data and "role" in conv_data:
message_sources = [conv_data]
for msg_data in message_sources:
message = self._extract_message_from_json(msg_data)
if message:
messages.append(message)
if not messages:
return None
# Extract conversation metadata
title = self._extract_title_from_conversation(conv_data, messages)
timestamp = self._extract_timestamp_from_conversation(conv_data)
return {"title": title, "messages": messages, "timestamp": timestamp}
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
"""
Extract message data from a JSON message object.
Args:
msg_data: Dictionary containing message data
Returns:
Dictionary with message data or None
"""
if not isinstance(msg_data, dict):
return None
# Extract content from various possible fields
content = ""
content_fields = ["content", "text", "message", "body"]
for field in content_fields:
if msg_data.get(field):
content = str(msg_data[field])
break
if not content or len(content.strip()) < 3:
return None
# Extract role (user/assistant/human/ai/claude)
role = "mixed" # Default role
role_fields = ["role", "sender", "from", "author", "type"]
for field in role_fields:
if msg_data.get(field):
role_value = str(msg_data[field]).lower()
if role_value in ["user", "human", "person"]:
role = "user"
elif role_value in ["assistant", "ai", "claude", "bot"]:
role = "assistant"
break
# Extract timestamp
timestamp = self._extract_timestamp_from_message(msg_data)
return {"role": role, "content": content, "timestamp": timestamp}
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
"""Extract timestamp from message data."""
timestamp_fields = ["timestamp", "created_at", "date", "time"]
for field in timestamp_fields:
if msg_data.get(field):
return str(msg_data[field])
return None
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
"""Extract timestamp from conversation data."""
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
for field in timestamp_fields:
if conv_data.get(field):
return str(conv_data[field])
return None
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
"""Extract or generate title for conversation."""
# Try to find explicit title
title_fields = ["title", "name", "subject", "topic"]
for field in title_fields:
if conv_data.get(field):
return str(conv_data[field])
# Generate title from first user message
for message in messages:
if message.get("role") == "user":
content = message.get("content", "")
if content:
# Use first 50 characters as title
title = content[:50].strip()
if len(content) > 50:
title += "..."
return title
return "Claude Conversation"
def _create_concatenated_content(self, conversation: dict) -> str:
"""
Create concatenated content from conversation messages.
Args:
conversation: Dictionary containing conversation data
Returns:
Formatted concatenated content
"""
title = conversation.get("title", "Claude Conversation")
messages = conversation.get("messages", [])
timestamp = conversation.get("timestamp", "Unknown")
# Build message content
message_parts = []
for message in messages:
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if role == "user":
prefix = "[You]"
elif role == "assistant":
prefix = "[Claude]"
else:
prefix = "[Message]"
# Add timestamp if available
if msg_timestamp:
prefix += f" ({msg_timestamp})"
message_parts.append(f"{prefix}: {content}")
concatenated_text = "\n\n".join(message_parts)
# Create final document content
doc_content = f"""Conversation: {title}
Date: {timestamp}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load Claude export data.
Args:
input_dir: Directory containing Claude export files or path to specific file
**load_kwargs:
max_count (int): Maximum number of conversations to process
claude_export_path (str): Specific path to Claude export file/directory
include_metadata (bool): Whether to include metadata in documents
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", -1)
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
include_metadata = load_kwargs.get("include_metadata", True)
if not claude_export_path:
print("No Claude export path provided")
return docs
export_path = Path(claude_export_path)
if not export_path.exists():
print(f"Claude export path not found: {export_path}")
return docs
json_contents = []
# Handle different input types
if export_path.is_file():
if export_path.suffix.lower() == ".zip":
# Extract JSON from zip file
json_contents = self._extract_json_from_zip(export_path)
elif export_path.suffix.lower() == ".json":
# Read JSON file directly
try:
with open(export_path, encoding="utf-8", errors="ignore") as f:
json_contents.append(f.read())
except Exception as e:
print(f"Error reading JSON file {export_path}: {e}")
return docs
else:
print(f"Unsupported file type: {export_path.suffix}")
return docs
elif export_path.is_dir():
# Look for JSON files in directory
json_files = list(export_path.glob("*.json"))
zip_files = list(export_path.glob("*.zip"))
if json_files:
print(f"Found {len(json_files)} JSON files in directory")
for json_file in json_files:
try:
with open(json_file, encoding="utf-8", errors="ignore") as f:
json_contents.append(f.read())
except Exception as e:
print(f"Error reading JSON file {json_file}: {e}")
continue
if zip_files:
print(f"Found {len(zip_files)} ZIP files in directory")
for zip_file in zip_files:
zip_contents = self._extract_json_from_zip(zip_file)
json_contents.extend(zip_contents)
if not json_files and not zip_files:
print(f"No JSON or ZIP files found in {export_path}")
return docs
if not json_contents:
print("No JSON content found to process")
return docs
# Parse conversations from JSON content
print("Parsing Claude conversations from JSON...")
all_conversations = []
for json_content in json_contents:
conversations = self._parse_claude_json(json_content)
all_conversations.extend(conversations)
if not all_conversations:
print("No conversations found in JSON content")
return docs
print(f"Found {len(all_conversations)} conversations")
# Process conversations into documents
count = 0
for conversation in all_conversations:
if max_count > 0 and count >= max_count:
break
if self.concatenate_conversations:
# Create one document per conversation with concatenated messages
doc_content = self._create_concatenated_content(conversation)
metadata = {}
if include_metadata:
metadata = {
"title": conversation.get("title", "Claude Conversation"),
"timestamp": conversation.get("timestamp", "Unknown"),
"message_count": len(conversation.get("messages", [])),
"source": "Claude Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
else:
# Create separate documents for each message
for message in conversation.get("messages", []):
if max_count > 0 and count >= max_count:
break
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if not content.strip():
continue
# Create document content with context
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
Role: {role}
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
Message: {content}
"""
metadata = {}
if include_metadata:
metadata = {
"conversation_title": conversation.get("title", "Claude Conversation"),
"role": role,
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
"source": "Claude Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
print(f"Created {len(docs)} documents from Claude export")
return docs

189
apps/claude_rag.py Normal file
View File

@@ -0,0 +1,189 @@
"""
Claude RAG example using the unified interface.
Supports Claude export data from JSON files.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .claude_data.claude_reader import ClaudeReader
class ClaudeRAG(BaseRAGExample):
"""RAG example for Claude conversation data."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all conversations by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Claude",
description="Process and query Claude conversation exports with LEANN",
default_index_name="claude_conversations_index",
)
def _add_specific_arguments(self, parser):
"""Add Claude-specific arguments."""
claude_group = parser.add_argument_group("Claude Parameters")
claude_group.add_argument(
"--export-path",
type=str,
default="./claude_export",
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
)
claude_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
claude_group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
claude_group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
claude_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _find_claude_exports(self, export_path: Path) -> list[Path]:
"""
Find Claude export files in the given path.
Args:
export_path: Path to search for exports
Returns:
List of paths to Claude export files
"""
export_files = []
if export_path.is_file():
if export_path.suffix.lower() in [".zip", ".json"]:
export_files.append(export_path)
elif export_path.is_dir():
# Look for zip and json files
export_files.extend(export_path.glob("*.zip"))
export_files.extend(export_path.glob("*.json"))
return export_files
async def load_data(self, args) -> list[str]:
"""Load Claude export data and convert to text chunks."""
export_path = Path(args.export_path)
if not export_path.exists():
print(f"Claude export path not found: {export_path}")
print(
"Please ensure you have exported your Claude data and placed it in the correct location."
)
print("\nTo export your Claude data:")
print("1. Open Claude in your browser")
print("2. Look for export/download options in settings or conversation menu")
print("3. Download the conversation data (usually in JSON format)")
print("4. Place the file/directory at the specified path")
print(
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
)
return []
# Find export files
export_files = self._find_claude_exports(export_path)
if not export_files:
print(f"No Claude export files (.json or .zip) found in: {export_path}")
return []
print(f"Found {len(export_files)} Claude export files")
# Create reader with appropriate settings
concatenate = args.concatenate_conversations and not args.separate_messages
reader = ClaudeReader(concatenate_conversations=concatenate)
# Process each export file
all_documents = []
total_processed = 0
for i, export_file in enumerate(export_files):
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
try:
# Apply max_items limit per file
max_per_file = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_file = remaining
# Load conversations
documents = reader.load_data(
claude_export_path=str(export_file),
max_count=max_per_file,
include_metadata=True,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} conversations from this file")
else:
print(f"No conversations loaded from {export_file}")
except Exception as e:
print(f"Error processing {export_file}: {e}")
continue
if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid Claude export")
print("- Check that the JSON file contains conversation data")
print("- Try using a different export format or method")
print("- Check Claude's documentation for current export procedures")
return []
print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for Claude RAG
print("\n🤖 Claude RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did I ask Claude about Python programming?'")
print("- 'Show me conversations about machine learning'")
print("- 'Find discussions about code optimization'")
print("- 'What advice did Claude give me about software design?'")
print("- 'Search for conversations about debugging techniques'")
print("\nTo get started:")
print("1. Export your Claude conversation data")
print("2. Place the JSON/ZIP file in ./claude_export/")
print("3. Run this script to build your personal Claude knowledge base!")
print("\nOr run without --query for interactive mode\n")
rag = ClaudeRAG()
asyncio.run(rag.run())

View File

@@ -0,0 +1 @@
"""iMessage data processing module."""

View File

@@ -0,0 +1,342 @@
"""
iMessage data reader.
Reads and processes iMessage conversation data from the macOS Messages database.
"""
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class IMessageReader(BaseReader):
"""
iMessage data reader.
Reads iMessage conversation data from the macOS Messages database (chat.db).
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
self.concatenate_conversations = concatenate_conversations
def _get_default_chat_db_path(self) -> Path:
"""
Get the default path to the iMessage chat database.
Returns:
Path to the chat.db file
"""
home = Path.home()
return home / "Library" / "Messages" / "chat.db"
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
"""
Convert Cocoa timestamp to readable format.
Args:
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
Returns:
Formatted timestamp string
"""
if cocoa_timestamp == 0:
return "Unknown"
try:
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
# Convert to seconds and add to Unix epoch
cocoa_epoch = datetime(2001, 1, 1)
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
message_time = cocoa_epoch.timestamp() + unix_timestamp
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "Unknown"
def _get_contact_name(self, handle_id: str) -> str:
"""
Get a readable contact name from handle ID.
Args:
handle_id: The handle ID (phone number or email)
Returns:
Formatted contact name
"""
if not handle_id:
return "Unknown"
# Clean up phone numbers and emails for display
if "@" in handle_id:
return handle_id # Email address
elif handle_id.startswith("+"):
return handle_id # International phone number
else:
# Try to format as phone number
digits = "".join(filter(str.isdigit, handle_id))
if len(digits) == 10:
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
elif len(digits) == 11 and digits[0] == "1":
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
else:
return handle_id
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
"""
Read messages from the iMessage database.
Args:
db_path: Path to the chat.db file
Returns:
List of message dictionaries
"""
if not db_path.exists():
print(f"iMessage database not found at: {db_path}")
return []
try:
# Connect to the database
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Query to get messages with chat and handle information
query = """
SELECT
m.ROWID as message_id,
m.text,
m.date,
m.is_from_me,
m.service,
c.chat_identifier,
c.display_name as chat_display_name,
h.id as handle_id,
c.ROWID as chat_id
FROM message m
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
LEFT JOIN handle h ON m.handle_id = h.ROWID
WHERE m.text IS NOT NULL AND m.text != ''
ORDER BY c.ROWID, m.date
"""
cursor.execute(query)
rows = cursor.fetchall()
messages = []
for row in rows:
(
message_id,
text,
date,
is_from_me,
service,
chat_identifier,
chat_display_name,
handle_id,
chat_id,
) = row
message = {
"message_id": message_id,
"text": text,
"timestamp": self._convert_cocoa_timestamp(date),
"is_from_me": bool(is_from_me),
"service": service or "iMessage",
"chat_identifier": chat_identifier or "Unknown",
"chat_display_name": chat_display_name or "Unknown Chat",
"handle_id": handle_id or "Unknown",
"contact_name": self._get_contact_name(handle_id or ""),
"chat_id": chat_id,
}
messages.append(message)
conn.close()
print(f"Found {len(messages)} messages in database")
return messages
except sqlite3.Error as e:
print(f"Error reading iMessage database: {e}")
return []
except Exception as e:
print(f"Unexpected error reading iMessage database: {e}")
return []
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
"""
Group messages by chat ID.
Args:
messages: List of message dictionaries
Returns:
Dictionary mapping chat_id to list of messages
"""
chats = {}
for message in messages:
chat_id = message["chat_id"]
if chat_id not in chats:
chats[chat_id] = []
chats[chat_id].append(message)
return chats
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
"""
Create concatenated content from chat messages.
Args:
chat_id: The chat ID
messages: List of messages in the chat
Returns:
Concatenated text content
"""
if not messages:
return ""
# Get chat info from first message
first_msg = messages[0]
chat_name = first_msg["chat_display_name"]
chat_identifier = first_msg["chat_identifier"]
# Build message content
message_parts = []
for message in messages:
timestamp = message["timestamp"]
is_from_me = message["is_from_me"]
text = message["text"]
contact_name = message["contact_name"]
if is_from_me:
prefix = "[You]"
else:
prefix = f"[{contact_name}]"
if timestamp != "Unknown":
prefix += f" ({timestamp})"
message_parts.append(f"{prefix}: {text}")
concatenated_text = "\n\n".join(message_parts)
doc_content = f"""Chat: {chat_name}
Identifier: {chat_identifier}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def _create_individual_content(self, message: dict) -> str:
"""
Create content for individual message.
Args:
message: Message dictionary
Returns:
Formatted message content
"""
timestamp = message["timestamp"]
is_from_me = message["is_from_me"]
text = message["text"]
contact_name = message["contact_name"]
chat_name = message["chat_display_name"]
sender = "You" if is_from_me else contact_name
return f"""Message from {sender} in chat "{chat_name}"
Time: {timestamp}
Content: {text}
"""
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load iMessage data and return as documents.
Args:
input_dir: Optional path to directory containing chat.db file.
If not provided, uses default macOS location.
**load_kwargs: Additional arguments (unused)
Returns:
List of Document objects containing iMessage data
"""
docs = []
# Determine database path
if input_dir:
db_path = Path(input_dir) / "chat.db"
else:
db_path = self._get_default_chat_db_path()
print(f"Reading iMessage database from: {db_path}")
# Read messages from database
messages = self._read_messages_from_db(db_path)
if not messages:
return docs
if self.concatenate_conversations:
# Group messages by chat and create concatenated documents
chats = self._group_messages_by_chat(messages)
for chat_id, chat_messages in chats.items():
if not chat_messages:
continue
content = self._create_concatenated_content(chat_id, chat_messages)
# Create metadata
first_msg = chat_messages[0]
last_msg = chat_messages[-1]
metadata = {
"source": "iMessage",
"chat_id": chat_id,
"chat_name": first_msg["chat_display_name"],
"chat_identifier": first_msg["chat_identifier"],
"message_count": len(chat_messages),
"first_message_date": first_msg["timestamp"],
"last_message_date": last_msg["timestamp"],
"participants": list(
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
),
}
doc = Document(text=content, metadata=metadata)
docs.append(doc)
else:
# Create individual documents for each message
for message in messages:
content = self._create_individual_content(message)
metadata = {
"source": "iMessage",
"message_id": message["message_id"],
"chat_id": message["chat_id"],
"chat_name": message["chat_display_name"],
"chat_identifier": message["chat_identifier"],
"timestamp": message["timestamp"],
"is_from_me": message["is_from_me"],
"contact_name": message["contact_name"],
"service": message["service"],
}
doc = Document(text=content, metadata=metadata)
docs.append(doc)
print(f"Created {len(docs)} documents from iMessage data")
return docs

125
apps/imessage_rag.py Normal file
View File

@@ -0,0 +1,125 @@
"""
iMessage RAG Example.
This example demonstrates how to build a RAG system on your iMessage conversation history.
"""
import asyncio
from pathlib import Path
from leann.chunking_utils import create_text_chunks
from apps.base_rag_example import BaseRAGExample
from apps.imessage_data.imessage_reader import IMessageReader
class IMessageRAG(BaseRAGExample):
"""RAG example for iMessage conversation history."""
def __init__(self):
super().__init__(
name="iMessage",
description="RAG on your iMessage conversation history",
default_index_name="imessage_index",
)
def _add_specific_arguments(self, parser):
"""Add iMessage-specific arguments."""
imessage_group = parser.add_argument_group("iMessage Parameters")
imessage_group.add_argument(
"--db-path",
type=str,
default=None,
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
)
imessage_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
imessage_group.add_argument(
"--no-concatenate-conversations",
action="store_true",
help="Process each message individually instead of concatenating by conversation",
)
imessage_group.add_argument(
"--chunk-size",
type=int,
default=1000,
help="Maximum characters per text chunk (default: 1000)",
)
imessage_group.add_argument(
"--chunk-overlap",
type=int,
default=200,
help="Overlap between text chunks (default: 200)",
)
async def load_data(self, args) -> list[str]:
"""Load iMessage history and convert to text chunks."""
print("Loading iMessage conversation history...")
# Determine concatenation setting
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
# Initialize iMessage reader
reader = IMessageReader(concatenate_conversations=concatenate)
# Load documents
try:
if args.db_path:
# Use custom database path
db_dir = str(Path(args.db_path).parent)
documents = reader.load_data(input_dir=db_dir)
else:
# Use default macOS location
documents = reader.load_data()
except Exception as e:
print(f"Error loading iMessage data: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
print("3. Try specifying a custom path with --db-path if you have a backup")
return []
if not documents:
print("No iMessage conversations found!")
return []
print(f"Loaded {len(documents)} iMessage documents")
# Show some statistics
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
print(f"Total messages: {total_messages}")
if concatenate:
# Show chat statistics
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
unique_chats = len(set(chat_names))
print(f"Unique conversations: {unique_chats}")
# Convert to text chunks
all_texts = create_text_chunks(
documents,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
)
# Apply max_items limit if specified
if args.max_items > 0:
all_texts = all_texts[: args.max_items]
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
return all_texts
async def main():
"""Main entry point."""
app = IMessageRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,113 @@
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
### What youll run
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
## Prerequisites (macOS)
### 1) Homebrew poppler (for pdf2image)
```bash
brew install poppler
which pdfinfo && pdfinfo -v
```
### 2) Python environment
Use uv (recommended) or pip. Python 3.9+.
Using uv:
```bash
uv pip install \
colpali_engine \
pdf2image \
pillow \
matplotlib qwen_vl_utils \
einops \
seaborn
```
Notes:
- On first run, models download from Hugging Face. Login/config if needed.
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
```bash
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
```
## Run the demos
### A) Local PDF example
Converts a local PDF into page images, embeds them, builds an index, and searches.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
# If you don't have the sample PDF locally, download it (ignored by Git)
mkdir -p pdfs
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
ls pdfs/2004.12832v2.pdf
# Ensure output dir exists
mkdir -p pages
python multi-vector-leann-paper-example.py
```
Expected:
- Page images in `pages/`.
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
To use your own PDF: edit `pdf_path` near the top of the script.
### B) Similarity map + answer demo
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Artifacts (when enabled):
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
- Similarity maps: `./figures/similarity_map_rank{K}.png`
Key knobs in the script (top of file):
- `QUERY`: your question
- `MODEL`: `"colqwen2"` or `"colpali"`
- `USE_HF_DATASET`: set `False` to use local pages
- `PDF`, `PAGES_DIR`: for local mode
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
## Troubleshooting
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
## Notes
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
- For local PDFs, page images go to `./pages/`.
### Retrieval and Visualization Example
Example settings in `multi-vector-leann-similarity-map.py`:
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
- `SIMILARITY_MAP = True` (to generate heatmaps)
- `TOPK = 1` (save the top retrieved page and its similarity map)
Run:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Outputs (by default):
- Retrieved page: `./figures/retrieved_page_rank1.png`
- Similarity map: `./figures/similarity_map_rank1.png`
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
"):
![Similarity map example](fig/image.png)
Notes:
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 166 KiB

View File

@@ -1,12 +1,18 @@
from __future__ import annotations
import concurrent.futures
import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
import numpy as np
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
@@ -16,6 +22,380 @@ def _ensure_repo_paths_importable(current_file: str) -> None:
sys.path.append(str(_leann_hnsw_pkg))
def _find_backend_module_file() -> Optional[Path]:
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
this_file = Path(__file__).resolve()
candidates: list[Path] = []
# Common in-repo location
repo_root = this_file.parents[3]
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
candidates.append(
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
)
for cand in candidates:
try:
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
pass
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
for p in list(sys.path):
try:
cand = Path(p) / "leann_multi_vector.py"
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
continue
return None
_BACKEND_LEANN_CLASS: Optional[type] = None
def _get_backend_leann_multi_vector() -> type:
"""Load backend LeannMultiVector class even if this file shadows its module name."""
global _BACKEND_LEANN_CLASS
if _BACKEND_LEANN_CLASS is not None:
return _BACKEND_LEANN_CLASS
backend_path = _find_backend_module_file()
if backend_path is None:
# Fallback to local implementation in this module
try:
cls = LeannMultiVector # type: ignore[name-defined]
_BACKEND_LEANN_CLASS = cls
return cls
except Exception as e:
raise ImportError(
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
) from e
import importlib.util
module_name = "leann_hnsw_backend_module"
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
if spec is None or spec.loader is None:
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
backend_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = backend_module
spec.loader.exec_module(backend_module) # type: ignore[assignment]
if not hasattr(backend_module, "LeannMultiVector"):
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
return _BACKEND_LEANN_CLASS
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
) -> Any:
LeannMultiVector = _get_backend_leann_multi_vector()
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
"image": images[i], # Include the original image
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
LeannMultiVector = _get_backend_leann_multi_vector()
index_base = Path(index_path)
# Check for the actual HNSW index file written by the backend + our sidecar files
index_file = index_base.parent / f"{index_base.stem}.index"
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_file.exists() and meta.exists() and labels.exists():
try:
with open(meta, encoding="utf-8") as f:
meta_json = json.load(f)
dim = int(meta_json.get("dimensions", 128))
except Exception:
dim = 128
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Ensure repo paths are importable for dynamic backend loading
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
@@ -45,6 +425,7 @@ class LeannMultiVector:
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
self._docid_to_indices: dict[int, list[int]] | None = None
def _meta_dict(self) -> dict:
return {
@@ -69,6 +450,7 @@ class LeannMultiVector:
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
"image": data.get("image"), # PIL Image object (optional)
}
)
@@ -80,6 +462,15 @@ class LeannMultiVector:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def _embeddings_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
def _images_dir_path(self) -> Path:
"""Directory where original images are stored."""
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.images"
def create_index(self) -> None:
if not self._pending_items:
return
@@ -87,10 +478,23 @@ class LeannMultiVector:
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
# Create images directory if needed
images_dir = self._images_dir_path()
images_dir.mkdir(parents=True, exist_ok=True)
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
image = item.get("image")
# Save image if provided
image_path = ""
if image is not None and isinstance(image, Image.Image):
image_filename = f"doc_{doc_id}.png"
image_path = str(images_dir / image_filename)
image.save(image_path, "PNG")
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
@@ -100,6 +504,7 @@ class LeannMultiVector:
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
"image_path": image_path, # Store the path to the saved image
}
)
@@ -107,7 +512,6 @@ class LeannMultiVector:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
@@ -121,6 +525,9 @@ class LeannMultiVector:
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
# Persist embeddings for exact reranking
np.save(self._embeddings_path(), embeddings_np)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
@@ -133,6 +540,19 @@ class LeannMultiVector:
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def _build_docid_to_indices_if_needed(self) -> None:
if self._docid_to_indices is not None:
return
self._load_labels_meta_if_needed()
mapping: dict[int, list[int]] = {}
for idx, meta in enumerate(self._labels_meta):
try:
doc_id = int(meta["doc_id"]) # type: ignore[index]
except Exception:
continue
mapping.setdefault(doc_id, []).append(idx)
self._docid_to_indices = mapping
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
@@ -180,3 +600,181 @@ class LeannMultiVector:
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact(
self,
data: np.ndarray,
topk: int,
*,
first_stage_k: int = 200,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
High-precision MaxSim reranking over candidate documents.
Steps:
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
2) For each candidate doc, load all its token embeddings and compute
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
Returns top-k list of (score, doc_id).
"""
# Normalize inputs
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
# Fallback to approximate if we don't have persisted embeddings
return self.search(data, topk, first_stage_k=first_stage_k)
# Memory-map embeddings to avoid loading all into RAM
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
# First-stage ANN to collect candidate doc_ids
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
if labels is None:
return []
candidate_doc_ids: set[int] = set()
for batch in labels:
for sid in batch:
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
# Exact scoring per doc (parallelized)
assert self._docid_to_indices is not None
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
sim = np.dot(data, doc_vecs.T)
# nan-safe
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact_all(
self,
data: np.ndarray,
topk: int,
*,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
Exact MaxSim over ALL documents (no ANN pre-filtering).
This computes, for each document, sum_i max_j dot(q_i, d_j).
It memory-maps the persisted token-embedding matrix for scalability.
"""
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve the original image for a given doc_id from the index.
Args:
doc_id: The document ID
Returns:
PIL Image object if found, None otherwise
"""
self._load_labels_meta_if_needed()
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
image_path = meta.get("image_path", "")
if image_path and Path(image_path).exists():
return Image.open(image_path)
break
return None
def get_metadata(self, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a given doc_id.
Args:
doc_id: The document ID
Returns:
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
"""
self._load_labels_meta_if_needed()
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
return {
"doc_id": doc_id,
"filepath": meta.get("filepath", ""),
"image_path": meta.get("image_path", ""),
}
return None

View File

@@ -4,39 +4,24 @@
# pip install tqdm
# pip install pillow
# %%
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
# %%
import os
import re
import sys
from pathlib import Path
from typing import cast
# Make local leann packages importable without installing
from PIL import Image
from tqdm import tqdm
# Ensure local leann packages are importable before importing them
_repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
import sys
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import LeannMultiVector
class LeannRetriever(LeannMultiVector):
pass
# %%
from typing import cast
import torch
from colpali_engine.models import ColPali
@@ -88,13 +73,6 @@ for batch_query in dataloader:
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape)
# %%
import re
from PIL import Image
from tqdm import tqdm
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]

View File

@@ -2,34 +2,31 @@
# %%
# uv pip install matplotlib qwen_vl_utils
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
from typing import Any, Optional
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import ( # utility functions/classes
_ensure_repo_paths_importable,
_load_images_from_dir,
_maybe_convert_pdf_to_images,
_load_colvision,
_embed_images,
_embed_queries,
_build_index,
_load_retriever_if_index_exists,
_generate_similarity_map,
QwenVL,
)
_ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
@@ -44,7 +41,7 @@ PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
TOPK: int = 3
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
@@ -54,332 +51,57 @@ SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 128
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
MAX_NEW_TOKENS: int = 1024
# %%
# Step 1: Prepare data
if USE_HF_DATASET:
from datasets import load_dataset
# Step 1: Check if we can skip data loading (index already exists)
retriever: Optional[Any] = None
need_to_build_index = REBUILD_INDEX
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset"):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
if not REBUILD_INDEX:
retriever = _load_retriever_if_index_exists(INDEX_PATH)
if retriever is not None:
print(f"✓ Index loaded from {INDEX_PATH}")
print(f"✓ Images available at: {retriever._images_dir_path()}")
need_to_build_index = False
else:
print(f"Index not found, will build new index")
need_to_build_index = True
# Step 2: Load data only if we need to build the index
if need_to_build_index:
print("Loading dataset...")
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print(f"Loaded {len(images)} images")
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print("Skipping dataset loading (using existing index)")
filepaths = [] # Not needed when using existing index
images = [] # Not needed when using existing index
# %%
# Step 2: Load model and processor
# Step 3: Load model and processor (only if we need to build index or perform search)
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
@@ -387,34 +109,39 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
try:
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None:
# Step 4: Build index if needed
if need_to_build_index and retriever is None:
print("Building index...")
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
# Clear memory
del images, filepaths, doc_vecs
# Note: Images are now stored in the index, retriever will load them on-demand from disk
# %%
# Step 4: Embed query and search
# Step 5: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
path = filepaths[doc_id]
# Retrieve image from index instead of memory
image = retriever.get_image(doc_id)
if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
continue
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(images[doc_id])
top_images.append(image)
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
@@ -427,12 +154,17 @@ else:
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
# Print the retrieval score (document-level MaxSim) alongside the saved path
try:
score, _doc_id = results[rank - 1]
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 5: Similarity maps for top-K results
# Step 6: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
@@ -469,7 +201,7 @@ if results and SIMILARITY_MAP:
# %%
# Step 6: Optional answer generation
# Step 7: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)

View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python3
import re
import sys
from datetime import datetime, timedelta
from pathlib import Path
from leann import LeannSearcher
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
class TimeParser:
def __init__(self):
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
# Compile for performance
self.regex = re.compile(self.pattern, re.IGNORECASE)
# Stop words to remove before regex parsing
self.stop_words = {
"in",
"at",
"of",
"by",
"as",
"me",
"the",
"a",
"an",
"and",
"any",
"find",
"search",
"list",
"ago",
"back",
"past",
"earlier",
}
def clean_text(self, text):
"""Remove stop words from text"""
words = text.split()
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
return cleaned
def parse(self, text):
"""Extract all time expressions from text"""
# Clean text first
cleaned_text = self.clean_text(text)
matches = []
for match in self.regex.finditer(cleaned_text):
fuzzy = match.group(1) # "around", "about", etc.
number = int(match.group(2))
unit = match.group(3).lower()
matches.append(
{
"full_match": match.group(0),
"fuzzy": bool(fuzzy),
"number": number,
"unit": unit,
"range": self.calculate_range(number, unit, bool(fuzzy)),
}
)
return matches
def calculate_range(self, number, unit, is_fuzzy):
"""Convert to actual datetime range and return ISO format strings"""
units = {
"hour": timedelta(hours=number),
"day": timedelta(days=number),
"week": timedelta(weeks=number),
"month": timedelta(days=number * 30),
"year": timedelta(days=number * 365),
}
delta = units[unit]
now = datetime.now()
target = now - delta
if is_fuzzy:
buffer = delta * 0.2 # 20% buffer for fuzzy
start = (target - buffer).isoformat()
end = (target + buffer).isoformat()
else:
start = target.isoformat()
end = now.isoformat()
return (start, end)
def search_files(query, top_k=15):
"""Search the index and return results"""
# Parse time expressions
parser = TimeParser()
time_matches = parser.parse(query)
# Remove time expressions from query for semantic search
clean_query = query
if time_matches:
for match in time_matches:
clean_query = clean_query.replace(match["full_match"], "").strip()
# Check if clean_query is less than 4 characters
if len(clean_query) < 4:
print("Error: add more input for accurate results.")
return
# Single query to vector DB
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search(
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
)
# Filter by time if time expression found
if time_matches:
time_range = time_matches[0]["range"] # Use first time expression
start_time, end_time = time_range
filtered_results = []
for result in results:
# Access metadata attribute directly (not .get())
metadata = result.metadata if hasattr(result, "metadata") else {}
if metadata:
# Check modification date first, fall back to creation date
date_str = metadata.get("modification_date") or metadata.get("creation_date")
if date_str:
# Convert strings to datetime objects for proper comparison
try:
file_date = datetime.fromisoformat(date_str)
start_dt = datetime.fromisoformat(start_time)
end_dt = datetime.fromisoformat(end_time)
# Compare dates properly
if start_dt <= file_date <= end_dt:
filtered_results.append(result)
except (ValueError, TypeError):
# Handle invalid date formats
print(f"Warning: Invalid date format in metadata: {date_str}")
continue
results = filtered_results
# Print results
print(f"\nSearch results for: '{query}'")
if time_matches:
print(
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
)
print(
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
)
print("-" * 80)
for i, result in enumerate(results, 1):
print(f"\n[{i}] Score: {result.score:.4f}")
print(f"Content: {result.text}")
# Show metadata if present
metadata = result.metadata if hasattr(result, "metadata") else None
if metadata:
if "creation_date" in metadata:
print(f"Created: {metadata['creation_date']}")
if "modification_date" in metadata:
print(f"Modified: {metadata['modification_date']}")
print("-" * 80)
if __name__ == "__main__":
if len(sys.argv) < 2:
print('Usage: python search_index.py "<search query>" [top_k]')
sys.exit(1)
query = sys.argv[1]
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
search_files(query, top_k)

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
import json
import sys
from pathlib import Path
from leann import LeannBuilder
def process_json_items(json_file_path):
"""Load and process JSON file with metadata items"""
with open(json_file_path, encoding="utf-8") as f:
items = json.load(f)
# Guard against empty JSON
if not items:
print("⚠️ No items found in the JSON file. Exiting gracefully.")
return
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
total_items = len(items)
items_added = 0
print(f"Processing {total_items} items...")
for idx, item in enumerate(items):
try:
# Create embedding text sentence
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
# Prepare metadata with dates
metadata = {}
if "CreationDate" in item:
metadata["creation_date"] = item["CreationDate"]
if "ContentChangeDate" in item:
metadata["modification_date"] = item["ContentChangeDate"]
# Add to builder
builder.add_text(embedding_text, metadata=metadata)
items_added += 1
except Exception as e:
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
continue
# Show progress
progress = (idx + 1) / total_items * 100
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
sys.stdout.flush()
print() # New line after progress
# Guard against no successfully added items
if items_added == 0:
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
return
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
print("Building index...")
try:
builder.build_index(INDEX_PATH)
print(f"✓ Index saved to {INDEX_PATH}")
except ValueError as e:
if "No chunks added" in str(e):
print("⚠️ No chunks were added to the builder. Index not created.")
else:
raise
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python build_index.py <json_file>")
sys.exit(1)
json_file = sys.argv[1]
if not Path(json_file).exists():
print(f"Error: File {json_file} not found")
sys.exit(1)
process_json_items(json_file)

View File

@@ -0,0 +1,265 @@
#!/usr/bin/env python3
"""
Spotlight Metadata Dumper for Vector DB
Extracts only essential metadata for semantic search embeddings
Output is optimized for vector database storage with minimal fields
"""
import json
import sys
from datetime import datetime
# Check platform before importing macOS-specific modules
if sys.platform != "darwin":
print("This script requires macOS (uses Spotlight)")
sys.exit(1)
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
# EDIT THIS LIST: Add or remove folders to search
# Can be either:
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
# - Absolute paths (e.g., "/Applications", "/System/Library")
SEARCH_FOLDERS = [
"Desktop",
"Downloads",
"Documents",
"Music",
"Pictures",
"Movies",
# "Library", # Uncomment to include
# "/Applications", # Absolute path example
# "Code/Projects", # Subfolder example
# Add any other folders here
]
def convert_to_serializable(obj):
"""Convert NS objects to Python serializable types"""
if obj is None:
return None
# Handle NSDate
if hasattr(obj, "timeIntervalSince1970"):
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
# Handle NSArray
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
# Convert to string
try:
return str(obj)
except Exception:
return repr(obj)
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
"""
Dump Spotlight data using public.item predicate
"""
# Build full paths from SEARCH_FOLDERS
import os
home_dir = os.path.expanduser("~")
search_paths = []
print("Search locations:")
for folder in SEARCH_FOLDERS:
# Check if it's an absolute path or relative
if folder.startswith("/"):
full_path = folder
else:
full_path = os.path.join(home_dir, folder)
if os.path.exists(full_path):
search_paths.append(full_path)
print(f"{full_path}")
else:
print(f"{full_path} (not found)")
if not search_paths:
print("No valid search paths found!")
return []
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
# Create query with public.item predicate
query = NSMetadataQuery.alloc().init()
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
query.setPredicate_(predicate)
# Set search scopes to our specific folders
query.setSearchScopes_(search_paths)
print("Starting query...")
query.startQuery()
# Wait for gathering to complete
run_loop = NSRunLoop.currentRunLoop()
print("Gathering results...")
# Let it gather for a few seconds
for i in range(50): # 5 seconds max
run_loop.runMode_beforeDate_(
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
)
# Check gathering status periodically
if i % 10 == 0:
current_count = query.resultCount()
if current_count > 0:
print(f" Found {current_count} items so far...")
# Continue while still gathering (up to 2 more seconds)
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
run_loop.runMode_beforeDate_(
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
)
query.stopQuery()
total_results = query.resultCount()
print(f"Found {total_results} total items")
if total_results == 0:
print("No results found")
return []
# Process items
items_to_process = min(total_results, max_items)
results = []
# ONLY relevant attributes for vector embeddings
# These provide essential context for semantic search without bloat
attributes = [
"kMDItemPath", # Full path for file retrieval
"kMDItemFSName", # Filename for display & embedding
"kMDItemFSSize", # Size for filtering/ranking
"kMDItemContentType", # File type for categorization
"kMDItemKind", # Human-readable type for embedding
"kMDItemFSCreationDate", # Temporal context
"kMDItemFSContentChangeDate", # Recency for ranking
]
print(f"Processing {items_to_process} items...")
for i in range(items_to_process):
try:
item = query.resultAtIndex_(i)
metadata = {}
# Extract ONLY the relevant attributes
for attr in attributes:
try:
value = item.valueForAttribute_(attr)
if value is not None:
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
clean_key = attr.replace("kMDItem", "").replace("FS", "")
metadata[clean_key] = convert_to_serializable(value)
except (AttributeError, ValueError, TypeError):
continue
# Only add if we have at least a path
if metadata.get("Path"):
results.append(metadata)
except Exception as e:
print(f"Error processing item {i}: {e}")
continue
# Save to JSON
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\n✓ Saved {len(results)} items to {output_file}")
# Show summary
print("\nSample items:")
import os
home_dir = os.path.expanduser("~")
for i, item in enumerate(results[:3]):
print(f"\n[Item {i + 1}]")
print(f" Path: {item.get('Path', 'N/A')}")
print(f" Name: {item.get('Name', 'N/A')}")
print(f" Type: {item.get('ContentType', 'N/A')}")
print(f" Kind: {item.get('Kind', 'N/A')}")
# Handle size properly
size = item.get("Size")
if size:
try:
size_int = int(size)
if size_int > 1024 * 1024:
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
elif size_int > 1024:
print(f" Size: {size_int / 1024:.2f} KB")
else:
print(f" Size: {size_int} bytes")
except (ValueError, TypeError):
print(f" Size: {size}")
# Show dates
if "CreationDate" in item:
print(f" Created: {item['CreationDate']}")
if "ContentChangeDate" in item:
print(f" Modified: {item['ContentChangeDate']}")
# Count by type
type_counts = {}
for item in results:
content_type = item.get("ContentType", "unknown")
type_counts[content_type] = type_counts.get(content_type, 0) + 1
print(f"\nTotal items saved: {len(results)}")
if type_counts:
print("\nTop content types:")
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
print(f" {ct}: {count} items")
# Count by folder
folder_counts = {}
for item in results:
path = item.get("Path", "")
for folder in SEARCH_FOLDERS:
# Build the full folder path
if folder.startswith("/"):
folder_path = folder
else:
folder_path = os.path.join(home_dir, folder)
if path.startswith(folder_path):
folder_counts[folder] = folder_counts.get(folder, 0) + 1
break
if folder_counts:
print("\nItems by location:")
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
print(f" {folder}: {count} items")
return results
def main():
# Parse arguments
if len(sys.argv) > 1:
try:
max_items = int(sys.argv[1])
except ValueError:
print("Usage: python spot.py [number_of_items]")
print("Default: 10 items")
sys.exit(1)
else:
max_items = 10
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
# Run dump
dump_spotlight_data(max_items=max_items, output_file=output_file)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1 @@
# Slack MCP data integration for LEANN

View File

@@ -0,0 +1,510 @@
#!/usr/bin/env python3
"""
Slack MCP Reader for LEANN
This module provides functionality to connect to Slack MCP servers and fetch message data
for indexing in LEANN. It supports various Slack MCP server implementations and provides
flexible message processing options.
"""
import asyncio
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
class SlackMCPReader:
"""
Reader for Slack data via MCP (Model Context Protocol) servers.
This class connects to Slack MCP servers to fetch message data and convert it
into a format suitable for LEANN indexing.
"""
def __init__(
self,
mcp_server_command: str,
workspace_name: Optional[str] = None,
concatenate_conversations: bool = True,
max_messages_per_conversation: int = 100,
max_retries: int = 5,
retry_delay: float = 2.0,
):
"""
Initialize the Slack MCP Reader.
Args:
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
workspace_name: Optional workspace name to filter messages
concatenate_conversations: Whether to group messages by channel/thread
max_messages_per_conversation: Maximum messages to include per conversation
max_retries: Maximum number of retries for failed operations
retry_delay: Initial delay between retries in seconds
"""
self.mcp_server_command = mcp_server_command
self.workspace_name = workspace_name
self.concatenate_conversations = concatenate_conversations
self.max_messages_per_conversation = max_messages_per_conversation
self.max_retries = max_retries
self.retry_delay = retry_delay
self.mcp_process = None
async def start_mcp_server(self):
"""Start the MCP server process."""
try:
self.mcp_process = await asyncio.create_subprocess_exec(
*self.mcp_server_command.split(),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
logger.info(f"Started MCP server: {self.mcp_server_command}")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}")
raise
async def stop_mcp_server(self):
"""Stop the MCP server process."""
if self.mcp_process:
self.mcp_process.terminate()
await self.mcp_process.wait()
logger.info("Stopped MCP server")
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send a request to the MCP server and get response."""
if not self.mcp_process:
raise RuntimeError("MCP server not started")
request_json = json.dumps(request) + "\n"
self.mcp_process.stdin.write(request_json.encode())
await self.mcp_process.stdin.drain()
response_line = await self.mcp_process.stdout.readline()
if not response_line:
raise RuntimeError("No response from MCP server")
return json.loads(response_line.decode().strip())
async def initialize_mcp_connection(self):
"""Initialize the MCP connection."""
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
},
}
response = await self.send_mcp_request(init_request)
if "error" in response:
raise RuntimeError(f"MCP initialization failed: {response['error']}")
logger.info("MCP connection initialized successfully")
async def list_available_tools(self) -> list[dict[str, Any]]:
"""List available tools from the MCP server."""
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
response = await self.send_mcp_request(list_request)
if "error" in response:
raise RuntimeError(f"Failed to list tools: {response['error']}")
return response.get("result", {}).get("tools", [])
def _is_cache_sync_error(self, error: dict) -> bool:
"""Check if the error is related to users cache not being ready."""
if isinstance(error, dict):
message = error.get("message", "").lower()
return (
"users cache is not ready" in message or "sync process is still running" in message
)
return False
async def _retry_with_backoff(self, func, *args, **kwargs):
"""Retry a function with exponential backoff, especially for cache sync issues."""
last_exception = None
for attempt in range(self.max_retries + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# Check if this is a cache sync error
error_dict = {}
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
error_dict = e.args[0]
elif "Failed to fetch messages" in str(e):
# Try to extract error from the exception message
import re
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = eval(match.group(1))
except (ValueError, SyntaxError, NameError):
pass
else:
# Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = eval(match.group(1))
except (ValueError, SyntaxError, NameError):
pass
if self._is_cache_sync_error(error_dict):
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt) # Exponential backoff
logger.info(
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
)
await asyncio.sleep(delay)
continue
else:
logger.warning(
f"Cache sync still not ready after {self.max_retries} retries, giving up"
)
break
else:
# Not a cache sync error, don't retry
break
# If we get here, all retries failed or it's not a retryable error
raise last_exception
async def fetch_slack_messages(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
Args:
channel: Optional channel name to filter messages
limit: Maximum number of messages to fetch
Returns:
List of message dictionaries
"""
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
async def _fetch_slack_messages_impl(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Internal implementation of fetch_slack_messages without retry logic.
"""
# This is a generic implementation - specific MCP servers may have different tool names
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
tools = await self.list_available_tools()
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
message_tool = None
# Look for a tool that can fetch messages - prioritize conversations_history
message_tool = None
# First, try to find conversations_history specifically
for tool in tools:
tool_name = tool.get("name", "").lower()
if "conversations_history" in tool_name:
message_tool = tool
logger.info(f"Found conversations_history tool: {tool}")
break
# If not found, look for other message-fetching tools
if not message_tool:
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(
keyword in tool_name
for keyword in ["conversations_search", "message", "history"]
):
message_tool = tool
break
if not message_tool:
raise RuntimeError("No message fetching tool found in MCP server")
# Prepare tool call parameters
tool_params = {"limit": "180d"} # Use 180 days to get older messages
if channel:
# For conversations_history, use channel_id parameter
if message_tool["name"] == "conversations_history":
tool_params["channel_id"] = channel
else:
# Try common parameter names for channel specification
for param_name in ["channel", "channel_id", "channel_name"]:
tool_params[param_name] = channel
break
logger.info(f"Tool parameters: {tool_params}")
fetch_request = {
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {"name": message_tool["name"], "arguments": tool_params},
}
response = await self.send_mcp_request(fetch_request)
if "error" in response:
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
# Extract messages from response - format may vary by MCP server
result = response.get("result", {})
if "content" in result and isinstance(result["content"], list):
# Some MCP servers return content as a list
content = result["content"][0] if result["content"] else {}
if "text" in content:
try:
messages = json.loads(content["text"])
except json.JSONDecodeError:
# If not JSON, try to parse as CSV format (Slack MCP server format)
messages = self._parse_csv_messages(content["text"], channel)
else:
messages = result["content"]
else:
# Direct message format
messages = result.get("messages", [result])
return messages if isinstance(messages, list) else [messages]
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
"""Parse CSV format messages from Slack MCP server."""
import csv
import io
messages = []
try:
# Split by lines and process each line as a CSV row
lines = csv_text.strip().split("\n")
if not lines:
return messages
# Skip header line if it exists
start_idx = 0
if lines[0].startswith("MsgID,UserID,UserName"):
start_idx = 1
for line in lines[start_idx:]:
if not line.strip():
continue
# Parse CSV line
reader = csv.reader(io.StringIO(line))
try:
row = next(reader)
if len(row) >= 7: # Ensure we have enough columns
message = {
"ts": row[0],
"user": row[1],
"username": row[2],
"real_name": row[3],
"channel": row[4],
"thread_ts": row[5],
"text": row[6],
"time": row[7] if len(row) > 7 else "",
"reactions": row[8] if len(row) > 8 else "",
"cursor": row[9] if len(row) > 9 else "",
}
messages.append(message)
except Exception as e:
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
continue
except Exception as e:
logger.warning(f"Failed to parse CSV messages: {e}")
# Fallback: treat entire text as one message
messages = [{"text": csv_text, "channel": channel or "unknown"}]
return messages
def _format_message(self, message: dict[str, Any]) -> str:
"""Format a single message for indexing."""
text = message.get("text", "")
user = message.get("user", message.get("username", "Unknown"))
channel = message.get("channel", message.get("channel_name", "Unknown"))
timestamp = message.get("ts", message.get("timestamp", ""))
# Format timestamp if available
formatted_time = ""
if timestamp:
try:
import datetime
if isinstance(timestamp, str) and "." in timestamp:
dt = datetime.datetime.fromtimestamp(float(timestamp))
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(timestamp, (int, float)):
dt = datetime.datetime.fromtimestamp(timestamp)
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
else:
formatted_time = str(timestamp)
except (ValueError, TypeError):
formatted_time = str(timestamp)
# Build formatted message
parts = []
if channel:
parts.append(f"Channel: #{channel}")
if user:
parts.append(f"User: {user}")
if formatted_time:
parts.append(f"Time: {formatted_time}")
if text:
parts.append(f"Message: {text}")
return "\n".join(parts)
def _create_concatenated_content(self, messages: list[dict[str, Any]], channel: str) -> str:
"""Create concatenated content from multiple messages in a channel."""
if not messages:
return ""
# Sort messages by timestamp if available
try:
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
except (ValueError, TypeError):
pass # Keep original order if timestamps aren't numeric
# Limit messages per conversation
if len(messages) > self.max_messages_per_conversation:
messages = messages[-self.max_messages_per_conversation :]
# Create header
content_parts = [
f"Slack Channel: #{channel}",
f"Message Count: {len(messages)}",
f"Workspace: {self.workspace_name or 'Unknown'}",
"=" * 50,
"",
]
# Add messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
content_parts.append(formatted_msg)
content_parts.append("-" * 30)
content_parts.append("")
return "\n".join(content_parts)
async def get_all_channels(self) -> list[str]:
"""Get list of all available channels."""
try:
channels_list_request = {
"jsonrpc": "2.0",
"id": 4,
"method": "tools/call",
"params": {"name": "channels_list", "arguments": {}},
}
channels_response = await self.send_mcp_request(channels_list_request)
if "result" in channels_response:
result = channels_response["result"]
if "content" in result and isinstance(result["content"], list):
content = result["content"][0] if result["content"] else {}
if "text" in content:
# Parse the channels from the response
channels = []
lines = content["text"].split("\n")
for line in lines:
if line.strip() and ("#" in line or "C" in line[:10]):
# Extract channel ID or name
parts = line.split()
for part in parts:
if part.startswith("C") and len(part) > 5:
channels.append(part)
elif part.startswith("#"):
channels.append(part[1:]) # Remove #
logger.info(f"Found {len(channels)} channels: {channels}")
return channels
return []
except Exception as e:
logger.warning(f"Failed to get channels list: {e}")
return []
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
"""
Read Slack data and return formatted text chunks.
Args:
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
Returns:
List of formatted text chunks ready for LEANN indexing
"""
try:
await self.start_mcp_server()
await self.initialize_mcp_connection()
all_texts = []
if channels:
# Fetch specific channels
for channel in channels:
try:
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
if messages:
if self.concatenate_conversations:
text_content = self._create_concatenated_content(messages, channel)
if text_content.strip():
all_texts.append(text_content)
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
else:
# Fetch from all available channels
logger.info("Fetching from all available channels...")
all_channels = await self.get_all_channels()
if not all_channels:
# Fallback to common channel names if we can't get the list
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
logger.info(f"Using fallback channels: {all_channels}")
for channel in all_channels:
try:
logger.info(f"Searching channel: {channel}")
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
if messages:
if self.concatenate_conversations:
text_content = self._create_concatenated_content(messages, channel)
if text_content.strip():
all_texts.append(text_content)
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
return all_texts
finally:
await self.stop_mcp_server()
async def __aenter__(self):
"""Async context manager entry."""
await self.start_mcp_server()
await self.initialize_mcp_connection()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.stop_mcp_server()

227
apps/slack_rag.py Normal file
View File

@@ -0,0 +1,227 @@
#!/usr/bin/env python3
"""
Slack RAG Application with MCP Support
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
Usage:
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
"""
import argparse
import asyncio
from apps.base_rag_example import BaseRAGExample
from apps.slack_data.slack_mcp_reader import SlackMCPReader
class SlackMCPRAG(BaseRAGExample):
"""
RAG application for Slack messages via MCP servers.
This class provides a complete RAG pipeline for Slack data, including
MCP server connection, data fetching, indexing, and interactive chat.
"""
def __init__(self):
super().__init__(
name="Slack MCP RAG",
description="RAG application for Slack messages via MCP servers",
default_index_name="slack_messages",
)
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add Slack MCP-specific arguments."""
parser.add_argument(
"--mcp-server",
type=str,
required=True,
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
)
parser.add_argument(
"--workspace-name",
type=str,
help="Slack workspace name for better organization and filtering",
)
parser.add_argument(
"--channels",
nargs="+",
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
)
parser.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Group messages by channel/thread for better context (default: True)",
)
parser.add_argument(
"--no-concatenate-conversations",
action="store_true",
help="Process individual messages instead of grouping by channel",
)
parser.add_argument(
"--max-messages-per-channel",
type=int,
default=100,
help="Maximum number of messages to include per channel (default: 100)",
)
parser.add_argument(
"--test-connection",
action="store_true",
help="Test MCP server connection and list available tools without indexing",
)
parser.add_argument(
"--max-retries",
type=int,
default=5,
help="Maximum number of retries for failed operations (default: 5)",
)
parser.add_argument(
"--retry-delay",
type=float,
default=2.0,
help="Initial delay between retries in seconds (default: 2.0)",
)
async def test_mcp_connection(self, args) -> bool:
"""Test the MCP server connection and display available tools."""
print(f"Testing connection to MCP server: {args.mcp_server}")
try:
reader = SlackMCPReader(
mcp_server_command=args.mcp_server,
workspace_name=args.workspace_name,
concatenate_conversations=not args.no_concatenate_conversations,
max_messages_per_conversation=args.max_messages_per_channel,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
async with reader:
tools = await reader.list_available_tools()
print("Successfully connected to MCP server!")
print(f"Available tools ({len(tools)}):")
for i, tool in enumerate(tools, 1):
name = tool.get("name", "Unknown")
description = tool.get("description", "No description available")
print(f"\n{i}. {name}")
print(
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
)
# Show input schema if available
schema = tool.get("inputSchema", {})
if schema.get("properties"):
props = list(schema["properties"].keys())[:3] # Show first 3 properties
print(
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
)
return True
except Exception as e:
print(f"Failed to connect to MCP server: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure the MCP server is installed and accessible")
print("2. Check if the server command is correct")
print("3. Ensure you have proper authentication/credentials configured")
print("4. Try running the MCP server command directly to test it")
return False
async def load_data(self, args) -> list[str]:
"""Load Slack messages via MCP server."""
print(f"Connecting to Slack MCP server: {args.mcp_server}")
if args.workspace_name:
print(f"Workspace: {args.workspace_name}")
# Filter out empty strings from channels
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
if channels:
print(f"Channels: {', '.join(channels)}")
else:
print("Fetching from all available channels")
concatenate = not args.no_concatenate_conversations
print(
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
)
try:
reader = SlackMCPReader(
mcp_server_command=args.mcp_server,
workspace_name=args.workspace_name,
concatenate_conversations=concatenate,
max_messages_per_conversation=args.max_messages_per_channel,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
texts = await reader.read_slack_data(channels=channels)
if not texts:
print("No messages found! This could mean:")
print("- The MCP server couldn't fetch messages")
print("- The specified channels don't exist or are empty")
print("- Authentication issues with the Slack workspace")
return []
print(f"Successfully loaded {len(texts)} text chunks from Slack")
# Show sample of what was loaded
if texts:
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
print("\nSample content:")
print("-" * 40)
print(sample_text)
print("-" * 40)
return texts
except Exception as e:
print(f"Error loading Slack data: {e}")
print("\nThis might be due to:")
print("- MCP server connection issues")
print("- Authentication problems")
print("- Network connectivity issues")
print("- Incorrect channel names")
raise
async def run(self):
"""Main entry point with MCP connection testing."""
args = self.parser.parse_args()
# Test connection if requested
if args.test_connection:
success = await self.test_mcp_connection(args)
if not success:
return
print(
"MCP server is working! You can now run without --test-connection to start indexing."
)
return
# Run the standard RAG pipeline
await super().run()
async def main():
"""Main entry point for the Slack MCP RAG application."""
app = SlackMCPRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1 @@
# Twitter MCP data integration for LEANN

View File

@@ -0,0 +1,295 @@
#!/usr/bin/env python3
"""
Twitter MCP Reader for LEANN
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
flexible bookmark processing options.
"""
import asyncio
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
class TwitterMCPReader:
"""
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
This class connects to Twitter MCP servers to fetch bookmark data and convert it
into a format suitable for LEANN indexing.
"""
def __init__(
self,
mcp_server_command: str,
username: Optional[str] = None,
include_tweet_content: bool = True,
include_metadata: bool = True,
max_bookmarks: int = 1000,
):
"""
Initialize the Twitter MCP Reader.
Args:
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
username: Optional Twitter username to filter bookmarks
include_tweet_content: Whether to include full tweet content
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
max_bookmarks: Maximum number of bookmarks to fetch
"""
self.mcp_server_command = mcp_server_command
self.username = username
self.include_tweet_content = include_tweet_content
self.include_metadata = include_metadata
self.max_bookmarks = max_bookmarks
self.mcp_process = None
async def start_mcp_server(self):
"""Start the MCP server process."""
try:
self.mcp_process = await asyncio.create_subprocess_exec(
*self.mcp_server_command.split(),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
logger.info(f"Started MCP server: {self.mcp_server_command}")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}")
raise
async def stop_mcp_server(self):
"""Stop the MCP server process."""
if self.mcp_process:
self.mcp_process.terminate()
await self.mcp_process.wait()
logger.info("Stopped MCP server")
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send a request to the MCP server and get response."""
if not self.mcp_process:
raise RuntimeError("MCP server not started")
request_json = json.dumps(request) + "\n"
self.mcp_process.stdin.write(request_json.encode())
await self.mcp_process.stdin.drain()
response_line = await self.mcp_process.stdout.readline()
if not response_line:
raise RuntimeError("No response from MCP server")
return json.loads(response_line.decode().strip())
async def initialize_mcp_connection(self):
"""Initialize the MCP connection."""
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
},
}
response = await self.send_mcp_request(init_request)
if "error" in response:
raise RuntimeError(f"MCP initialization failed: {response['error']}")
logger.info("MCP connection initialized successfully")
async def list_available_tools(self) -> list[dict[str, Any]]:
"""List available tools from the MCP server."""
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
response = await self.send_mcp_request(list_request)
if "error" in response:
raise RuntimeError(f"Failed to list tools: {response['error']}")
return response.get("result", {}).get("tools", [])
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> list[dict[str, Any]]:
"""
Fetch Twitter bookmarks using MCP tools.
Args:
limit: Maximum number of bookmarks to fetch
Returns:
List of bookmark dictionaries
"""
tools = await self.list_available_tools()
bookmark_tool = None
# Look for a tool that can fetch bookmarks
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
bookmark_tool = tool
break
if not bookmark_tool:
raise RuntimeError("No bookmark fetching tool found in MCP server")
# Prepare tool call parameters
tool_params = {}
if limit or self.max_bookmarks:
tool_params["limit"] = limit or self.max_bookmarks
if self.username:
tool_params["username"] = self.username
fetch_request = {
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
}
response = await self.send_mcp_request(fetch_request)
if "error" in response:
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
# Extract bookmarks from response
result = response.get("result", {})
if "content" in result and isinstance(result["content"], list):
content = result["content"][0] if result["content"] else {}
if "text" in content:
try:
bookmarks = json.loads(content["text"])
except json.JSONDecodeError:
# If not JSON, treat as plain text
bookmarks = [{"text": content["text"], "source": "twitter"}]
else:
bookmarks = result["content"]
else:
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
def _format_bookmark(self, bookmark: dict[str, Any]) -> str:
"""Format a single bookmark for indexing."""
# Extract tweet information
text = bookmark.get("text", bookmark.get("content", ""))
author = bookmark.get(
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
)
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
url = bookmark.get("url", bookmark.get("tweet_url", ""))
# Extract metadata if available
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
# Build formatted bookmark
parts = []
# Header
parts.append("=== Twitter Bookmark ===")
if author:
parts.append(f"Author: @{author}")
if timestamp:
# Format timestamp if it's a standard format
try:
import datetime
if "T" in str(timestamp): # ISO format
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
else:
formatted_time = str(timestamp)
parts.append(f"Date: {formatted_time}")
except (ValueError, TypeError):
parts.append(f"Date: {timestamp}")
if url:
parts.append(f"URL: {url}")
# Tweet content
if text and self.include_tweet_content:
parts.append("")
parts.append("Content:")
parts.append(text)
# Metadata
if self.include_metadata and any([likes, retweets, replies]):
parts.append("")
parts.append("Engagement:")
if likes:
parts.append(f" Likes: {likes}")
if retweets:
parts.append(f" Retweets: {retweets}")
if replies:
parts.append(f" Replies: {replies}")
# Extract hashtags and mentions if available
hashtags = bookmark.get("hashtags", [])
mentions = bookmark.get("mentions", [])
if hashtags or mentions:
parts.append("")
if hashtags:
parts.append(f"Hashtags: {', '.join(hashtags)}")
if mentions:
parts.append(f"Mentions: {', '.join(mentions)}")
return "\n".join(parts)
async def read_twitter_bookmarks(self) -> list[str]:
"""
Read Twitter bookmark data and return formatted text chunks.
Returns:
List of formatted text chunks ready for LEANN indexing
"""
try:
await self.start_mcp_server()
await self.initialize_mcp_connection()
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
if self.username:
print(f"Filtering for user: @{self.username}")
bookmarks = await self.fetch_twitter_bookmarks()
if not bookmarks:
print("No bookmarks found")
return []
print(f"Processing {len(bookmarks)} bookmarks...")
all_texts = []
processed_count = 0
for bookmark in bookmarks:
try:
formatted_bookmark = self._format_bookmark(bookmark)
if formatted_bookmark.strip():
all_texts.append(formatted_bookmark)
processed_count += 1
except Exception as e:
logger.warning(f"Failed to format bookmark: {e}")
continue
print(f"Successfully processed {processed_count} bookmarks")
return all_texts
finally:
await self.stop_mcp_server()
async def __aenter__(self):
"""Async context manager entry."""
await self.start_mcp_server()
await self.initialize_mcp_connection()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.stop_mcp_server()

195
apps/twitter_rag.py Normal file
View File

@@ -0,0 +1,195 @@
#!/usr/bin/env python3
"""
Twitter RAG Application with MCP Support
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
Usage:
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
"""
import argparse
import asyncio
from apps.base_rag_example import BaseRAGExample
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
class TwitterMCPRAG(BaseRAGExample):
"""
RAG application for Twitter bookmarks via MCP servers.
This class provides a complete RAG pipeline for Twitter bookmark data, including
MCP server connection, data fetching, indexing, and interactive chat.
"""
def __init__(self):
super().__init__(
name="Twitter MCP RAG",
description="RAG application for Twitter bookmarks via MCP servers",
default_index_name="twitter_bookmarks",
)
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add Twitter MCP-specific arguments."""
parser.add_argument(
"--mcp-server",
type=str,
required=True,
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')",
)
parser.add_argument(
"--username", type=str, help="Twitter username to filter bookmarks (without @)"
)
parser.add_argument(
"--max-bookmarks",
type=int,
default=1000,
help="Maximum number of bookmarks to fetch (default: 1000)",
)
parser.add_argument(
"--no-tweet-content",
action="store_true",
help="Exclude tweet content, only include metadata",
)
parser.add_argument(
"--no-metadata",
action="store_true",
help="Exclude engagement metadata (likes, retweets, etc.)",
)
parser.add_argument(
"--test-connection",
action="store_true",
help="Test MCP server connection and list available tools without indexing",
)
async def test_mcp_connection(self, args) -> bool:
"""Test the MCP server connection and display available tools."""
print(f"Testing connection to MCP server: {args.mcp_server}")
try:
reader = TwitterMCPReader(
mcp_server_command=args.mcp_server,
username=args.username,
include_tweet_content=not args.no_tweet_content,
include_metadata=not args.no_metadata,
max_bookmarks=args.max_bookmarks,
)
async with reader:
tools = await reader.list_available_tools()
print("\n✅ Successfully connected to MCP server!")
print(f"Available tools ({len(tools)}):")
for i, tool in enumerate(tools, 1):
name = tool.get("name", "Unknown")
description = tool.get("description", "No description available")
print(f"\n{i}. {name}")
print(
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
)
# Show input schema if available
schema = tool.get("inputSchema", {})
if schema.get("properties"):
props = list(schema["properties"].keys())[:3] # Show first 3 properties
print(
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
)
return True
except Exception as e:
print(f"\n❌ Failed to connect to MCP server: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure the Twitter MCP server is installed and accessible")
print("2. Check if the server command is correct")
print("3. Ensure you have proper Twitter API credentials configured")
print("4. Verify your Twitter account has bookmarks to fetch")
print("5. Try running the MCP server command directly to test it")
return False
async def load_data(self, args) -> list[str]:
"""Load Twitter bookmarks via MCP server."""
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
if args.username:
print(f"Username filter: @{args.username}")
print(f"Max bookmarks: {args.max_bookmarks}")
print(f"Include tweet content: {not args.no_tweet_content}")
print(f"Include metadata: {not args.no_metadata}")
try:
reader = TwitterMCPReader(
mcp_server_command=args.mcp_server,
username=args.username,
include_tweet_content=not args.no_tweet_content,
include_metadata=not args.no_metadata,
max_bookmarks=args.max_bookmarks,
)
texts = await reader.read_twitter_bookmarks()
if not texts:
print("❌ No bookmarks found! This could mean:")
print("- You don't have any bookmarks on Twitter")
print("- The MCP server couldn't access your bookmarks")
print("- Authentication issues with Twitter API")
print("- The username filter didn't match any bookmarks")
return []
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
# Show sample of what was loaded
if texts:
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
print("\nSample bookmark:")
print("-" * 50)
print(sample_text)
print("-" * 50)
return texts
except Exception as e:
print(f"❌ Error loading Twitter bookmarks: {e}")
print("\nThis might be due to:")
print("- MCP server connection issues")
print("- Twitter API authentication problems")
print("- Network connectivity issues")
print("- Rate limiting from Twitter API")
raise
async def run(self):
"""Main entry point with MCP connection testing."""
args = self.parser.parse_args()
# Test connection if requested
if args.test_connection:
success = await self.test_mcp_connection(args)
if not success:
return
print(
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
)
return
# Run the standard RAG pipeline
await super().run()
async def main():
"""Main entry point for the Twitter MCP RAG application."""
app = TwitterMCPRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -54,29 +54,51 @@ def extract_thinking_answer(response):
return response.strip()
def load_hf_model(model_name="Qwen/Qwen3-8B"):
"""Load HuggingFace model"""
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
"""Load HuggingFace model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading HF: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
trust_remote_code=trust_remote_code,
)
return tokenizer, model
def load_vllm_model(model_name="Qwen/Qwen3-8B"):
"""Load vLLM model"""
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
"""Load vLLM model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not VLLM_AVAILABLE:
raise ImportError("vllm not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading vLLM: {model_name}")
llm = LLM(model=model_name, trust_remote_code=True)
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
# Qwen3 specific config
if is_qwen3_model(model_name):
@@ -178,19 +200,33 @@ def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complex
}
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
"""Load Qwen2.5-VL multimodal model"""
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
"""Load Qwen2.5-VL multimodal model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading Qwen2.5-VL: {model_name}")
try:
from transformers import AutoModelForVision2Seq, AutoProcessor
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model = AutoModelForVision2Seq.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=trust_remote_code,
)
return processor, model
@@ -202,9 +238,14 @@ def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
try:
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=trust_remote_code,
)
return processor, model

143
benchmarks/update/README.md Normal file
View File

@@ -0,0 +1,143 @@
# Update Benchmarks
This directory hosts two benchmark suites that exercise LEANNs HNSW “update +
search” pipeline under different assumptions:
1. **RNG recompute latency** measure how random-neighbour pruning and cache
settings influence incremental `add()` latency when embeddings are fetched
over the ZMQ embedding server.
2. **Update strategy comparison** compare a fully sequential update pipeline
against an offline approach that keeps the graph static and fuses results.
Both suites build a non-compact, `is_recompute=True` index so that new
embeddings are pulled from the embedding server. Benchmark outputs are written
under `.leann/bench/` by default and appended to CSV files for later plotting.
## Benchmarks
### 1. HNSW RNG Recompute Benchmark
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
changes the forward / reverse RNG pruning flags and whether the embedding cache
is enabled:
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
| ---------------------------------- | ----------- | ----------- | ------------------- |
| `baseline` | Enabled | Enabled | Enabled |
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
For each scenario the script:
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
3. Appends the requested updates using the scenarios RNG flags.
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
timings before appending a row to the CSV output.
**Run:**
```bash
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
LEANN_LOG_LEVEL=INFO \
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--runs 1 \
--index-path .leann/bench/test.leann \
--initial-files data/PrideandPrejudice.txt \
--update-files data/huawei_pangu.md \
--max-initial 300 \
--max-updates 1 \
--add-timeout 120
```
**Output:**
- `benchmarks/update/bench_results.csv` per-scenario timing statistics
(including ms/passage) for each run.
- `.leann/bench/hnsw_server.log` detailed ZMQ/server logs (path controlled by
`LEANN_HNSW_LOG_PATH`).
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
### 2. Sequential vs. Offline Update Benchmark
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
same dataset:
- **Scenario A Sequential Update**
- Start an embedding server.
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
mutates the HNSW graph.
- After all inserts, run a search on the updated graph.
- Metrics recorded: update time (`add_total_s`), post-update search time
(`search_time_s`), combined total (`total_time_s`), and per-passage
latency.
- **Scenario B Offline Embedding + Concurrent Search**
- Stop Scenario As server and start a fresh embedding server.
- Spawn two threads: one generates embeddings for the new passages offline
(graph unchanged); the other computes the query embedding and searches the
existing graph.
- Merge offline similarities with the graph search results to emulate late
fusion, then report the merged topk preview.
- Metrics recorded: embedding time (`emb_time_s`), search time
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
**Run (both scenarios):**
```bash
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 \
--num-updates 1
```
You can pass `--only A` or `--only B` to run a single scenario. The script will
print timing summaries to stdout and append the results to CSV.
**Output:**
- `benchmarks/update/offline_vs_update.csv` per-scenario timing statistics for
Scenario A and B.
- Console output includes Scenario Bs merged topk preview for quick sanity
checks.
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
### 3. Visualisation
`plot_bench_results.py` combines the RNG benchmark and the update strategy
benchmark into a single two-panel plot.
**Run:**
```bash
uv run -m benchmarks.update.plot_bench_results \
--csv benchmarks/update/bench_results.csv \
--csv-right benchmarks/update/offline_vs_update.csv \
--out benchmarks/update/bench_latency_from_csv.png
```
**Options:**
- `--broken-y` Enable a broken Y-axis (default: true when appropriate).
- `--csv` RNG benchmark results CSV (left panel).
- `--csv-right` Update strategy results CSV (right panel).
- `--out` Output image path (PNG/PDF supported).
**Output:**
- `benchmarks/update/bench_latency_from_csv.png` visual comparison of the two
suites.
- `benchmarks/update/bench_latency_from_csv.pdf` PDF version, suitable for
slides/papers.
## Parameters & Environment
### Common CLI Flags
- `--max-initial` Number of initial passages used to seed the index.
- `--max-updates` / `--num-updates` Number of passages to treat as updates.
- `--index-path` Base path (without extension) where the LEANN index is stored.
- `--runs` Number of repetitions (RNG benchmark only).
### Environment Variables
- `LEANN_HNSW_LOG_PATH` File to receive embedding-server logs (optional).
- `LEANN_LOG_LEVEL` Logging verbosity (DEBUG/INFO/WARNING/ERROR).
- `CUDA_VISIBLE_DEVICES` Set to empty string if you want to force CPU
execution of the embedding model.
With these scripts you can easily replicate LEANNs update benchmarks, compare
multiple RNG strategies, and evaluate whether sequential updates or offline
fusion better match your latency/accuracy trade-offs.

View File

@@ -0,0 +1,16 @@
"""Benchmarks for LEANN update workflows."""
# Expose helper to locate repository root for other modules that need it.
from pathlib import Path
def find_repo_root() -> Path:
"""Return the project root containing pyproject.toml."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
return current.parents[1]
__all__ = ["find_repo_root"]

View File

@@ -0,0 +1,804 @@
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
embedding recomputation.
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
so that we build a non-compact ``is_recompute=True`` index, spin up the
standard HNSW embedding server, and measure how long incremental ``add`` takes
when RNG pruning is fully enabled vs. partially/fully disabled.
Example usage (run from the repo root; downloads the model on first run)::
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--index-path .leann/bench/leann-demo.leann \
--runs 1
You can tweak the input documents with ``--initial-files`` / ``--update-files``
if you want a larger or different workload, and change the embedding model via
``--model-name``.
"""
import argparse
import json
import logging
import os
import pickle
import re
import sys
import time
from pathlib import Path
from typing import Any
import msgpack
import numpy as np
import zmq
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
return [{"text": text, "metadata": {}} for text in paragraphs]
def benchmark_update_with_mode(
index_path: Path,
new_chunks: list[dict[str, Any]],
model_name: str,
embedding_mode: str,
distance_metric: str,
disable_forward_rng: bool,
disable_reverse_rng: bool,
server_port: int,
add_timeout: int,
ef_construction: int,
) -> tuple[float, float]:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
index_file = index_path.parent / f"{index_path.stem}.index"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in new_chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
model_name,
mode=embedding_mode,
is_build=False,
batch_size=16,
)
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
index.is_recompute = True
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
try:
storage_index.ntotal = index.ntotal
except AttributeError:
pass
try:
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
if ef_construction is not None:
index.hnsw.efConstruction = ef_construction
except AttributeError:
pass
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
logger.info(
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
disable_forward_rng,
disable_reverse_rng,
applied_forward,
applied_reverse,
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=server_port,
model_name=model_name,
embedding_mode=embedding_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
)
if not server_started:
raise RuntimeError("Failed to start embedding server.")
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
_warmup_embedding_server(actual_port)
total_start = time.time()
add_elapsed = 0.0
try:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("incremental add timed out")
if add_timeout > 0:
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(add_timeout)
add_start = time.time()
for i in range(embeddings.shape[0]):
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
add_elapsed = time.time() - add_start
if add_timeout > 0:
signal.alarm(0)
faiss.write_index(index, str(index_file))
finally:
server_manager.stop_server()
except TimeoutError:
raise
except Exception:
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_size)
with open(offset_file, "wb") as f:
pickle.dump(offset_map_backup, f)
raise
prune_hnsw_embeddings_inplace(str(index_file))
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
# Reset toggles so the index on disk returns to baseline behaviour.
try:
index.hnsw.set_disable_rng_during_add(False)
index.hnsw.set_disable_reverse_prune(False)
except AttributeError:
pass
faiss.write_index(index, str(index_file))
total_elapsed = time.time() - total_start
return total_elapsed, add_elapsed
def _total_zmq_nodes(log_path: Path) -> int:
if not log_path.exists():
return 0
with log_path.open("r", encoding="utf-8") as log_file:
text = log_file.read()
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
def _warmup_embedding_server(port: int) -> None:
"""Send a dummy REQ so the embedding server loads its model."""
ctx = zmq.Context()
try:
sock = ctx.socket(zmq.REQ)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.RCVTIMEO, 5000)
sock.setsockopt(zmq.SNDTIMEO, 5000)
sock.connect(f"tcp://127.0.0.1:{port}")
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
sock.send(payload)
try:
sock.recv()
except zmq.error.Again:
pass
finally:
sock.close()
ctx.term()
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/leann-demo.leann"),
help="Output index base path (without extension).",
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
help="Files used to build the initial index.",
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
help="Files appended during the benchmark.",
)
parser.add_argument(
"--runs", type=int, default=1, help="How many times to repeat each scenario."
)
parser.add_argument(
"--model-name",
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model used for build/update.",
)
parser.add_argument(
"--embedding-mode",
default="sentence-transformers",
help="Embedding mode passed to LeannBuilder/embedding server.",
)
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
help="Distance metric for HNSW backend.",
)
parser.add_argument(
"--ef-construction",
type=int,
default=200,
help="efConstruction setting for initial build.",
)
parser.add_argument(
"--server-port",
type=int,
default=5557,
help="Port for the real embedding server.",
)
parser.add_argument(
"--max-initial",
type=int,
default=300,
help="Optional cap on initial passages (after chunking).",
)
parser.add_argument(
"--max-updates",
type=int,
default=1,
help="Optional cap on update passages (after chunking).",
)
parser.add_argument(
"--add-timeout",
type=int,
default=900,
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
)
parser.add_argument(
"--plot-path",
type=Path,
default=Path("bench_latency.png"),
help="Where to save the latency bar plot.",
)
parser.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
)
parser.add_argument(
"--broken-y",
action="store_true",
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
)
parser.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
)
parser.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Where to append per-scenario results as CSV.",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
if not update_paragraphs:
raise ValueError("No update passages found; please provide --update-files with content.")
update_chunks = prepare_new_chunks(update_paragraphs)
ensure_index_dir(args.index_path)
scenarios = [
("baseline", False, False, True),
("no_cache_baseline", False, False, False),
("disable_forward_rng", True, False, True),
("disable_forward_and_reverse_rng", True, True, True),
]
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
log_path.parent.mkdir(parents=True, exist_ok=True)
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
# CSV setup
import csv
run_id = time.strftime("%Y%m%d-%H%M%S")
csv_fields = [
"run_id",
"scenario",
"cache_enabled",
"ef_construction",
"max_initial",
"max_updates",
"total_time_s",
"add_only_s",
"latency_ms_per_passage",
"zmq_nodes",
"stageA_time_s",
"stageBC_time_s",
"model_name",
"embedding_mode",
"distance_metric",
]
# Create CSV with header if missing
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
for run in range(args.runs):
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
print(f"\nScenario: {name}")
cleanup_index_files(args.index_path)
if log_path.exists():
try:
log_path.unlink()
except OSError:
pass
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
prev_size = log_path.stat().st_size if log_path.exists() else 0
try:
total_elapsed, add_elapsed = benchmark_update_with_mode(
args.index_path,
update_chunks,
args.model_name,
args.embedding_mode,
args.distance_metric,
disable_forward,
disable_reverse,
args.server_port,
args.add_timeout,
args.ef_construction,
)
except TimeoutError as exc:
print(f"Scenario {name} timed out: {exc}")
continue
curr_size = log_path.stat().st_size if log_path.exists() else 0
if curr_size < prev_size:
prev_size = 0
zmq_count = 0
if log_path.exists():
with log_path.open("r", encoding="utf-8") as log_file:
log_file.seek(prev_size)
new_entries = log_file.read()
zmq_count = sum(
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
)
stageA = sum(
float(x)
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
)
stageBC = sum(
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
)
else:
stageA = 0.0
stageBC = 0.0
per_chunk = add_elapsed / len(update_chunks)
print(
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
)
print(f"ZMQ node fetch total: {zmq_count}")
results_total[name].append(total_elapsed)
results_add[name].append(add_elapsed)
results_zmq[name].append(zmq_count)
results_ms_per_passage[name].append(per_chunk * 1e3)
results_stageA[name].append(stageA)
results_stageBC[name].append(stageBC)
# Append row to CSV
if args.csv_path:
row = {
"run_id": run_id,
"scenario": name,
"cache_enabled": 1 if cache_enabled else 0,
"ef_construction": args.ef_construction,
"max_initial": args.max_initial,
"max_updates": args.max_updates,
"total_time_s": round(total_elapsed, 6),
"add_only_s": round(add_elapsed, 6),
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
"zmq_nodes": int(zmq_count),
"stageA_time_s": round(stageA, 6),
"stageBC_time_s": round(stageBC, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row)
print("\n=== Summary ===")
for name in results_add:
add_values = results_add[name]
total_values = results_total[name]
zmq_values = results_zmq[name]
latency_values = results_ms_per_passage[name]
if not add_values:
print(f"{name}: no successful runs")
continue
avg_add = sum(add_values) / len(add_values)
avg_total = sum(total_values) / len(total_values)
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
runs = len(add_values)
print(
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
)
if args.plot_path:
try:
import matplotlib.pyplot as plt
labels = [name for name, *_ in scenarios]
values = [
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
if results_ms_per_passage[name]
else 0.0
for name in labels
]
def _auto_cap(vals: list[float]) -> float | None:
s = sorted(vals, reverse=True)
if len(s) < 2:
return None
if s[1] > 0 and s[0] >= 2.5 * s[1]:
return s[1] * 1.1
return None
def _fmt_ms(v: float) -> str:
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
if args.broken_y:
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.4, 5.0),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
)
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i,
v + lower_cap * 0.02,
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False)
ax_bottom.xaxis.tick_bottom()
d = 0.015
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
ax_top.plot((-d, +d), (-d, +d), **kwargs)
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_bottom.set_xticks(range(len(labels)))
ax_bottom.set_xticklabels(labels)
ax = ax_bottom
else:
cap = args.cap_y or _auto_cap(values)
plt.figure(figsize=(7.2, 4.2))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (v, show) in enumerate(zip(values, show_vals)):
b = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(b[0])
if v > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
ax.plot(
[0.02 - 0.02, 0.02 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
ax.plot(
[0.98 - 0.02, 0.98 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
if any(v > cap for v in values):
ax.legend(
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels)
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
plt.ylabel("Average add latency (ms per passage)")
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
plt.tight_layout()
plt.savefig(args.plot_path)
print(f"Saved latency bar plot to {args.plot_path}")
# ZMQ time split (Stage A vs B/C)
try:
plt.figure(figsize=(6, 4))
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
bc_vals = [
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
]
ind = range(len(labels))
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
plt.bar(
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
)
plt.xticks(list(ind), labels, rotation=10)
plt.ylabel("Server ZMQ time (s)")
plt.title(
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
)
plt.legend()
out2 = args.plot_path.with_name(
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
)
plt.tight_layout()
plt.savefig(out2)
print(f"Saved ZMQ time split plot to {out2}")
except Exception as e:
print("Failed to plot ZMQ split:", e)
except ImportError:
print("matplotlib not available; skipping plot generation")
# leave the last build on disk for inspection
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario cache_enabled ef_construction max_initial max_updates total_time_s add_only_s latency_ms_per_passage zmq_nodes stageA_time_s stageBC_time_s model_name embedding_mode distance_metric
2 20251024-133101 baseline 1 200 300 1 3.391856 1.120359 1120.359421 126 0.507821 0.601608 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-133101 no_cache_baseline 0 200 300 1 34.941514 32.91376 32913.760185 4033 0.506933 32.159928 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251024-133101 disable_forward_rng 1 200 300 1 2.746756 0.8202 820.200443 66 0.474354 0.338454 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251024-133101 disable_forward_and_reverse_rng 1 200 300 1 2.396566 0.521478 521.478415 1 0.508973 0.006938 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -0,0 +1,704 @@
"""
Compare two latency models for small incremental updates vs. search:
Scenario A (sequential update then search):
- Build initial HNSW (is_recompute=True)
- Start embedding server (ZMQ) for recompute
- Add N passages one-by-one (each triggers recompute over ZMQ)
- Then run a search query on the updated index
- Report total time = sum(add_i) + search_time, with breakdowns
Scenario B (offline embeds + concurrent search; no graph updates):
- Do NOT insert the N passages into the graph
- In parallel: (1) compute embeddings for the N passages; (2) compute query
embedding and run a search on the existing index
- After both finish, compute similarity between the query embedding and the N
new passage embeddings, merge with the index search results by score, and
report time = max(embed_time, search_time) (i.e., no blocking on updates)
This script reuses the model/data loading conventions of
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
comparison for the two execution strategies above.
Example (from the repository root):
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 --num-updates 5 --k 10
"""
import argparse
import csv
import json
import logging
import os
import pickle
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import psutil # type: ignore
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
if metric == "cosine":
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
norms[norms == 0] = 1
vecs = vecs / norms
return vecs
def _read_index_for_search(index_path: Path) -> Any:
index_file = index_path.parent / f"{index_path.stem}.index"
# Force-disable experimental disk cache when loading the index so that
# incremental benchmarks don't pick up stale top-degree bitmaps.
cfg = faiss.HNSWIndexConfig()
cfg.is_recompute = True
if hasattr(cfg, "disk_cache_ratio"):
cfg.disk_cache_ratio = 0.0
if hasattr(cfg, "external_storage_path"):
cfg.external_storage_path = None
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
index = faiss.read_index(str(index_file), io_flags, cfg)
# ensure recompute mode persists after reload
try:
index.is_recompute = True
except AttributeError:
pass
try:
actual_ntotal = index.hnsw.levels.size()
except AttributeError:
actual_ntotal = index.ntotal
if actual_ntotal != index.ntotal:
print(
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
flush=True,
)
index.ntotal = actual_ntotal
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
return index
def _append_passages_for_updates(
meta_path: Path,
start_id: int,
texts: list[str],
) -> list[str]:
"""Append update passages so the embedding server can serve recompute fetches."""
if not texts:
return []
index_dir = meta_path.parent
meta_name = meta_path.name
if not meta_name.endswith(".meta.json"):
raise ValueError(f"Unexpected meta filename: {meta_path}")
index_base = meta_name[: -len(".meta.json")]
passages_file = index_dir / f"{index_base}.passages.jsonl"
offsets_file = index_dir / f"{index_base}.passages.idx"
if not passages_file.exists() or not offsets_file.exists():
raise FileNotFoundError(
"Passage store missing; cannot register update passages for recompute mode."
)
with open(offsets_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
assigned_ids: list[str] = []
with open(passages_file, "a", encoding="utf-8") as f:
for i, text in enumerate(texts):
passage_id = str(start_id + i)
offset = f.tell()
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
f.write("\n")
offset_map[passage_id] = offset
assigned_ids.append(passage_id)
with open(offsets_file, "wb") as f:
pickle.dump(offset_map, f)
try:
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
except json.JSONDecodeError:
meta = {}
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
return assigned_ids
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
q = np.ascontiguousarray(q, dtype=np.float32)
distances = np.zeros((1, k), dtype=np.float32)
indices = np.zeros((1, k), dtype=np.int64)
index.search(
1,
faiss.swig_ptr(q),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(indices),
)
return distances[0], indices[0]
def _score_for_metric(dist: float, metric: str) -> float:
# Convert FAISS distance to a "higher is better" score
if metric in ("mips", "cosine"):
return float(dist)
# l2 distance (smaller better) -> negative distance as score
return -float(dist)
def _merge_results(
index_results: tuple[np.ndarray, np.ndarray],
offline_scores: list[tuple[int, float]],
k: int,
metric: str,
) -> list[tuple[str, float]]:
distances, indices = index_results
merged: list[tuple[str, float]] = []
for distance, idx in zip(distances.tolist(), indices.tolist()):
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
for j, s in offline_scores:
merged.append((f"offline:{j}", s))
merged.sort(key=lambda x: x[1], reverse=True)
return merged[:k]
@dataclass
class ScenarioResult:
name: str
update_total_s: float
search_s: float
overall_s: float
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/offline-vs-update.leann"),
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
)
parser.add_argument("--max-initial", type=int, default=300)
parser.add_argument("--num-updates", type=int, default=5)
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
parser.add_argument(
"--query",
type=str,
default="neural network",
help="Query text used for the search benchmark.",
)
parser.add_argument("--server-port", type=int, default=5557)
parser.add_argument("--add-timeout", type=int, default=600)
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
parser.add_argument("--embedding-mode", default="sentence-transformers")
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
)
parser.add_argument("--ef-construction", type=int, default=200)
parser.add_argument(
"--only",
choices=["A", "B", "both"],
default="both",
help="Run only Scenario A, Scenario B, or both",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Where to append results (CSV).",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
# Load data
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, None)
if not update_paragraphs:
raise ValueError("No update passages loaded from --update-files")
update_paragraphs = update_paragraphs[: args.num_updates]
if len(update_paragraphs) < args.num_updates:
raise ValueError(
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
)
ensure_index_dir(args.index_path)
cleanup_index_files(args.index_path)
# Build initial index
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
# Prepare index object and meta
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
index = _read_index_for_search(args.index_path)
# CSV setup
run_id = time.strftime("%Y%m%d-%H%M%S")
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
csv_fields = [
"run_id",
"scenario",
"max_initial",
"num_updates",
"k",
"total_time_s",
"add_total_s",
"search_time_s",
"emb_time_s",
"makespan_s",
"model_name",
"embedding_mode",
"distance_metric",
]
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
# Debug: list existing HNSW server PIDs before starting
try:
existing = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if existing:
print("[debug] Found existing hnsw_embedding_server processes before run:")
for p in existing:
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
except Exception as _e:
pass
add_total = 0.0
search_after_add = 0.0
total_seq = 0.0
port_a = None
if args.only in ("A", "both"):
# Scenario A: sequential update then search
start_id = index.ntotal
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
if assigned_ids:
logger.debug(
"Registered %d update passages starting at id %s",
len(assigned_ids),
assigned_ids[0],
)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
ok, port = server_manager.start_server(
port=args.server_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok:
raise RuntimeError("Failed to start embedding server")
try:
# Set ZMQ port for recompute mode
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(port)
# Start A overall timer BEFORE computing update embeddings
t0 = time.time()
# Compute embeddings for updates (counted into A's overall)
t_emb0 = time.time()
upd_embs = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time_updates = time.time() - t_emb0
upd_embs = np.asarray(upd_embs, dtype=np.float32)
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
# Perform sequential adds
for i in range(upd_embs.shape[0]):
t_add0 = time.time()
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
add_total += time.time() - t_add0
# Don't persist index after adds to avoid contaminating Scenario B
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
# faiss.write_index(index, str(index_file))
# Search after updates
q_emb = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_emb = np.asarray(q_emb, dtype=np.float32)
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
# Warm up search with a dummy query first
print("[DEBUG] Warming up search...")
_ = _search(index, q_emb, 1)
t_s0 = time.time()
D_upd, I_upd = _search(index, q_emb, args.k)
search_after_add = time.time() - t_s0
total_seq = time.time() - t0
finally:
server_manager.stop_server()
port_a = port
print("\n=== Scenario A: update->search (sequential) ===")
# emb_time_updates is defined only when A runs
try:
_emb_a = emb_time_updates
except NameError:
_emb_a = 0.0
print(
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
)
# CSV row for A
if args.csv_path:
row_a = {
"run_id": run_id,
"scenario": "A",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": round(total_seq, 6),
"add_total_s": round(add_total, 6),
"search_time_s": round(search_after_add, 6),
"emb_time_s": round(_emb_a, 6),
"makespan_s": 0.0,
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_a)
# Verify server cleanup
try:
# short sleep to allow signal handling to finish
time.sleep(0.5)
leftovers = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if leftovers:
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
for p in leftovers:
print(
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
)
else:
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
except Exception:
pass
# Scenario B: offline embeds + concurrent search (no graph updates)
if args.only in ("B", "both"):
# ensure a server is available for recompute search
server_manager_b = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
requested_port = args.server_port if port_a is None else port_a
ok_b, port_b = server_manager_b.start_server(
port=requested_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok_b:
raise RuntimeError("Failed to start embedding server for Scenario B")
# Wait for server to fully initialize
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
time.sleep(2)
try:
# Read the index first
index_no_update = _read_index_for_search(args.index_path) # unchanged index
# Then configure ZMQ port on the correct index object
if hasattr(index_no_update.hnsw, "set_zmq_port"):
index_no_update.hnsw.set_zmq_port(port_b)
elif hasattr(index_no_update, "set_zmq_port"):
index_no_update.set_zmq_port(port_b)
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
logger.info("Warming up embedding model for Scenario B...")
_ = compute_embeddings(
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
)
# Prepare worker A: compute embeddings for the same N passages
emb_time = 0.0
updates_embs_offline: np.ndarray | None = None
def _worker_emb():
nonlocal emb_time, updates_embs_offline
t = time.time()
updates_embs_offline = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time = time.time() - t
# Pre-compute query embedding and warm up search outside of timed section.
q_vec = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_vec = np.asarray(q_vec, dtype=np.float32)
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
print("[DEBUG B] Warming up search...")
_ = _search(index_no_update, q_vec, 1)
# Worker B: timed search on the warmed index
search_time = 0.0
offline_elapsed = 0.0
index_results: tuple[np.ndarray, np.ndarray] | None = None
def _worker_search():
nonlocal search_time, index_results
t = time.time()
distances, indices = _search(index_no_update, q_vec, args.k)
search_time = time.time() - t
index_results = (distances, indices)
# Run two workers concurrently
t0 = time.time()
th1 = threading.Thread(target=_worker_emb)
th2 = threading.Thread(target=_worker_search)
th1.start()
th2.start()
th1.join()
th2.join()
offline_elapsed = time.time() - t0
# For mixing: compute query vs. offline update similarities (pure client-side)
offline_scores: list[tuple[int, float]] = []
if updates_embs_offline is not None:
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
for j in range(upd2.shape[0]):
if args.distance_metric in ("mips", "cosine"):
s = float(np.dot(q_vec[0], upd2[j]))
else:
diff = q_vec[0] - upd2[j]
s = -float(np.dot(diff, diff))
offline_scores.append((j, s))
merged_topk = (
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
if index_results
else []
)
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
print(
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
)
if merged_topk:
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
print(f"Merged top-5 preview: {preview}")
# CSV row for B
if args.csv_path:
row_b = {
"run_id": run_id,
"scenario": "B",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": 0.0,
"add_total_s": 0.0,
"search_time_s": round(search_time, 6),
"emb_time_s": round(emb_time, 6),
"makespan_s": round(offline_elapsed, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_b)
finally:
server_manager_b.stop_server()
# Summary
print("\n=== Summary ===")
msg_a = (
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
if args.only in ("A", "both")
else "A: skipped"
)
msg_b = (
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
if args.only in ("B", "both")
else "B: skipped"
)
print(msg_a + "\n" + msg_b)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario max_initial num_updates k total_time_s add_total_s search_time_s emb_time_s makespan_s model_name embedding_mode distance_metric
2 20251024-141607 A 300 1 10 3.273957 3.050168 0.097825 0.017339 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-141607 B 300 1 10 0.0 0.0 0.111892 0.007869 0.112635 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251025-160652 A 300 5 10 5.061945 4.805962 0.123271 0.015008 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251025-160652 B 300 5 10 0.0 0.0 0.101809 0.008817 0.102447 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -0,0 +1,645 @@
#!/usr/bin/env python3
"""
Plot latency bars from the benchmark CSV produced by
benchmarks/update/bench_hnsw_rng_recompute.py.
If you also provide an offline_vs_update.csv via --csv-right
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
output a side-by-side figure:
- Left: ms/passage bars (four RNG scenarios).
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
Usage:
uv run python benchmarks/update/plot_bench_results.py \
--csv benchmarks/update/bench_results.csv \
--out benchmarks/update/bench_latency_from_csv.png
The script selects the latest run_id in the CSV and plots four bars for
the default scenarios:
- baseline
- no_cache_baseline
- disable_forward_rng
- disable_forward_and_reverse_rng
If multiple rows exist per scenario for that run_id, the script averages
their latency_ms_per_passage values.
"""
import argparse
import csv
from collections import defaultdict
from pathlib import Path
DEFAULT_SCENARIOS = [
"no_cache_baseline",
"baseline",
"disable_forward_rng",
"disable_forward_and_reverse_rng",
]
SCENARIO_LABELS = {
"baseline": "+ Cache",
"no_cache_baseline": "Naive \n Recompute",
"disable_forward_rng": "+ w/o \n Fwd RNG",
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
}
# Paper-style colors and hatches for scenarios
SCENARIO_STYLES = {
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
}
def load_latest_run(csv_path: Path):
rows = []
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
if not rows:
raise SystemExit("CSV is empty: no rows to plot")
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
run_ids = [r.get("run_id", "") for r in rows]
latest = max(run_ids)
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
if not latest_rows:
# Fallback: take last 4 rows
latest_rows = rows[-4:]
latest = latest_rows[-1].get("run_id", "unknown")
return latest, latest_rows
def aggregate_latency(rows):
acc = defaultdict(list)
for r in rows:
sc = r.get("scenario", "")
try:
val = float(r.get("latency_ms_per_passage", "nan"))
except ValueError:
continue
acc[sc].append(val)
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
return avg
def _auto_cap(values: list[float]) -> float | None:
if not values:
return None
sorted_vals = sorted(values, reverse=True)
if len(sorted_vals) < 2:
return None
max_v, second = sorted_vals[0], sorted_vals[1]
if second <= 0:
return None
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
if max_v >= 2.5 * second:
return second * 1.1
return None
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
# Draw small diagonal ticks near left/right to signal cap
x0, x1 = rel_x0, rel_x1
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
def _fmt_ms(v: float) -> str:
if v >= 1000:
return f"{v / 1000:.1f}k"
return f"{v:.1f}"
def main():
# Set LaTeX style for paper figures (matching paper_fig.py)
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument(
"--csv",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Path to results CSV (defaults to bench_results.csv)",
)
ap.add_argument(
"--out",
type=Path,
default=Path("add_ablation.pdf"),
help="Output image path",
)
ap.add_argument(
"--csv-right",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
)
ap.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
)
ap.add_argument(
"--no-auto-cap",
action="store_true",
help="Disable auto-cap heuristic when --cap-y is not provided.",
)
ap.add_argument(
"--broken-y",
action="store_true",
default=True,
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
)
ap.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
)
ap.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
)
args = ap.parse_args()
latest_run, latest_rows = load_latest_run(args.csv)
avg = aggregate_latency(latest_rows)
try:
import matplotlib.pyplot as plt
except Exception as e:
raise SystemExit(f"matplotlib not available: {e}")
scenarios = DEFAULT_SCENARIOS
values = [avg.get(name, 0.0) for name in scenarios]
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
# If right CSV is provided, build side-by-side figure
if args.csv_right is not None:
try:
right_rows_all = []
with args.csv_right.open("r", encoding="utf-8") as f:
rreader = csv.DictReader(f)
right_rows_all = list(rreader)
if right_rows_all:
r_latest = max(r.get("run_id", "") for r in right_rows_all)
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
else:
r_latest = None
right_rows = []
except Exception:
r_latest = None
right_rows = []
a_total = 0.0
b_makespan = 0.0
for r in right_rows:
sc = (r.get("scenario", "") or "").strip().upper()
if sc == "A":
try:
a_total = float(r.get("total_time_s", 0.0))
except Exception:
pass
elif sc == "B":
try:
b_makespan = float(r.get("makespan_s", 0.0))
except Exception:
pass
import matplotlib.pyplot as plt
from matplotlib import gridspec
# Left subplot (reuse current style, with optional cap)
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
x = list(range(len(labels)))
if args.broken_y:
# Use broken axis for left subplot
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
gs = gridspec.GridSpec(
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
)
ax_left_top = fig.add_subplot(gs[0, 0])
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
ax_right = fig.add_subplot(gs[:, 1])
# Determine break points
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = (
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
) # Increased to show more range
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.5, lower_cap * 1.02)
)
ymax = (
max(values) * 1.90 if values else 1.0
) # Increase headroom to 1.90 for text label and tick range
# Draw bars on both axes
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Set limits
ax_left_bottom.set_ylim(0, lower_cap)
ax_left_top.set_ylim(upper_start, ymax)
# Annotate values (convert ms to s)
values_s = [v / 1000.0 for v in values]
lower_cap_s = lower_cap / 1000.0
upper_start_s = upper_start / 1000.0
ymax_s = ymax / 1000.0
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
ax_left_bottom.clear()
ax_left_top.clear()
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
# Draw in bottom axis for all bars
ax_left_bottom.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
# Only draw in top axis if the bar is tall enough to reach the upper range
if v > upper_start_s:
ax_left_top.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
for i, v in enumerate(values_s):
if v <= lower_cap_s:
ax_left_bottom.text(
i,
v + lower_cap_s * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
else:
ax_left_top.text(
i,
v + (ymax_s - upper_start_s) * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
# Hide spines between axes
ax_left_top.spines["bottom"].set_visible(False)
ax_left_bottom.spines["top"].set_visible(False)
ax_left_top.tick_params(
labeltop=False, labelbottom=False, bottom=False
) # Hide tick marks
ax_left_bottom.xaxis.tick_bottom()
ax_left_bottom.tick_params(top=False) # Hide top tick marks
# Draw break marks (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_left_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_left_bottom.transAxes})
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.set_xticks(x)
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_left_bottom.tick_params(axis="y", labelsize=10)
ax_left_top.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match bar width with right subplot
ax_left_bottom.set_xlim(-0.6, 3.6)
ax_left_top.set_xlim(-0.6, 3.6)
ax_left = ax_left_bottom # for compatibility
else:
# Regular side-by-side layout
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
for i, (val, show) in enumerate(zip(values, show_vals)):
if val > cap:
bars[i].set_hatch("//")
ax_left.text(
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
)
else:
ax_left.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(val),
ha="center",
va="bottom",
fontsize=9,
)
ax_left.set_ylim(0, cap * 1.10)
_add_break_marker(ax_left, y=0.98)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
else:
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
for i, v in enumerate(values):
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
ax_left.set_ylabel("Latency (ms per passage)")
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
ax_left.set_title(
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
)
# Right subplot (A vs B, seconds) - paper style
r_labels = ["Sequential", "Delayed \n Add+Search"]
r_values = [a_total or 0.0, b_makespan or 0.0]
r_styles = [
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
{"edgecolor": "#edc948", "hatch": "/////"},
]
# 2 bars, centered with proper spacing
xr = [0, 1]
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (v, style) in enumerate(zip(r_values, r_styles)):
ax_right.bar(
xr[i],
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
for i, v in enumerate(r_values):
max_v = max(r_values) if r_values else 1.0
offset = max(0.0002, 0.02 * max_v)
ax_right.text(
xr[i],
v + offset,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
ax_right.set_xticks(xr)
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_right.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match left subplot's bar width visually
# Accounting for width_ratios=[1.5, 1]:
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
# Right: 2 bars, need same visual width
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
# range_right = 4.2 / 1.5 = 2.8
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
ax_right.set_xlim(-0.9, 1.9)
# Set y-axis limit with headroom for text labels
if r_values:
max_v = max(r_values)
ax_right.set_ylim(0, max_v * 1.15)
# Format y-axis to avoid scientific notation
ax_right.ticklabel_format(style="plain", axis="y")
plt.tight_layout()
# Add aligned ylabels using fig.text (after tight_layout)
# Get the vertical center of the entire figure
fig_center_y = 0.5
# Left ylabel - closer to left plot
left_x = 0.05
fig.text(
left_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
# Right ylabel - closer to right plot
right_bbox = ax_right.get_position()
right_x = right_bbox.x0 - 0.07
fig.text(
right_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
return
# Broken-Y mode
if args.broken_y:
import matplotlib.pyplot as plt
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.5, 6.75),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
)
# Determine default breaks from second-highest
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Limits
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
# Annotate values
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
# Hide spines between axes and draw diagonal break marks
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
ax_bottom.xaxis.tick_bottom()
# Diagonal lines at the break (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
ax_bottom.set_xticks(x)
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
ax = ax_bottom # for labeling below
else:
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
plt.figure(figsize=(5.4, 3.15))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
bar = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(bar[0])
# Hatch and annotate when capped
if val > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
_add_break_marker(ax, y=0.98)
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
v > cap for v in values
) else None
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(
idx,
val + 1.0,
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
# Try to extract some context for title
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
if args.broken_y:
fig.text(
0.02,
0.5,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
fig.suptitle(
"Add Operation Latency",
fontsize=11,
y=0.98,
fontweight="bold",
)
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
else:
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
plt.tight_layout()
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
if __name__ == "__main__":
main()

View File

@@ -455,5 +455,5 @@ Conclusion:
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
- [DiskANN Original Paper](https://suhasjs.github.io/files/diskann_neurips19.pdf)
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)

395
docs/slack-setup-guide.md Normal file
View File

@@ -0,0 +1,395 @@
# Slack Integration Setup Guide
This guide provides step-by-step instructions for setting up Slack integration with LEANN.
## Overview
LEANN's Slack integration uses MCP (Model Context Protocol) servers to fetch and index your Slack messages for RAG (Retrieval-Augmented Generation). This allows you to search through your Slack conversations using natural language queries.
## Prerequisites
1. **Slack Workspace Access**: You need admin or owner permissions in your Slack workspace to create apps and configure OAuth tokens.
2. **Slack MCP Server**: Install a Slack MCP server (e.g., `slack-mcp-server` via npm)
3. **LEANN**: Ensure you have LEANN installed and working
## Step 1: Create a Slack App
### 1.1 Go to Slack API Dashboard
1. Visit [https://api.slack.com/apps](https://api.slack.com/apps)
2. Click **"Create New App"**
3. Choose **"From scratch"**
4. Enter your app name (e.g., "LEANN Slack Integration")
5. Select your workspace
6. Click **"Create App"**
### 1.2 Configure App Permissions
#### Token Scopes
1. In your app dashboard, go to **"OAuth & Permissions"** in the left sidebar
2. Scroll down to **"Scopes"** section
3. Under **"Bot Token Scopes & OAuth Scope"**, click **"Add an OAuth Scope"**
4. Add the following scopes:
- `channels:read` - Read public channel information
- `channels:history` - Read messages in public channels
- `groups:read` - Read private channel information
- `groups:history` - Read messages in private channels
- `im:read` - Read direct message information
- `im:history` - Read direct messages
- `mpim:read` - Read group direct message information
- `mpim:history` - Read group direct messages
- `users:read` - Read user information
- `team:read` - Read workspace information
#### App-Level Tokens (Optional)
Some MCP servers may require app-level tokens:
1. Go to **"Basic Information"** in the left sidebar
2. Scroll down to **"App-Level Tokens"**
3. Click **"Generate Token and Scopes"**
4. Enter a name (e.g., "LEANN Integration")
5. Add the `connections:write` scope
6. Click **"Generate"**
7. Copy the token (starts with `xapp-`)
### 1.3 Install App to Workspace
1. Go to **"OAuth & Permissions"** in the left sidebar
2. Click **"Install to Workspace"**
3. Review the permissions and click **"Allow"**
4. Copy the **"Bot User OAuth Token"** (starts with `xoxb-`)
5. Copy the **"User OAuth Token"** (starts with `xoxp-`)
## Step 2: Install Slack MCP Server
### Option A: Using npm (Recommended)
```bash
# Install globally
npm install -g slack-mcp-server
# Or install locally
npm install slack-mcp-server
```
### Option B: Using npx (No installation required)
```bash
# Use directly without installation
npx slack-mcp-server
```
## Step 3: Install and Configure Ollama (for Real LLM Responses)
### 3.1 Install Ollama
```bash
# Install Ollama using Homebrew (macOS)
brew install ollama
# Or download from https://ollama.ai/
```
### 3.2 Start Ollama Service
```bash
# Start Ollama as a service
brew services start ollama
# Or start manually
ollama serve
```
### 3.3 Pull a Model
```bash
# Pull a lightweight model for testing
ollama pull llama3.2:1b
# Verify the model is available
ollama list
```
## Step 4: Configure Environment Variables
Create a `.env` file or set environment variables:
```bash
# Required: User OAuth Token
SLACK_OAUTH_TOKEN=xoxp-your-user-oauth-token-here
# Optional: App-Level Token (if your MCP server requires it)
SLACK_APP_TOKEN=xapp-your-app-token-here
# Optional: Workspace-specific settings
SLACK_WORKSPACE_ID=T1234567890 # Your workspace ID (optional)
```
## Step 5: Test the Setup
### 5.1 Test MCP Server Connection
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--test-connection \
--workspace-name "Your Workspace Name"
```
This will test the connection and list available tools without indexing any data.
### 5.2 Index a Specific Channel
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Your Workspace Name" \
--channels general \
--query "What did we discuss about the project?"
```
### 5.3 Real RAG Query Examples
This section demonstrates successful Slack RAG integration queries against the Sky Lab Computing workspace's "random" channel. The system successfully retrieves actual conversation messages and performs semantic search with high relevance scores, including finding specific research paper announcements and technical discussions.
### Example 1: Advisor Models Query
**Query:** "train black-box models to adopt to your personal data"
This query demonstrates the system's ability to find specific research announcements about training black-box models for personal data adaptation.
![Advisor Models Query - Command Setup](videos/slack_integration_1.1.png)
![Advisor Models Query - Search Results](videos/slack_integration_1.2.png)
![Advisor Models Query - LLM Response](videos/slack_integration_1.3.png)
### Example 2: Barbarians at the Gate Query
**Query:** "AI-driven research systems ADRS"
This query demonstrates the system's ability to find specific research announcements about AI-driven research systems and algorithm discovery.
![Barbarians Query - Command Setup](videos/slack_integration_2.1.png)
![Barbarians Query - Search Results](videos/slack_integration_2.2.png)
![Barbarians Query - LLM Response](videos/slack_integration_2.3.png)
### Prerequisites
- Bot is installed in the Sky Lab Computing workspace and invited to the target channel (run `/invite @YourBotName` in the channel if needed)
- Bot token available and exported in the same terminal session
### Commands
1) Set the workspace token for this shell
```bash
export SLACK_MCP_XOXP_TOKEN="xoxp-***-redacted-***"
```
2) Run queries against the "random" channel by channel ID (C0GN5BX0F)
**Advisor Models Query:**
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Sky Lab Computing" \
--channels C0GN5BX0F \
--max-messages-per-channel 100000 \
--query "train black-box models to adopt to your personal data" \
--llm ollama \
--llm-model "llama3.2:1b" \
--llm-host "http://localhost:11434" \
--no-concatenate-conversations
```
**Barbarians at the Gate Query:**
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Sky Lab Computing" \
--channels C0GN5BX0F \
--max-messages-per-channel 100000 \
--query "AI-driven research systems ADRS" \
--llm ollama \
--llm-model "llama3.2:1b" \
--llm-host "http://localhost:11434" \
--no-concatenate-conversations
```
These examples demonstrate the system's ability to find and retrieve specific research announcements and technical discussions from the conversation history, showcasing the power of semantic search in Slack data.
3) Optional: Ask a broader question
```bash
python test_channel_by_id_or_name.py \
--channel-id C0GN5BX0F \
--workspace-name "Sky Lab Computing" \
--query "What is LEANN about?"
```
Notes:
- If you see `not_in_channel`, invite the bot to the channel and re-run.
- If you see `channel_not_found`, confirm the channel ID and workspace.
- Deep search via server-side “search” tools may require additional Slack scopes; the example above performs client-side filtering over retrieved history.
## Common Issues and Solutions
### Issue 1: "users cache is not ready yet" Error
**Problem**: You see this warning:
```
WARNING - Failed to fetch messages from channel random: Failed to fetch messages: {'code': -32603, 'message': 'users cache is not ready yet, sync process is still running... please wait'}
```
**Solution**: This is a common timing issue. The LEANN integration now includes automatic retry logic:
1. **Wait and Retry**: The system will automatically retry with exponential backoff (2s, 4s, 8s, etc.)
2. **Increase Retry Parameters**: If needed, you can customize retry behavior:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--max-retries 10 \
--retry-delay 3.0 \
--channels general \
--query "Your query here"
```
3. **Keep MCP Server Running**: Start the MCP server separately and keep it running:
```bash
# Terminal 1: Start MCP server
slack-mcp-server
# Terminal 2: Run LEANN (it will connect to the running server)
python -m apps.slack_rag --mcp-server "slack-mcp-server" --channels general --query "test"
```
### Issue 2: "No message fetching tool found"
**Problem**: The MCP server doesn't have the expected tools.
**Solution**:
1. Check if your MCP server is properly installed and configured
2. Verify your Slack tokens are correct
3. Try a different MCP server implementation
4. Check the MCP server documentation for required configuration
### Issue 3: Permission Denied Errors
**Problem**: You get permission errors when trying to access channels.
**Solutions**:
1. **Check Bot Permissions**: Ensure your bot has been added to the channels you want to access
2. **Verify Token Scopes**: Make sure you have all required scopes configured
3. **Channel Access**: For private channels, the bot needs to be explicitly invited
4. **Workspace Permissions**: Ensure your Slack app has the necessary workspace permissions
### Issue 4: Empty Results
**Problem**: No messages are returned even though the channel has messages.
**Solutions**:
1. **Check Channel Names**: Ensure channel names are correct (without the # symbol)
2. **Verify Bot Access**: Make sure the bot can access the channels
3. **Check Date Ranges**: Some MCP servers have limitations on message history
4. **Increase Message Limits**: Try increasing the message limit:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--channels general \
--max-messages-per-channel 1000 \
--query "test"
```
## Advanced Configuration
### Custom MCP Server Commands
If you need to pass additional parameters to your MCP server:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server --token-file /path/to/tokens.json" \
--workspace-name "Your Workspace" \
--channels general \
--query "Your query"
```
### Multiple Workspaces
To work with multiple Slack workspaces, you can:
1. Create separate apps for each workspace
2. Use different environment variables
3. Run separate instances with different configurations
### Performance Optimization
For better performance with large workspaces:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Your Workspace" \
--max-messages-per-channel 500 \
--no-concatenate-conversations \
--query "Your query"
```
---
## Troubleshooting Checklist
- [ ] Slack app created with proper permissions
- [ ] Bot token (xoxb-) copied correctly
- [ ] App-level token (xapp-) created if needed
- [ ] MCP server installed and accessible
- [ ] Ollama installed and running (`brew services start ollama`)
- [ ] Ollama model pulled (`ollama pull llama3.2:1b`)
- [ ] Environment variables set correctly
- [ ] Bot invited to relevant channels
- [ ] Channel names specified without # symbol
- [ ] Sufficient retry attempts configured
- [ ] Network connectivity to Slack APIs
## Getting Help
If you continue to have issues:
1. **Check Logs**: Look for detailed error messages in the console output
2. **Test MCP Server**: Use `--test-connection` to verify the MCP server is working
3. **Verify Tokens**: Double-check that your Slack tokens are valid and have the right scopes
4. **Check Ollama**: Ensure Ollama is running (`ollama serve`) and the model is available (`ollama list`)
5. **Community Support**: Reach out to the LEANN community for help
## Example Commands
### Basic Usage
```bash
# Test connection
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
# Index specific channels
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "My Company" \
--channels general random \
--query "What did we decide about the project timeline?"
```
### Advanced Usage
```bash
# With custom retry settings
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "My Company" \
--channels general \
--max-retries 10 \
--retry-delay 5.0 \
--max-messages-per-channel 2000 \
--query "Show me all decisions made in the last month"
```

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 445 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 508 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 437 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 474 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 501 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 454 KiB

View File

@@ -43,7 +43,11 @@ from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
REPO_ROOT / "data" / "PrideandPrejudice.txt",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
@@ -182,6 +186,7 @@ def run_workflow(
is_recompute: bool,
query: str,
top_k: int,
skip_search: bool,
) -> dict[str, Any]:
prefix = f"[{label}] " if label else ""
@@ -198,12 +203,15 @@ def run_workflow(
)
initial_size = index_file_size(index_path)
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
if not skip_search:
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
else:
before_results = None
print(f"\n{prefix}Updating index with additional passages...")
update_index(
@@ -215,20 +223,23 @@ def run_workflow(
is_recompute=is_recompute,
)
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
if not skip_search:
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
else:
after_results = None
updated_size = index_file_size(index_path)
return {
"initial_size": initial_size,
"updated_size": updated_size,
"delta": updated_size - initial_size,
"before_results": before_results,
"after_results": after_results,
"before_results": before_results if not skip_search else None,
"after_results": after_results if not skip_search else None,
"metadata": load_metadata_snapshot(index_path),
}
@@ -314,6 +325,12 @@ def main() -> None:
action="store_false",
help="Skip building the no-recompute baseline.",
)
parser.add_argument(
"--skip-search",
dest="skip_search",
action="store_true",
help="Skip the search step.",
)
parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args()
@@ -350,10 +367,13 @@ def main() -> None:
is_recompute=True,
query=args.query,
top_k=args.top_k,
skip_search=args.skip_search,
)
print_results("initial search", recompute_stats["before_results"])
print_results("after update", recompute_stats["after_results"])
if not args.skip_search:
print_results("initial search", recompute_stats["before_results"])
if not args.skip_search:
print_results("after update", recompute_stats["after_results"])
print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})"
@@ -378,6 +398,7 @@ def main() -> None:
is_recompute=False,
query=args.query,
top_k=args.top_k,
skip_search=args.skip_search,
)
print(
@@ -385,8 +406,12 @@ def main() -> None:
f"{baseline_stats['delta']})"
)
after_texts = [res.text for res in recompute_stats["after_results"]]
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
after_texts = (
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
)
baseline_after_texts = (
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
)
if after_texts == baseline_after_texts:
print(
"[no-recompute] Search results match recompute baseline; see above for the shared output."

View File

@@ -0,0 +1,178 @@
#!/usr/bin/env python3
"""
MCP Integration Examples for LEANN
This script demonstrates how to use LEANN with different MCP servers for
RAG on various platforms like Slack and Twitter.
Examples:
1. Slack message RAG via MCP
2. Twitter bookmark RAG via MCP
3. Testing MCP server connections
"""
import asyncio
import sys
from pathlib import Path
# Add the parent directory to the path so we can import from apps
sys.path.append(str(Path(__file__).parent.parent))
async def demo_slack_mcp():
"""Demonstrate Slack MCP integration."""
print("=" * 60)
print("🔥 Slack MCP RAG Demo")
print("=" * 60)
print("\n1. Testing Slack MCP server connection...")
# This would typically use a real MCP server command
# For demo purposes, we show what the command would look like
# slack_app = SlackMCPRAG() # Would be used for actual testing
# Simulate command line arguments for testing
class MockArgs:
mcp_server = "slack-mcp-server" # This would be the actual MCP server command
workspace_name = "my-workspace"
channels = ["general", "random", "dev-team"]
no_concatenate_conversations = False
max_messages_per_channel = 50
test_connection = True
print(f"MCP Server Command: {MockArgs.mcp_server}")
print(f"Workspace: {MockArgs.workspace_name}")
print(f"Channels: {', '.join(MockArgs.channels)}")
# In a real scenario, you would run:
# success = await slack_app.test_mcp_connection(MockArgs)
print("\n📝 Example usage:")
print("python -m apps.slack_rag \\")
print(" --mcp-server 'slack-mcp-server' \\")
print(" --workspace-name 'my-team' \\")
print(" --channels general dev-team \\")
print(" --test-connection")
print("\n🔍 After indexing, you could query:")
print("- 'What did the team discuss about the project deadline?'")
print("- 'Find messages about the new feature launch'")
print("- 'Show me conversations about budget planning'")
async def demo_twitter_mcp():
"""Demonstrate Twitter MCP integration."""
print("\n" + "=" * 60)
print("🐦 Twitter MCP RAG Demo")
print("=" * 60)
print("\n1. Testing Twitter MCP server connection...")
# twitter_app = TwitterMCPRAG() # Would be used for actual testing
class MockArgs:
mcp_server = "twitter-mcp-server"
username = None # Fetch all bookmarks
max_bookmarks = 500
no_tweet_content = False
no_metadata = False
test_connection = True
print(f"MCP Server Command: {MockArgs.mcp_server}")
print(f"Max Bookmarks: {MockArgs.max_bookmarks}")
print(f"Include Content: {not MockArgs.no_tweet_content}")
print(f"Include Metadata: {not MockArgs.no_metadata}")
print("\n📝 Example usage:")
print("python -m apps.twitter_rag \\")
print(" --mcp-server 'twitter-mcp-server' \\")
print(" --max-bookmarks 1000 \\")
print(" --test-connection")
print("\n🔍 After indexing, you could query:")
print("- 'What AI articles did I bookmark last month?'")
print("- 'Find tweets about machine learning techniques'")
print("- 'Show me bookmarked threads about startup advice'")
async def show_mcp_server_setup():
"""Show how to set up MCP servers."""
print("\n" + "=" * 60)
print("⚙️ MCP Server Setup Guide")
print("=" * 60)
print("\n🔧 Setting up Slack MCP Server:")
print("1. Install a Slack MCP server (example commands):")
print(" npm install -g slack-mcp-server")
print(" # OR")
print(" pip install slack-mcp-server")
print("\n2. Configure Slack credentials:")
print(" export SLACK_BOT_TOKEN='xoxb-your-bot-token'")
print(" export SLACK_APP_TOKEN='xapp-your-app-token'")
print("\n3. Test the server:")
print(" slack-mcp-server --help")
print("\n🔧 Setting up Twitter MCP Server:")
print("1. Install a Twitter MCP server:")
print(" npm install -g twitter-mcp-server")
print(" # OR")
print(" pip install twitter-mcp-server")
print("\n2. Configure Twitter API credentials:")
print(" export TWITTER_API_KEY='your-api-key'")
print(" export TWITTER_API_SECRET='your-api-secret'")
print(" export TWITTER_ACCESS_TOKEN='your-access-token'")
print(" export TWITTER_ACCESS_TOKEN_SECRET='your-access-token-secret'")
print("\n3. Test the server:")
print(" twitter-mcp-server --help")
async def show_integration_benefits():
"""Show the benefits of MCP integration."""
print("\n" + "=" * 60)
print("🌟 Benefits of MCP Integration")
print("=" * 60)
benefits = [
("🔄 Live Data Access", "Fetch real-time data from platforms without manual exports"),
("🔌 Standardized Protocol", "Use any MCP-compatible server with minimal code changes"),
("🚀 Easy Extension", "Add new platforms by implementing MCP readers"),
("🔒 Secure Access", "MCP servers handle authentication and API management"),
("📊 Rich Metadata", "Access full platform metadata (timestamps, engagement, etc.)"),
("⚡ Efficient Processing", "Stream data directly into LEANN without intermediate files"),
]
for title, description in benefits:
print(f"\n{title}")
print(f" {description}")
async def main():
"""Main demo function."""
print("🎯 LEANN MCP Integration Examples")
print("This demo shows how to integrate LEANN with MCP servers for various platforms.")
await demo_slack_mcp()
await demo_twitter_mcp()
await show_mcp_server_setup()
await show_integration_benefits()
print("\n" + "=" * 60)
print("✨ Next Steps")
print("=" * 60)
print("1. Install and configure MCP servers for your platforms")
print("2. Test connections using --test-connection flag")
print("3. Run indexing to build your RAG knowledge base")
print("4. Start querying your personal data!")
print("\n📚 For more information:")
print("- Check the README for detailed setup instructions")
print("- Look at the apps/slack_rag.py and apps/twitter_rag.py for implementation details")
print("- Explore other MCP servers for additional platforms")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.3.4"
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
version = "0.3.5"
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build]
# Key: simplified CMake path

View File

@@ -29,12 +29,25 @@ if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
endif()
# Use system ZeroMQ instead of building from source
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
find_package(PkgConfig REQUIRED)
pkg_check_modules(ZMQ REQUIRED libzmq)
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
endif()
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
# This creates PkgConfig::ZMQ target automatically with correct properties
if(TARGET PkgConfig::ZMQ)
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
else()
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
endif()
# Add cppzmq headers
include_directories(third_party/cppzmq)
include_directories(SYSTEM third_party/cppzmq)
# Configure msgpack-c - disable boost dependency
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)

View File

@@ -215,6 +215,8 @@ class HNSWSearcher(BaseSearcher):
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
if hasattr(self._index, "set_zmq_port"):
self._index.set_zmq_port(zmq_port)
if query.dtype != np.float32:
query = query.astype(np.float32)

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.3.4"
version = "0.3.5"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.3.4",
"leann-core==0.3.5",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.3.4"
version = "0.3.5"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"
@@ -18,14 +18,16 @@ dependencies = [
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=2.2.0",
"sentence-transformers>=3.0.0",
"llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0", # Essential for document reading
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
"python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
"transformers>=4.30.0",
# Keep transformers below 4.46: 4.46.0 adds Python 3.10-only return type syntax and
# breaks Python 3.9 environments.
"transformers>=4.30.0,<4.46",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
@@ -40,7 +42,7 @@ dependencies = [
[project.optional-dependencies]
colab = [
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
"transformers>=4.30.0,<5.0.0", # Limit transformers version
"transformers>=4.30.0,<4.46", # 4.46.0 switches to PEP 604 typing (int | None), breaks Py3.9
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
]

View File

@@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code.
import json
import logging
import os
import pickle
import re
import subprocess
@@ -17,9 +18,11 @@ from typing import Any, Literal, Optional, Union
import numpy as np
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interactive_utils import create_api_session
from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY
@@ -728,6 +731,7 @@ class LeannBuilder:
index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute
print(f"index.is_recompute: {index.is_recompute}")
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
@@ -735,37 +739,112 @@ class LeannBuilder:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
# Faiss expects storage.ntotal to reflect the existing graph's
# population (even if the vectors themselves were pruned from disk
# for recompute mode). When we attach a fresh IndexFlat here its
# ntotal starts at zero, which later causes IndexHNSW::add to
# believe new "preset" levels were provided and trips the
# `n0 + n == levels.size()` assertion. Seed the temporary storage
# with the current ntotal so Faiss maintains the proper offset for
# incoming vectors.
try:
storage_index.ntotal = index.ntotal
except AttributeError:
# Older Faiss builds may not expose ntotal as a writable
# attribute; in that case we fall back to the default behaviour.
pass
if index.d != embedding_dim:
raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
)
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
# Append passages/offsets before we attempt index.add so the ZMQ server
# can resolve newly assigned IDs during recompute. Keep rollback hooks
# so we can restore files if the update fails mid-way.
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager: Optional[EmbeddingServerManager] = None
server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
logger.warning(
"Embedding server started on port %s instead of requested %s. "
"Using reassigned port.",
actual_port,
requested_zmq_port,
)
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
if needs_recompute:
for i in range(embeddings.shape[0]):
print(f"add {i} embeddings")
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
@@ -1157,6 +1236,17 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge."
)
print("The context provided to the LLM is:")
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
print("-" * 150)
for r in results:
chunk_relevance = f"{r.score:.3f}"
chunk_id = r.id
chunk_content = r.text[:60]
chunk_source = r.metadata.get("source", "")[:80]
print(
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
)
ask_time = time.time()
ans = self.llm.ask(prompt, **llm_kwargs)
ask_time = time.time() - ask_time
@@ -1164,19 +1254,14 @@ class LeannChat:
return ans
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() in ["quit", "exit"]:
break
if not user_input:
continue
response = self.ask(user_input)
print(f"Leann: {response}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
"""Start interactive chat session."""
session = create_api_session()
def handle_query(user_input: str):
response = self.ask(user_input)
print(f"Leann: {response}")
session.run_interactive_loop(handle_query)
def cleanup(self):
"""Explicitly cleanup embedding server resources.

View File

@@ -546,11 +546,30 @@ class OllamaChat(LLMInterface):
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
"""LLM interface for local Hugging Face Transformers models with proper chat templates.
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
Args:
model_name (str): Name of the Hugging Face model to load.
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models as this can pose
a security risk if the model repository is compromised.
"""
def __init__(
self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat", trust_remote_code: bool = False
):
logger.info(f"Initializing HFChat with model='{model_name}'")
# Security warning when trust_remote_code is enabled
if trust_remote_code:
logger.warning(
"SECURITY WARNING: trust_remote_code=True allows execution of arbitrary code from the model repository. "
"Only enable this for models from trusted sources. This creates a potential security risk if the model "
"repository is compromised."
)
self.trust_remote_code = trust_remote_code
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf")
if model_error:
@@ -588,14 +607,16 @@ class HFChat(LLMInterface):
try:
logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=self.trust_remote_code
)
logger.info(f"Loading model {model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True,
trust_remote_code=self.trust_remote_code,
)
logger.info(f"Successfully loaded {model_name}")
finally:
@@ -813,6 +834,11 @@ class OpenAIChat(LLMInterface):
try:
response = self.client.chat.completions.create(**params)
print(
f"Total tokens = {response.usage.total_tokens}, prompt tokens = {response.usage.prompt_tokens}, completion tokens = {response.usage.completion_tokens}"
)
if response.choices[0].finish_reason == "length":
print("The query is exceeding the maximum allowed number of tokens")
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")
@@ -859,7 +885,10 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
host=llm_config.get("host"),
)
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
return HFChat(
model_name=model or "deepseek-ai/deepseek-llm-7b-chat",
trust_remote_code=llm_config.get("trust_remote_code", False),
)
elif llm_type == "openai":
return OpenAIChat(
model=model or "gpt-4o",

View File

@@ -5,12 +5,128 @@ Packaged within leann-core so installed wheels can import it reliably.
import logging
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__)
# Flag to ensure AST token warning only shown once per session
_ast_token_warning_shown = False
def estimate_token_count(text: str) -> int:
"""
Estimate token count for a text string.
Uses conservative estimation: ~4 characters per token for natural text,
~1.2 tokens per character for code (worse tokenization).
Args:
text: Input text to estimate tokens for
Returns:
Estimated token count
"""
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
return len(encoder.encode(text))
except ImportError:
# Fallback: Conservative character-based estimation
# Assume worst case for code: 1.2 tokens per character
return int(len(text) * 1.2)
def calculate_safe_chunk_size(
model_token_limit: int,
overlap_tokens: int,
chunking_mode: str = "traditional",
safety_factor: float = 0.9,
) -> int:
"""
Calculate safe chunk size accounting for overlap and safety margin.
Args:
model_token_limit: Maximum tokens supported by embedding model
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
chunking_mode: "traditional" (tokens) or "ast" (characters)
safety_factor: Safety margin (0.9 = 10% safety margin)
Returns:
Safe chunk size: tokens for traditional, characters for AST
"""
safe_limit = int(model_token_limit * safety_factor)
if chunking_mode == "traditional":
# Traditional chunking uses tokens
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
return max(1, safe_limit - overlap_tokens)
else: # AST chunking
# AST uses characters, need to convert
# Conservative estimate: 1.2 tokens per char for code
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
safe_chars = int(safe_limit / 1.2)
return max(1, safe_chars - overlap_chars)
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
"""
Validate that chunks don't exceed token limits and truncate if necessary.
Args:
chunks: List of text chunks to validate
max_tokens: Maximum tokens allowed per chunk
Returns:
Tuple of (validated_chunks, num_truncated)
"""
validated_chunks = []
num_truncated = 0
for i, chunk in enumerate(chunks):
estimated_tokens = estimate_token_count(chunk)
if estimated_tokens > max_tokens:
# Truncate chunk to fit token limit
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
tokens = encoder.encode(chunk)
if len(tokens) > max_tokens:
truncated_tokens = tokens[:max_tokens]
truncated_chunk = encoder.decode(truncated_tokens)
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
)
else:
validated_chunks.append(chunk)
except ImportError:
# Fallback: Conservative character truncation
char_limit = int(max_tokens / 1.2) # Conservative for code
if len(chunk) > char_limit:
truncated_chunk = chunk[:char_limit]
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
f"(conservative estimate for {max_tokens} tokens)"
)
else:
validated_chunks.append(chunk)
else:
validated_chunks.append(chunk)
if num_truncated > 0:
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
return validated_chunks, num_truncated
# Code file extensions supported by astchunk
CODE_EXTENSIONS = {
".py": "python",
@@ -61,27 +177,45 @@ def create_ast_chunks(
max_chunk_size: int = 512,
chunk_overlap: int = 64,
metadata_template: str = "default",
) -> list[str]:
) -> list[dict[str, Any]]:
"""Create AST-aware chunks from code documents using astchunk.
Falls back to traditional chunking if astchunk is unavailable.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
try:
from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
all_chunks = []
for doc in documents:
language = doc.metadata.get("language")
if not language:
logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
continue
try:
# Warn once if AST chunk size + overlap might exceed common token limits
# Note: Actual truncation happens at embedding time with dynamic model limits
global _ast_token_warning_shown
estimated_max_tokens = int(
(max_chunk_size + chunk_overlap) * 1.2
) # Conservative estimate
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
logger.warning(
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
)
_ast_token_warning_shown = True
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
@@ -105,17 +239,40 @@ def create_ast_chunks(
chunks = chunk_builder.chunkify(code_content)
for chunk in chunks:
chunk_text = None
astchunk_metadata = {}
if hasattr(chunk, "text"):
chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str):
chunk_text = chunk
elif isinstance(chunk, dict):
# Handle astchunk format: {"content": "...", "metadata": {...}}
if "content" in chunk:
chunk_text = chunk["content"]
astchunk_metadata = chunk.get("metadata", {})
elif "text" in chunk:
chunk_text = chunk["text"]
else:
chunk_text = str(chunk) # Last resort
else:
chunk_text = str(chunk)
if chunk_text and chunk_text.strip():
all_chunks.append(chunk_text.strip())
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Merge document metadata + astchunk metadata
combined_metadata = {**doc_metadata, **astchunk_metadata}
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
@@ -123,15 +280,19 @@ def create_ast_chunks(
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
return all_chunks
def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
) -> list[dict[str, Any]]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
@@ -147,19 +308,40 @@ def create_traditional_chunks(
paragraph_separator="\n\n",
)
all_texts = []
result = []
for doc in documents:
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
try:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
all_texts.extend(node.get_content() for node in nodes)
for node in nodes:
result.append({"text": node.get_content(), "metadata": doc_metadata})
except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}")
content = doc.get_content()
if content and content.strip():
all_texts.append(content.strip())
result.append({"text": content.strip(), "metadata": doc_metadata})
return all_texts
return result
def _traditional_chunks_as_dicts(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]:
"""Helper: Traditional chunking that returns dict format for consistency.
This is now just an alias for create_traditional_chunks for backwards compatibility.
"""
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
def create_text_chunks(
@@ -171,8 +353,12 @@ def create_text_chunks(
ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
) -> list[str]:
"""Create text chunks from documents with optional AST support for code files."""
) -> list[dict[str, Any]]:
"""Create text chunks from documents with optional AST support for code files.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if not documents:
logger.warning("No documents provided for chunking")
return []
@@ -207,14 +393,17 @@ def create_text_chunks(
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
all_chunks.extend(
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
)
else:
raise
if text_docs:
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
else:
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")
# Note: Token truncation is now handled at embedding time with dynamic model limits
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
return all_chunks

View File

@@ -1,5 +1,6 @@
import argparse
import asyncio
import time
from pathlib import Path
from typing import Any, Optional, Union
@@ -8,6 +9,7 @@ from llama_index.core.node_parser import SentenceSplitter
from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher
from .interactive_utils import create_cli_session
from .registry import register_project_directory
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
@@ -105,7 +107,7 @@ Examples:
help="Documents directories and/or files (default: current directory)",
)
build_parser.add_argument(
"--backend",
"--backend-name",
type=str,
default="hnsw",
choices=["hnsw", "diskann"],
@@ -179,25 +181,25 @@ Examples:
"--doc-chunk-size",
type=int,
default=256,
help="Document chunk size in tokens/characters (default: 256)",
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
)
build_parser.add_argument(
"--doc-chunk-overlap",
type=int,
default=128,
help="Document chunk overlap (default: 128)",
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
)
build_parser.add_argument(
"--code-chunk-size",
type=int,
default=512,
help="Code chunk size in tokens/lines (default: 512)",
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
)
build_parser.add_argument(
"--code-chunk-overlap",
type=int,
default=50,
help="Code chunk overlap (default: 50)",
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
)
build_parser.add_argument(
"--use-ast-chunking",
@@ -207,14 +209,14 @@ Examples:
build_parser.add_argument(
"--ast-chunk-size",
type=int,
default=768,
help="AST chunk size in characters (default: 768)",
default=300,
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
)
build_parser.add_argument(
"--ast-chunk-overlap",
type=int,
default=96,
help="AST chunk overlap in characters (default: 96)",
default=64,
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
)
build_parser.add_argument(
"--ast-fallback-traditional",
@@ -253,6 +255,11 @@ Examples:
action="store_true",
help="Non-interactive mode: automatically select index without prompting",
)
search_parser.add_argument(
"--show-metadata",
action="store_true",
help="Display file paths and metadata in search results",
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
@@ -1185,6 +1192,7 @@ Examples:
for doc in other_docs:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
doc.metadata["source"] = file_path
filtered_docs.append(doc)
documents.extend(filtered_docs)
@@ -1260,7 +1268,7 @@ Examples:
from .chunking_utils import create_text_chunks
# Use enhanced chunking with AST support
all_texts = create_text_chunks(
chunk_texts = create_text_chunks(
documents,
chunk_size=self.node_parser.chunk_size,
chunk_overlap=self.node_parser.chunk_overlap,
@@ -1271,6 +1279,9 @@ Examples:
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
)
# create_text_chunks now returns list[dict] with metadata preserved
all_texts.extend(chunk_texts)
except ImportError as e:
print(
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
@@ -1282,14 +1293,27 @@ Examples:
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
# Check if this is a code file based on source path
source_path = doc.metadata.get("source", "")
file_path = doc.metadata.get("file_path", "")
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
# Extract metadata to preserve with chunks
chunk_metadata = {
"file_path": file_path or source_path,
"file_name": doc.metadata.get("file_name", ""),
}
# Add optional metadata if available
if "creation_date" in doc.metadata:
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Use appropriate parser based on file type
parser = self.code_parser if is_code_file else self.node_parser
nodes = parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
return all_texts
@@ -1364,7 +1388,7 @@ Examples:
index_dir.mkdir(parents=True, exist_ok=True)
print(f"Building index '{index_name}' with {args.backend} backend...")
print(f"Building index '{index_name}' with {args.backend_name} backend...")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
@@ -1376,7 +1400,7 @@ Examples:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend,
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
@@ -1387,8 +1411,8 @@ Examples:
num_threads=args.num_threads,
)
for chunk_text in all_texts:
builder.add_text(chunk_text)
for chunk in all_texts:
builder.add_text(chunk["text"], metadata=chunk["metadata"])
builder.build_index(index_path)
print(f"Index built at {index_path}")
@@ -1509,7 +1533,25 @@ Examples:
print(f"Search results for '{query}' (top {len(results)}):")
for i, result in enumerate(results, 1):
print(f"{i}. Score: {result.score:.3f}")
# Display metadata if flag is set
if args.show_metadata and result.metadata:
file_path = result.metadata.get("file_path", "")
if file_path:
print(f" 📄 File: {file_path}")
file_name = result.metadata.get("file_name", "")
if file_name and file_name != file_path:
print(f" 📝 Name: {file_name}")
# Show timestamps if available
if "creation_date" in result.metadata:
print(f" 🕐 Created: {result.metadata['creation_date']}")
if "last_modified_date" in result.metadata:
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
print(f" {result.text[:200]}...")
print(f" Source: {result.metadata.get('source', '')}")
print()
async def ask_questions(self, args):
@@ -1541,6 +1583,7 @@ Examples:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
query_start_time = time.time()
response = chat.ask(
prompt,
top_k=args.top_k,
@@ -1551,27 +1594,20 @@ Examples:
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
query_completion_time = time.time() - query_start_time
print(f"LEANN: {response}")
print(f"The query took {query_completion_time:.3f} seconds to finish")
initial_query = (args.query or "").strip()
if args.interactive:
# Create interactive session
session = create_cli_session(index_name)
if initial_query:
_ask_once(initial_query)
print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40)
while True:
user_input = input("\nYou: ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
if not user_input:
continue
_ask_once(user_input)
session.run_interactive_loop(_ask_once)
else:
query = initial_query or input("Enter your question: ").strip()
if not query:

View File

@@ -10,6 +10,7 @@ import time
from typing import Any, Optional
import numpy as np
import tiktoken
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
@@ -20,6 +21,170 @@ LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Token limit registry for embedding models
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
# Ollama models use dynamic discovery via /api/show
EMBEDDING_MODEL_LIMITS = {
# Nomic models (common across servers)
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
"nomic-embed-text-v1.5": 2048,
"nomic-embed-text-v2": 512,
# Other embedding models
"mxbai-embed-large": 512,
"all-minilm": 512,
"bge-m3": 8192,
"snowflake-arctic-embed": 512,
# OpenAI models
"text-embedding-3-small": 8192,
"text-embedding-3-large": 8192,
"text-embedding-ada-002": 8192,
}
def get_model_token_limit(
model_name: str,
base_url: Optional[str] = None,
default: int = 2048,
) -> int:
"""
Get token limit for a given embedding model.
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
Args:
model_name: Name of the embedding model
base_url: Base URL of the embedding server (for dynamic discovery)
default: Default token limit if model not found
Returns:
Token limit for the model in tokens
"""
# Try Ollama dynamic discovery if base_url provided
if base_url:
# Detect Ollama servers by port or "ollama" in URL
if "11434" in base_url or "ollama" in base_url.lower():
limit = _query_ollama_context_limit(model_name, base_url)
if limit:
return limit
# Fallback to known model registry with version handling (from PR #154)
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
base_model_name = model_name.split(":")[0]
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[model_name]
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[base_model_name]
# Check partial matches for common patterns
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
return limit
# Default fallback
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
return default
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
"""
Truncate texts to fit within token limit using tiktoken.
Args:
texts: List of text strings to truncate
token_limit: Maximum number of tokens allowed
Returns:
List of truncated texts (same length as input)
"""
if not texts:
return []
# Use tiktoken with cl100k_base encoding
enc = tiktoken.get_encoding("cl100k_base")
truncated_texts = []
truncation_count = 0
total_tokens_removed = 0
max_original_length = 0
for i, text in enumerate(texts):
tokens = enc.encode(text)
original_length = len(tokens)
if original_length <= token_limit:
# Text is within limit, keep as is
truncated_texts.append(text)
else:
# Truncate to token_limit
truncated_tokens = tokens[:token_limit]
truncated_text = enc.decode(truncated_tokens)
truncated_texts.append(truncated_text)
# Track truncation statistics
truncation_count += 1
tokens_removed = original_length - token_limit
total_tokens_removed += tokens_removed
max_original_length = max(max_original_length, original_length)
# Log individual truncation at WARNING level (first few only)
if truncation_count <= 3:
logger.warning(
f"Text {i + 1} truncated: {original_length}{token_limit} tokens "
f"({tokens_removed} tokens removed)"
)
elif truncation_count == 4:
logger.warning("Further truncation warnings suppressed...")
# Log summary at INFO level
if truncation_count > 0:
logger.warning(
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
)
else:
logger.debug(
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
)
return truncated_texts
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query Ollama /api/show for model context limit.
Args:
model_name: Name of the Ollama model
base_url: Base URL of the Ollama server
Returns:
Context limit in tokens if found, None otherwise
"""
try:
import requests
response = requests.post(
f"{base_url}/api/show",
json={"name": model_name},
timeout=5,
)
if response.status_code == 200:
data = response.json()
if "model_info" in data:
# Look for *.context_length in model_info
for key, value in data["model_info"].items():
if "context_length" in key and isinstance(value, int):
logger.info(f"Detected {model_name} context limit: {value} tokens")
return value
except Exception as e:
logger.debug(f"Failed to query Ollama context limit: {e}")
return None
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
@@ -183,32 +348,73 @@ def compute_embeddings_sentence_transformers(
}
try:
# Try local loading first
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
# Try loading with advanced parameters first (newer versions)
local_model_kwargs = model_kwargs.copy()
local_tokenizer_kwargs = tokenizer_kwargs.copy()
local_model_kwargs["local_files_only"] = True
local_tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
model_kwargs=local_model_kwargs,
tokenizer_kwargs=local_tokenizer_kwargs,
local_files_only=True,
)
logger.info("Model loaded successfully! (local + optimized)")
except TypeError as e:
if "model_kwargs" in str(e) or "tokenizer_kwargs" in str(e):
logger.warning(
f"Advanced parameters not supported ({e}), using basic initialization..."
)
# Fallback to basic initialization for older versions
try:
model = SentenceTransformer(
model_name,
device=device,
local_files_only=True,
)
logger.info("Model loaded successfully! (local + basic)")
except Exception as e2:
logger.warning(f"Local loading failed ({e2}), trying network download...")
model = SentenceTransformer(
model_name,
device=device,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + basic)")
else:
raise
except Exception as e:
logger.warning(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
# Fallback to network loading with advanced parameters
try:
network_model_kwargs = model_kwargs.copy()
network_tokenizer_kwargs = tokenizer_kwargs.copy()
network_model_kwargs["local_files_only"] = False
network_tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + optimized)")
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=network_model_kwargs,
tokenizer_kwargs=network_tokenizer_kwargs,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + optimized)")
except TypeError as e2:
if "model_kwargs" in str(e2) or "tokenizer_kwargs" in str(e2):
logger.warning(
f"Advanced parameters not supported ({e2}), using basic network loading..."
)
model = SentenceTransformer(
model_name,
device=device,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + basic)")
else:
raise
# Apply additional optimizations based on mode
if use_fp16 and device in ["cuda", "mps"]:
@@ -533,9 +739,10 @@ def compute_embeddings_ollama(
host: Optional[str] = None,
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
Compute embeddings using Ollama API with true batch processing.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Uses the /api/embed endpoint which supports batch inputs.
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
Args:
texts: List of texts to compute embeddings for
@@ -640,11 +847,11 @@ def compute_embeddings_ollama(
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it
# Verify the model supports embeddings by testing it with /api/embed
try:
test_response = requests.post(
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
f"{resolved_host}/api/embed",
json={"model": model_name, "input": "test"},
timeout=10,
)
if test_response.status_code != 200:
@@ -676,56 +883,71 @@ def compute_embeddings_ollama(
# If torch is not available, use conservative batch size
batch_size = 32
logger.info(f"Using batch size: {batch_size}")
logger.info(f"Using batch size: {batch_size} for true batch processing")
# Get model token limit and apply truncation before batching
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
logger.info(f"Model '{model_name}' token limit: {token_limit}")
# Apply truncation to all texts before batch processing
# Function logs truncation details internally
texts = truncate_to_token_limit(texts, token_limit)
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts."""
all_embeddings = []
failed_indices = []
"""Get embeddings for a batch of texts using /api/embed endpoint."""
max_retries = 3
retry_count = 0
for i, text in enumerate(batch_texts):
max_retries = 3
retry_count = 0
# Texts are already truncated to token limit by the outer function
while retry_count < max_retries:
try:
# Use /api/embed endpoint with "input" parameter for batch processing
response = requests.post(
f"{resolved_host}/api/embed",
json={"model": model_name, "input": batch_texts},
timeout=60, # Increased timeout for batch processing
)
response.raise_for_status()
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
try:
response = requests.post(
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
result = response.json()
batch_embeddings = result.get("embeddings")
if batch_embeddings is None:
raise ValueError("No embeddings returned from API")
if not isinstance(batch_embeddings, list):
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
if len(batch_embeddings) != len(batch_texts):
raise ValueError(
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
return batch_embeddings, []
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for batch after {max_retries} retries")
return None, list(range(len(batch_texts)))
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
# Enhanced error detection for token limit violations
error_msg = str(e).lower()
if "token" in error_msg and (
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
):
logger.error(
f"Token limit exceeded for batch. Error: {e}. "
f"Consider reducing chunk sizes or check token truncation."
)
else:
logger.error(f"Failed to get embeddings for batch: {e}")
return None, list(range(len(batch_texts)))
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
return None, list(range(len(batch_texts)))
# Process texts in batches
all_embeddings = []
@@ -743,7 +965,7 @@ def compute_embeddings_ollama(
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
else:
batch_iterator = range(num_batches)
@@ -754,10 +976,14 @@ def compute_embeddings_ollama(
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
all_embeddings.extend(batch_embeddings)
if batch_embeddings is not None:
all_embeddings.extend(batch_embeddings)
else:
# Entire batch failed, add None placeholders
all_embeddings.extend([None] * len(batch_texts))
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
# Handle failed embeddings
if all_failed_indices:

View File

@@ -1,4 +1,5 @@
import atexit
import json
import logging
import os
import socket
@@ -48,6 +49,85 @@ def _check_port(port: int) -> bool:
# Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager:
"""
A simplified manager for embedding server processes that avoids complex update mechanisms.
@@ -85,13 +165,14 @@ class EmbeddingServerManager:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
config_signature = {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
config_signature = self._build_config_signature(
model_name=model_name,
embedding_mode=embedding_mode,
provider_options=provider_options,
passages_file=passages_file,
)
# If this manager already has a live server, just reuse it
if (
@@ -115,6 +196,7 @@ class EmbeddingServerManager:
port,
model_name,
embedding_mode,
config_signature=config_signature,
provider_options=provider_options,
**kwargs,
)
@@ -136,11 +218,30 @@ class EmbeddingServerManager:
**kwargs,
)
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
@@ -163,10 +264,11 @@ class EmbeddingServerManager:
command,
actual_port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started:
self._server_config = {
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
@@ -198,6 +300,7 @@ class EmbeddingServerManager:
command,
port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready(port)
if started:
@@ -241,7 +344,9 @@ class EmbeddingServerManager:
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
@@ -276,26 +381,29 @@ class EmbeddingServerManager:
)
self.server_port = port
# Record config for in-process reuse (best effort; refined later when ready)
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
"provider_options": provider_options or {},
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
if config_signature is not None:
self._server_config = config_signature
else: # Fallback for unexpected code paths
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
"provider_options": provider_options or {},
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
@@ -403,7 +511,9 @@ class EmbeddingServerManager:
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
@@ -429,12 +539,15 @@ class EmbeddingServerManager:
atexit.register(self._finalize_process)
self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
if config_signature is not None:
self._server_config = config_signature
else:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -0,0 +1,189 @@
"""
Interactive session utilities for LEANN applications.
Provides shared readline functionality and command handling across
CLI, API, and RAG example interactive modes.
"""
import atexit
import os
from pathlib import Path
from typing import Callable, Optional
# Try to import readline with fallback for Windows
try:
import readline
HAS_READLINE = True
except ImportError:
# Windows doesn't have readline by default
HAS_READLINE = False
readline = None
class InteractiveSession:
"""Manages interactive session with optional readline support and common commands."""
def __init__(
self,
history_name: str,
prompt: str = "You: ",
welcome_message: str = "",
):
"""
Initialize interactive session with optional readline support.
Args:
history_name: Name for history file (e.g., "cli", "api_chat")
(ignored if readline not available)
prompt: Input prompt to display
welcome_message: Message to show when starting session
Note:
On systems without readline (e.g., Windows), falls back to basic input()
with limited functionality (no history, no line editing).
"""
self.history_name = history_name
self.prompt = prompt
self.welcome_message = welcome_message
self._setup_complete = False
def setup_readline(self):
"""Setup readline with history support (if available)."""
if self._setup_complete:
return
if not HAS_READLINE:
# Readline not available (likely Windows), skip setup
self._setup_complete = True
return
# History file setup
history_dir = Path.home() / ".leann" / "history"
history_dir.mkdir(parents=True, exist_ok=True)
history_file = history_dir / f"{self.history_name}.history"
# Load history if exists
try:
readline.read_history_file(str(history_file))
readline.set_history_length(1000)
except FileNotFoundError:
pass
# Save history on exit
atexit.register(readline.write_history_file, str(history_file))
# Optional: Enable vi editing mode (commented out by default)
# readline.parse_and_bind("set editing-mode vi")
self._setup_complete = True
def _show_help(self):
"""Show available commands."""
print("Commands:")
print(" quit/exit/q - Exit the chat")
print(" help - Show this help message")
print(" clear - Clear screen")
print(" history - Show command history")
def _show_history(self):
"""Show command history."""
if not HAS_READLINE:
print(" History not available (readline not supported on this system)")
return
history_length = readline.get_current_history_length()
if history_length == 0:
print(" No history available")
return
for i in range(history_length):
item = readline.get_history_item(i + 1)
if item:
print(f" {i + 1}: {item}")
def get_user_input(self) -> Optional[str]:
"""
Get user input with readline support.
Returns:
User input string, or None if EOF (Ctrl+D)
"""
try:
return input(self.prompt).strip()
except KeyboardInterrupt:
print("\n(Use 'quit' to exit)")
return "" # Return empty string to continue
except EOFError:
print("\nGoodbye!")
return None
def run_interactive_loop(self, handler_func: Callable[[str], None]):
"""
Run the interactive loop with a custom handler function.
Args:
handler_func: Function to handle user input that's not a built-in command
Should accept a string and handle the user's query
"""
self.setup_readline()
if self.welcome_message:
print(self.welcome_message)
while True:
user_input = self.get_user_input()
if user_input is None: # EOF (Ctrl+D)
break
if not user_input: # Empty input or KeyboardInterrupt
continue
# Handle built-in commands
command = user_input.lower()
if command in ["quit", "exit", "q"]:
print("Goodbye!")
break
elif command == "help":
self._show_help()
elif command == "clear":
os.system("clear" if os.name != "nt" else "cls")
elif command == "history":
self._show_history()
else:
# Regular user input - pass to handler
try:
handler_func(user_input)
except Exception as e:
print(f"Error: {e}")
def create_cli_session(index_name: str) -> InteractiveSession:
"""Create an interactive session for CLI usage."""
return InteractiveSession(
history_name=index_name,
prompt="\nYou: ",
welcome_message="LEANN Assistant ready! Type 'quit' to exit, 'help' for commands\n"
+ "=" * 40,
)
def create_api_session() -> InteractiveSession:
"""Create an interactive session for API chat."""
return InteractiveSession(
history_name="api_chat",
prompt="You: ",
welcome_message="Leann Chat started (type 'quit' to exit, 'help' for commands)\n"
+ "=" * 40,
)
def create_rag_session(app_name: str, data_description: str) -> InteractiveSession:
"""Create an interactive session for RAG examples."""
return InteractiveSession(
history_name=f"{app_name}_rag",
prompt="You: ",
welcome_message=f"[Interactive Mode] Chat with your {data_description} data!\nType 'quit' or 'exit' to stop, 'help' for commands.\n"
+ "=" * 40,
)

View File

@@ -60,6 +60,11 @@ def handle_request(request):
"maximum": 128,
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
},
"show_metadata": {
"type": "boolean",
"default": False,
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
},
},
"required": ["index_name", "query"],
},
@@ -104,6 +109,8 @@ def handle_request(request):
f"--complexity={args.get('complexity', 32)}",
"--non-interactive",
]
if args.get("show_metadata", False):
cmd.append("--show-metadata")
result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_list":

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.3.4"
version = "0.3.5"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -22,7 +22,10 @@ dependencies = [
"sglang",
"ollama",
"requests>=2.25.0",
"sentence-transformers>=2.2.0",
"sentence-transformers>=3.0.0",
# Pin transformers below 4.46: 4.46.0 introduced Python 3.10-only typing (PEP 604) and
# breaks our Python 3.9 test matrix when pulled in by sentence-transformers.
"transformers<4.46",
"openai>=1.0.0",
# PDF parsing dependencies - essential for document processing
"PyPDF2>=3.0.0",
@@ -54,6 +57,8 @@ dependencies = [
"tree-sitter-c-sharp>=0.20.0",
"tree-sitter-typescript>=0.20.0",
"torchvision>=0.23.0",
"einops",
"seaborn",
]
[project.optional-dependencies]
@@ -111,7 +116,7 @@ target-version = "py39"
line-length = 100
extend-exclude = [
"third_party",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
]

View File

@@ -8,7 +8,7 @@ import subprocess
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
@@ -116,8 +116,10 @@ class TestChunkingFunctions:
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(len(chunk.strip()) > 0 for chunk in chunks)
# Traditional chunks now return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks)
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
def test_create_traditional_chunks_empty_docs(self):
"""Test traditional chunking with empty documents."""
@@ -158,11 +160,22 @@ class Calculator:
# Should have multiple chunks due to different functions/classes
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(len(chunk.strip()) > 0 for chunk in chunks)
# R3: Expect dict format with "text" and "metadata" keys
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
"Each chunk text should be non-empty"
)
# Check metadata is present
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
"Each chunk should have file_path metadata"
)
# Check that code structure is somewhat preserved
combined_content = " ".join(chunks)
combined_content = " ".join([c["text"] for c in chunks])
assert "def hello_world" in combined_content
assert "class Calculator" in combined_content
@@ -194,7 +207,11 @@ class Calculator:
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
# R3: Traditional chunking should also return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_ast_mode(self):
"""Test text chunking in AST mode."""
@@ -213,7 +230,11 @@ class Calculator:
)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
# R3: AST mode should also return dict format
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_custom_extensions(self):
"""Test text chunking with custom code file extensions."""
@@ -353,6 +374,552 @@ class MathUtils:
pytest.skip("Test timed out - likely due to model download in CI")
class TestASTContentExtraction:
"""Test AST content extraction bug fix.
These tests verify that astchunk's dict format with 'content' key is handled correctly,
and that the extraction logic doesn't fall through to stringifying entire dicts.
"""
def test_extract_content_from_astchunk_dict(self):
"""Test that astchunk dict format with 'content' key is handled correctly.
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
This causes fallthrough to str(chunk), stringifying the entire dict.
This test will FAIL until the bug is fixed because:
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
- Fixed code should extract just the content value
"""
# Mock the ASTChunkBuilder class
mock_builder = Mock()
# Astchunk returns this format
astchunk_format_chunk = {
"content": "def hello():\n print('world')",
"metadata": {
"filepath": "test.py",
"line_count": 2,
"start_line_no": 0,
"end_line_no": 1,
"node_count": 1,
},
}
mock_builder.chunkify.return_value = [astchunk_format_chunk]
# Create mock document
doc = MockDocument(
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
)
# Mock the astchunk module and its ASTChunkBuilder class
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
# Patch sys.modules to inject our mock before the import
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should return dict format with proper metadata
assert len(chunks) > 0, "Should return at least one chunk"
# R3: Each chunk should be a dict
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "metadata" in chunk, "Chunk should have 'metadata' key"
chunk_text = chunk["text"]
# CRITICAL: Should NOT contain stringified dict markers in the text field
# These assertions will FAIL with current buggy code
assert "'content':" not in chunk_text, (
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
)
assert "'metadata':" not in chunk_text, (
"Chunk text contains stringified metadata - extraction failed! "
f"Got: {chunk_text[:100]}..."
)
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
"Chunk text appears to be a stringified dict"
)
# Should contain actual content
assert "def hello()" in chunk_text, "Should extract actual code content"
assert "print('world')" in chunk_text, "Should extract complete code content"
# R3: Should preserve astchunk metadata
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
"Should preserve file path metadata"
)
def test_extract_text_key_fallback(self):
"""Test that 'text' key still works for backward compatibility.
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
This test should PASS even with current code.
"""
mock_builder = Mock()
# Some chunks might use "text" key
text_key_chunk = {"text": "def legacy_function():\n return True"}
mock_builder.chunkify.return_value = [text_key_chunk]
# Create mock document
doc = MockDocument(
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract text correctly as dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
# Should NOT be stringified
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
# Should contain actual content
assert "def legacy_function()" in chunk_text
assert "return True" in chunk_text
def test_handles_string_chunks(self):
"""Test that plain string chunks still work.
Some chunkers might return plain strings - verify these are preserved.
This test should PASS with current code.
"""
mock_builder = Mock()
# Plain string chunk
plain_string_chunk = "def simple_function():\n pass"
mock_builder.chunkify.return_value = [plain_string_chunk]
# Create mock document
doc = MockDocument(
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should wrap string in dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
assert chunk_text == plain_string_chunk.strip(), (
"Should preserve plain string chunk content"
)
assert "def simple_function()" in chunk_text
assert "pass" in chunk_text
def test_multiple_chunks_with_mixed_formats(self):
"""Test handling of multiple chunks with different formats.
Real-world scenario: astchunk might return a mix of formats.
This test will FAIL if any chunk with 'content' key gets stringified.
"""
mock_builder = Mock()
# Mix of formats
mixed_chunks = [
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
"def second():\n return 2", # Plain string
{"text": "def third():\n return 3"}, # Old format
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
]
mock_builder.chunkify.return_value = mixed_chunks
# Create mock document
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract all chunks correctly as dicts
assert len(chunks) == 4, "Should extract all 4 chunks"
# Check each chunk
for i, chunk in enumerate(chunks):
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
assert "text" in chunk, f"Chunk {i} should have 'text' key"
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
chunk_text = chunk["text"]
# None should be stringified dicts
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
assert "'metadata':" not in chunk_text, (
f"Chunk {i} text is stringified (has 'metadata':)"
)
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
# Verify actual content is present
combined = "\n".join([c["text"] for c in chunks])
assert "def first()" in combined
assert "def second()" in combined
assert "def third()" in combined
assert "class MyClass:" in combined
def test_empty_content_value_handling(self):
"""Test handling of chunks with empty content values.
Edge case: chunk has 'content' key but value is empty.
Should skip these chunks, not stringify them.
"""
mock_builder = Mock()
chunks_with_empty = [
{"content": "", "metadata": {"line_count": 0}}, # Empty content
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
]
mock_builder.chunkify.return_value = chunks_with_empty
doc = MockDocument(
"def valid():\n return True", "/test/empty.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# R3: Should only have the valid chunk (empty ones filtered out)
assert len(chunks) == 1, "Should filter out empty content chunks"
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "def valid()" in chunk["text"]
# Should not have stringified the empty dict
assert "'content': ''" not in chunk["text"]
class TestASTMetadataPreservation:
"""Test metadata preservation in AST chunk dictionaries.
R3: These tests define the contract for metadata preservation when returning
chunk dictionaries instead of plain strings. Each chunk dict should have:
- "text": str - the actual chunk content
- "metadata": dict - all metadata from document AND astchunk
These tests will FAIL until G3 implementation changes return type to list[dict].
"""
def test_ast_chunks_preserve_file_metadata(self):
"""Test that document metadata is preserved in chunk metadata.
This test verifies that all document-level metadata (file_path, file_name,
creation_date, last_modified_date) is included in each chunk's metadata dict.
This will FAIL because current code returns list[str], not list[dict].
"""
# Create mock document with rich metadata
python_code = '''
def calculate_sum(numbers):
"""Calculate sum of numbers."""
return sum(numbers)
class DataProcessor:
"""Process data records."""
def process(self, data):
return [x * 2 for x in data]
'''
doc = MockDocument(
python_code,
file_path="/project/src/utils.py",
metadata={
"language": "python",
"file_path": "/project/src/utils.py",
"file_name": "utils.py",
"creation_date": "2024-01-15T10:30:00",
"last_modified_date": "2024-10-31T15:45:00",
},
)
# Mock astchunk to return chunks with metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def calculate_sum(numbers):\n return sum(numbers)",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 2,
"start_line_no": 1,
"end_line_no": 2,
"node_count": 1,
},
},
{
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 3,
"start_line_no": 5,
"end_line_no": 7,
"node_count": 2,
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These assertions will FAIL with current list[str] return type
assert len(chunks) == 2, "Should return 2 chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL: current code returns strings
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
# Document metadata preservation - WILL FAIL
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
assert metadata["file_path"] == "/project/src/utils.py", (
f"Chunk {i} file_path incorrect"
)
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
f"Chunk {i} creation_date incorrect"
)
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
f"Chunk {i} last_modified_date incorrect"
)
# Verify metadata is consistent across chunks from same document
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
"All chunks from same document should have same file_path"
)
# Verify text content is present and not stringified
assert "def calculate_sum" in chunks[0]["text"]
assert "class DataProcessor" in chunks[1]["text"]
def test_ast_chunks_include_astchunk_metadata(self):
"""Test that astchunk-specific metadata is merged into chunk metadata.
This test verifies that astchunk's metadata (line_count, start_line_no,
end_line_no, node_count) is merged with document metadata.
This will FAIL because current code returns list[str], not list[dict].
"""
python_code = '''
def function_one():
"""First function."""
x = 1
y = 2
return x + y
def function_two():
"""Second function."""
return 42
'''
doc = MockDocument(
python_code,
file_path="/test/code.py",
metadata={
"language": "python",
"file_path": "/test/code.py",
"file_name": "code.py",
},
)
# Mock astchunk with detailed metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
"metadata": {
"filepath": "/test/code.py",
"line_count": 4,
"start_line_no": 1,
"end_line_no": 4,
"node_count": 5, # function, assignments, return
},
},
{
"content": "def function_two():\n return 42",
"metadata": {
"filepath": "/test/code.py",
"line_count": 2,
"start_line_no": 7,
"end_line_no": 8,
"node_count": 2, # function, return
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These will FAIL with current list[str] return
assert len(chunks) == 2
# First chunk - function_one
chunk1 = chunks[0]
assert isinstance(chunk1, dict), "Chunk should be dict"
assert "metadata" in chunk1
metadata1 = chunk1["metadata"]
# Check astchunk metadata is present
assert "line_count" in metadata1, "Should include astchunk line_count"
assert metadata1["line_count"] == 4, "line_count should be 4"
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
assert "node_count" in metadata1, "Should include astchunk node_count"
assert metadata1["node_count"] == 5, "node_count should be 5"
# Second chunk - function_two
chunk2 = chunks[1]
metadata2 = chunk2["metadata"]
assert metadata2["line_count"] == 2, "line_count should be 2"
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
assert metadata2["node_count"] == 2, "node_count should be 2"
# Verify document metadata is ALSO present (merged, not replaced)
assert metadata1["file_path"] == "/test/code.py"
assert metadata1["file_name"] == "code.py"
assert metadata2["file_path"] == "/test/code.py"
assert metadata2["file_name"] == "code.py"
# Verify text content is correct
assert "def function_one" in chunk1["text"]
assert "def function_two" in chunk2["text"]
def test_traditional_chunks_as_dicts_helper(self):
"""Test the helper function that wraps traditional chunks as dicts.
This test verifies that when create_traditional_chunks is called,
its plain string chunks are wrapped into dict format with metadata.
This will FAIL because the helper function _traditional_chunks_as_dicts()
doesn't exist yet, and create_traditional_chunks returns list[str].
"""
# Create documents with various metadata
docs = [
MockDocument(
"This is the first paragraph of text. It contains multiple sentences. "
"This should be split into chunks based on size.",
file_path="/docs/readme.txt",
metadata={
"file_path": "/docs/readme.txt",
"file_name": "readme.txt",
"creation_date": "2024-01-01",
},
),
MockDocument(
"Second document with different metadata. It also has content that needs chunking.",
file_path="/docs/guide.md",
metadata={
"file_path": "/docs/guide.md",
"file_name": "guide.md",
"last_modified_date": "2024-10-31",
},
),
]
# Call create_traditional_chunks (which should now return list[dict])
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
# CRITICAL: Will FAIL - current code returns list[str]
assert len(chunks) > 0, "Should return chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
# Text should be non-empty
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
# Metadata should include document info
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
# Verify metadata tracking works correctly
# At least one chunk should be from readme.txt
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
# At least one chunk should be from guide.md
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
# Verify creation_date is preserved for readme chunks
for chunk in readme_chunks:
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
"readme.txt chunks should preserve creation_date"
)
# Verify last_modified_date is preserved for guide chunks
for chunk in guide_chunks:
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
"guide.md chunks should preserve last_modified_date"
)
# Verify text content is present
all_text = " ".join([c["text"] for c in chunks])
assert "first paragraph" in all_text
assert "Second document" in all_text
class TestErrorHandling:
"""Test error handling and edge cases."""

View File

@@ -0,0 +1,137 @@
import json
import time
import pytest
from leann.embedding_server_manager import EmbeddingServerManager
class DummyProcess:
def __init__(self):
self.pid = 12345
self._terminated = False
def poll(self):
return 0 if self._terminated else None
def terminate(self):
self._terminated = True
def kill(self):
self._terminated = True
def wait(self, timeout=None):
self._terminated = True
return 0
@pytest.fixture
def embedding_manager(monkeypatch):
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
def fake_get_available_port(start_port):
return start_port
monkeypatch.setattr(
"leann.embedding_server_manager._get_available_port",
fake_get_available_port,
)
start_calls = []
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
config_signature = kwargs.get("config_signature")
start_calls.append(config_signature)
self.server_process = DummyProcess()
self.server_port = port
self._server_config = config_signature
return True, port
monkeypatch.setattr(
EmbeddingServerManager,
"_start_new_server",
fake_start_new_server,
)
# Ensure stop_server doesn't try to operate on real subprocesses
def fake_stop_server(self):
self.server_process = None
self.server_port = None
self._server_config = None
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
return manager, start_calls
def _write_meta(meta_path, passages_name, index_name, total):
meta_path.write_text(
json.dumps(
{
"backend_name": "hnsw",
"embedding_model": "test-model",
"embedding_mode": "sentence-transformers",
"dimensions": 3,
"backend_kwargs": {},
"passage_sources": [
{
"type": "jsonl",
"path": passages_name,
"index_path": index_name,
}
],
"total_passages": total,
}
),
encoding="utf-8",
)
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
manager, start_calls = embedding_manager
meta_path = tmp_path / "example.meta.json"
passages_path = tmp_path / "example.passages.jsonl"
index_path = tmp_path / "example.passages.idx"
passages_path.write_text("first\n", encoding="utf-8")
index_path.write_bytes(b"index")
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
# Initial start populates signature
ok, port = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port == 6000
assert len(start_calls) == 1
initial_signature = start_calls[0]["passages_signature"]
# No metadata change => reuse existing server
ok, port_again = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_again == 6000
assert len(start_calls) == 1
# Modify passage data and metadata to force signature change
time.sleep(0.01) # Ensure filesystem timestamps move forward
passages_path.write_text("second\n", encoding="utf-8")
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
ok, port_third = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_third == 6000
assert len(start_calls) == 2
updated_signature = start_calls[1]["passages_signature"]
assert updated_signature != initial_signature

View File

@@ -0,0 +1,208 @@
#!/usr/bin/env python3
"""
Test script for MCP integration implementations.
This script tests the basic functionality of the MCP readers and RAG applications
without requiring actual MCP servers to be running.
"""
import sys
from pathlib import Path
# Add the parent directory to the path so we can import from apps
sys.path.append(str(Path(__file__).parent.parent))
from apps.slack_data.slack_mcp_reader import SlackMCPReader
from apps.slack_rag import SlackMCPRAG
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
from apps.twitter_rag import TwitterMCPRAG
def test_slack_reader_initialization():
"""Test that SlackMCPReader can be initialized with various parameters."""
print("Testing SlackMCPReader initialization...")
# Test basic initialization
reader = SlackMCPReader("slack-mcp-server")
assert reader.mcp_server_command == "slack-mcp-server"
assert reader.concatenate_conversations
assert reader.max_messages_per_conversation == 100
# Test with custom parameters
reader = SlackMCPReader(
"custom-slack-server",
workspace_name="test-workspace",
concatenate_conversations=False,
max_messages_per_conversation=50,
)
assert reader.workspace_name == "test-workspace"
assert not reader.concatenate_conversations
assert reader.max_messages_per_conversation == 50
print("✅ SlackMCPReader initialization tests passed")
def test_twitter_reader_initialization():
"""Test that TwitterMCPReader can be initialized with various parameters."""
print("Testing TwitterMCPReader initialization...")
# Test basic initialization
reader = TwitterMCPReader("twitter-mcp-server")
assert reader.mcp_server_command == "twitter-mcp-server"
assert reader.include_tweet_content
assert reader.include_metadata
assert reader.max_bookmarks == 1000
# Test with custom parameters
reader = TwitterMCPReader(
"custom-twitter-server",
username="testuser",
include_tweet_content=False,
include_metadata=False,
max_bookmarks=500,
)
assert reader.username == "testuser"
assert not reader.include_tweet_content
assert not reader.include_metadata
assert reader.max_bookmarks == 500
print("✅ TwitterMCPReader initialization tests passed")
def test_slack_message_formatting():
"""Test Slack message formatting functionality."""
print("Testing Slack message formatting...")
reader = SlackMCPReader("slack-mcp-server")
# Test basic message formatting
message = {
"text": "Hello, world!",
"user": "john_doe",
"channel": "general",
"ts": "1234567890.123456",
}
formatted = reader._format_message(message)
assert "Channel: #general" in formatted
assert "User: john_doe" in formatted
assert "Message: Hello, world!" in formatted
assert "Time:" in formatted
# Test with missing fields
message = {"text": "Simple message"}
formatted = reader._format_message(message)
assert "Message: Simple message" in formatted
print("✅ Slack message formatting tests passed")
def test_twitter_bookmark_formatting():
"""Test Twitter bookmark formatting functionality."""
print("Testing Twitter bookmark formatting...")
reader = TwitterMCPReader("twitter-mcp-server")
# Test basic bookmark formatting
bookmark = {
"text": "This is a great article about AI!",
"author": "ai_researcher",
"created_at": "2024-01-01T12:00:00Z",
"url": "https://twitter.com/ai_researcher/status/123456789",
"likes": 42,
"retweets": 15,
}
formatted = reader._format_bookmark(bookmark)
assert "=== Twitter Bookmark ===" in formatted
assert "Author: @ai_researcher" in formatted
assert "Content:" in formatted
assert "This is a great article about AI!" in formatted
assert "URL: https://twitter.com" in formatted
assert "Likes: 42" in formatted
assert "Retweets: 15" in formatted
# Test with minimal data
bookmark = {"text": "Simple tweet"}
formatted = reader._format_bookmark(bookmark)
assert "=== Twitter Bookmark ===" in formatted
assert "Simple tweet" in formatted
print("✅ Twitter bookmark formatting tests passed")
def test_slack_rag_initialization():
"""Test that SlackMCPRAG can be initialized."""
print("Testing SlackMCPRAG initialization...")
app = SlackMCPRAG()
assert app.default_index_name == "slack_messages"
assert hasattr(app, "parser")
print("✅ SlackMCPRAG initialization tests passed")
def test_twitter_rag_initialization():
"""Test that TwitterMCPRAG can be initialized."""
print("Testing TwitterMCPRAG initialization...")
app = TwitterMCPRAG()
assert app.default_index_name == "twitter_bookmarks"
assert hasattr(app, "parser")
print("✅ TwitterMCPRAG initialization tests passed")
def test_concatenated_content_creation():
"""Test creation of concatenated content from multiple messages."""
print("Testing concatenated content creation...")
reader = SlackMCPReader("slack-mcp-server", workspace_name="test-workspace")
messages = [
{"text": "First message", "user": "alice", "ts": "1000"},
{"text": "Second message", "user": "bob", "ts": "2000"},
{"text": "Third message", "user": "charlie", "ts": "3000"},
]
content = reader._create_concatenated_content(messages, "general")
assert "Slack Channel: #general" in content
assert "Message Count: 3" in content
assert "Workspace: test-workspace" in content
assert "First message" in content
assert "Second message" in content
assert "Third message" in content
print("✅ Concatenated content creation tests passed")
def main():
"""Run all tests."""
print("🧪 Running MCP Integration Tests")
print("=" * 50)
try:
test_slack_reader_initialization()
test_twitter_reader_initialization()
test_slack_message_formatting()
test_twitter_bookmark_formatting()
test_slack_rag_initialization()
test_twitter_rag_initialization()
test_concatenated_content_creation()
print("\n" + "=" * 50)
print("🎉 All tests passed! MCP integration is working correctly.")
print("\nNext steps:")
print("1. Install actual MCP servers for Slack and Twitter")
print("2. Configure API credentials")
print("3. Test with --test-connection flag")
print("4. Start indexing your live data!")
except Exception as e:
print(f"\n❌ Test failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,221 @@
#!/usr/bin/env python3
"""
Standalone test script for MCP integration implementations.
This script tests the basic functionality of the MCP readers
without requiring LEANN core dependencies.
"""
import json
import sys
from pathlib import Path
# Add the parent directory to the path so we can import from apps
sys.path.append(str(Path(__file__).parent.parent))
def test_slack_reader_basic():
"""Test basic SlackMCPReader functionality without async operations."""
print("Testing SlackMCPReader basic functionality...")
# Import and test initialization
from apps.slack_data.slack_mcp_reader import SlackMCPReader
reader = SlackMCPReader("slack-mcp-server")
assert reader.mcp_server_command == "slack-mcp-server"
assert reader.concatenate_conversations
# Test message formatting
message = {
"text": "Hello team! How's the project going?",
"user": "john_doe",
"channel": "general",
"ts": "1234567890.123456",
}
formatted = reader._format_message(message)
assert "Channel: #general" in formatted
assert "User: john_doe" in formatted
assert "Message: Hello team!" in formatted
# Test concatenated content creation
messages = [
{"text": "First message", "user": "alice", "ts": "1000"},
{"text": "Second message", "user": "bob", "ts": "2000"},
]
content = reader._create_concatenated_content(messages, "dev-team")
assert "Slack Channel: #dev-team" in content
assert "Message Count: 2" in content
assert "First message" in content
assert "Second message" in content
print("✅ SlackMCPReader basic tests passed")
def test_twitter_reader_basic():
"""Test basic TwitterMCPReader functionality."""
print("Testing TwitterMCPReader basic functionality...")
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
reader = TwitterMCPReader("twitter-mcp-server")
assert reader.mcp_server_command == "twitter-mcp-server"
assert reader.include_tweet_content
assert reader.max_bookmarks == 1000
# Test bookmark formatting
bookmark = {
"text": "Amazing article about the future of AI! Must read for everyone interested in tech.",
"author": "tech_guru",
"created_at": "2024-01-15T14:30:00Z",
"url": "https://twitter.com/tech_guru/status/123456789",
"likes": 156,
"retweets": 42,
"replies": 23,
"hashtags": ["AI", "tech", "future"],
"mentions": ["@openai", "@anthropic"],
}
formatted = reader._format_bookmark(bookmark)
assert "=== Twitter Bookmark ===" in formatted
assert "Author: @tech_guru" in formatted
assert "Amazing article about the future of AI!" in formatted
assert "Likes: 156" in formatted
assert "Retweets: 42" in formatted
assert "Hashtags: AI, tech, future" in formatted
assert "Mentions: @openai, @anthropic" in formatted
# Test with minimal data
simple_bookmark = {"text": "Short tweet", "author": "user123"}
formatted_simple = reader._format_bookmark(simple_bookmark)
assert "=== Twitter Bookmark ===" in formatted_simple
assert "Short tweet" in formatted_simple
assert "Author: @user123" in formatted_simple
print("✅ TwitterMCPReader basic tests passed")
def test_mcp_request_format():
"""Test MCP request formatting."""
print("Testing MCP request formatting...")
# Test initialization request format
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
},
}
# Verify it's valid JSON
json_str = json.dumps(init_request)
parsed = json.loads(json_str)
assert parsed["jsonrpc"] == "2.0"
assert parsed["method"] == "initialize"
assert parsed["params"]["protocolVersion"] == "2024-11-05"
# Test tools/list request
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
json_str = json.dumps(list_request)
parsed = json.loads(json_str)
assert parsed["method"] == "tools/list"
print("✅ MCP request formatting tests passed")
def test_data_processing():
"""Test data processing capabilities."""
print("Testing data processing capabilities...")
from apps.slack_data.slack_mcp_reader import SlackMCPReader
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
# Test Slack message processing with various formats
slack_reader = SlackMCPReader("test-server")
messages_with_timestamps = [
{"text": "Meeting in 5 minutes", "user": "alice", "ts": "1000.123"},
{"text": "On my way!", "user": "bob", "ts": "1001.456"},
{"text": "Starting now", "user": "charlie", "ts": "1002.789"},
]
content = slack_reader._create_concatenated_content(messages_with_timestamps, "meetings")
assert "Meeting in 5 minutes" in content
assert "On my way!" in content
assert "Starting now" in content
# Test Twitter bookmark processing with engagement data
twitter_reader = TwitterMCPReader("test-server", include_metadata=True)
high_engagement_bookmark = {
"text": "Thread about startup lessons learned 🧵",
"author": "startup_founder",
"likes": 1250,
"retweets": 340,
"replies": 89,
}
formatted = twitter_reader._format_bookmark(high_engagement_bookmark)
assert "Thread about startup lessons learned" in formatted
assert "Likes: 1250" in formatted
assert "Retweets: 340" in formatted
assert "Replies: 89" in formatted
# Test with metadata disabled
twitter_reader_no_meta = TwitterMCPReader("test-server", include_metadata=False)
formatted_no_meta = twitter_reader_no_meta._format_bookmark(high_engagement_bookmark)
assert "Thread about startup lessons learned" in formatted_no_meta
assert "Likes:" not in formatted_no_meta
assert "Retweets:" not in formatted_no_meta
print("✅ Data processing tests passed")
def main():
"""Run all standalone tests."""
print("🧪 Running MCP Integration Standalone Tests")
print("=" * 60)
print("Testing core functionality without LEANN dependencies...")
print()
try:
test_slack_reader_basic()
test_twitter_reader_basic()
test_mcp_request_format()
test_data_processing()
print("\n" + "=" * 60)
print("🎉 All standalone tests passed!")
print("\n✨ MCP Integration Summary:")
print("- SlackMCPReader: Ready for Slack message processing")
print("- TwitterMCPReader: Ready for Twitter bookmark processing")
print("- MCP Protocol: Properly formatted JSON-RPC requests")
print("- Data Processing: Handles various message/bookmark formats")
print("\n🚀 Next Steps:")
print("1. Install MCP servers: npm install -g slack-mcp-server twitter-mcp-server")
print("2. Configure API credentials for Slack and Twitter")
print("3. Test connections: python -m apps.slack_rag --test-connection")
print("4. Start indexing live data from your platforms!")
print("\n📖 Documentation:")
print("- Check README.md for detailed setup instructions")
print("- Run examples/mcp_integration_demo.py for usage examples")
print("- Explore apps/slack_rag.py and apps/twitter_rag.py for implementation details")
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,268 @@
"""Unit tests for token-aware truncation functionality.
This test suite defines the contract for token truncation functions that prevent
500 errors from Ollama when text exceeds model token limits. These tests verify:
1. Model token limit retrieval (known and unknown models)
2. Text truncation behavior for single and multiple texts
3. Token counting and truncation accuracy using tiktoken
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
import pytest
import tiktoken
from leann.embedding_compute import (
EMBEDDING_MODEL_LIMITS,
get_model_token_limit,
truncate_to_token_limit,
)
class TestModelTokenLimits:
"""Tests for retrieving model-specific token limits."""
def test_get_model_token_limit_known_model(self):
"""Verify correct token limit is returned for known models.
Known models should return their specific token limits from
EMBEDDING_MODEL_LIMITS dictionary.
"""
# Test nomic-embed-text (2048 tokens)
limit = get_model_token_limit("nomic-embed-text")
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
# Test nomic-embed-text-v1.5 (2048 tokens)
limit = get_model_token_limit("nomic-embed-text-v1.5")
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
# Test nomic-embed-text-v2 (512 tokens)
limit = get_model_token_limit("nomic-embed-text-v2")
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
# Test OpenAI models (8192 tokens)
limit = get_model_token_limit("text-embedding-3-small")
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
def test_get_model_token_limit_unknown_model(self):
"""Verify default token limit is returned for unknown models.
Unknown models should return the default limit (2048) to allow
operation with reasonable safety margin.
"""
# Test with completely unknown model
limit = get_model_token_limit("unknown-model-xyz")
assert limit == 2048, "Unknown models should return default 2048"
# Test with empty string
limit = get_model_token_limit("")
assert limit == 2048, "Empty model name should return default 2048"
def test_get_model_token_limit_custom_default(self):
"""Verify custom default can be specified for unknown models.
Allow callers to specify their own default token limit when
model is not in the known models dictionary.
"""
limit = get_model_token_limit("unknown-model", default=4096)
assert limit == 4096, "Should return custom default for unknown models"
# Known model should ignore custom default
limit = get_model_token_limit("nomic-embed-text", default=4096)
assert limit == 2048, "Known model should ignore custom default"
def test_embedding_model_limits_dictionary_exists(self):
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
The dictionary should be importable and contain at least the
known nomic models with correct token limits.
"""
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
"Should contain nomic-embed-text-v1.5"
)
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
# OpenAI models
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
class TestTokenTruncation:
"""Tests for truncating texts to token limits."""
@pytest.fixture
def tokenizer(self):
"""Provide tiktoken tokenizer for token counting verification."""
return tiktoken.get_encoding("cl100k_base")
def test_truncate_single_text_under_limit(self, tokenizer):
"""Verify text under token limit remains unchanged.
When text is already within the token limit, it should be
returned unchanged with no truncation.
"""
text = "This is a short text that is well under the token limit."
token_count = len(tokenizer.encode(text))
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
# Truncate with generous limit
result = truncate_to_token_limit([text], token_limit=512)
assert len(result) == 1, "Should return same number of texts"
assert result[0] == text, "Text under limit should be unchanged"
def test_truncate_single_text_over_limit(self, tokenizer):
"""Verify text over token limit is truncated correctly.
When text exceeds the token limit, it should be truncated to
fit within the limit while maintaining valid token boundaries.
"""
# Create a text that definitely exceeds limit
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 50, (
f"Test setup: text should be long (has {original_token_count} tokens)"
)
# Truncate to 50 tokens
result = truncate_to_token_limit([text], token_limit=50)
assert len(result) == 1, "Should return same number of texts"
assert result[0] != text, "Text over limit should be truncated"
assert len(result[0]) < len(text), "Truncated text should be shorter"
# Verify truncated text is within token limit
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 50, (
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
)
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
"""Verify multiple texts with mixed lengths are handled correctly.
When processing multiple texts:
- Texts under limit should remain unchanged
- Texts over limit should be truncated independently
- Output list should maintain same order and length
"""
texts = [
"Short text.", # Under limit
"word " * 200, # Over limit
"Another short one.", # Under limit
"token " * 150, # Over limit
]
# Verify test setup
for i, text in enumerate(texts):
token_count = len(tokenizer.encode(text))
if i in [1, 3]:
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
else:
assert token_count < 50, (
f"Text {i} should be under limit (has {token_count} tokens)"
)
# Truncate with 50 token limit
result = truncate_to_token_limit(texts, token_limit=50)
assert len(result) == len(texts), "Should return same number of texts"
# Verify each text individually
for i, (original, truncated) in enumerate(zip(texts, result)):
token_count = len(tokenizer.encode(truncated))
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
# Short texts should be unchanged
if i in [0, 2]:
assert truncated == original, f"Short text {i} should be unchanged"
# Long texts should be truncated
else:
assert len(truncated) < len(original), f"Long text {i} should be truncated"
def test_truncate_empty_list(self):
"""Verify empty input list returns empty output list.
Edge case: empty list should return empty list without errors.
"""
result = truncate_to_token_limit([], token_limit=512)
assert result == [], "Empty input should return empty output"
def test_truncate_preserves_order(self, tokenizer):
"""Verify truncation preserves original text order.
Output list should maintain the same order as input list,
regardless of which texts were truncated.
"""
texts = [
"First text " * 50, # Will be truncated
"Second text.", # Won't be truncated
"Third text " * 50, # Will be truncated
]
result = truncate_to_token_limit(texts, token_limit=20)
assert len(result) == 3, "Should preserve list length"
# Check that order is maintained by looking for distinctive words
assert "First" in result[0], "First text should remain in first position"
assert "Second" in result[1], "Second text should remain in second position"
assert "Third" in result[2], "Third text should remain in third position"
def test_truncate_extremely_long_text(self, tokenizer):
"""Verify extremely long texts are truncated efficiently.
Test with text that far exceeds token limit to ensure
truncation handles extreme cases without performance issues.
"""
# Create very long text (simulate real-world scenario)
text = "token " * 5000 # ~5000+ tokens
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 1000, "Test setup: text should be very long"
# Truncate to small limit
result = truncate_to_token_limit([text], token_limit=100)
assert len(result) == 1
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 100, (
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
)
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
def test_truncate_exact_token_limit(self, tokenizer):
"""Verify text at exactly token limit is handled correctly.
Edge case: text with exactly the token limit should either
remain unchanged or be safely truncated by 1 token.
"""
# Create text with approximately 50 tokens
# We'll adjust to get exactly 50
target_tokens = 50
text = "word " * 50
tokens = tokenizer.encode(text)
# Adjust to get exactly target_tokens
if len(tokens) > target_tokens:
tokens = tokens[:target_tokens]
text = tokenizer.decode(tokens)
elif len(tokens) < target_tokens:
# Add more words
while len(tokenizer.encode(text)) < target_tokens:
text += "word "
tokens = tokenizer.encode(text)[:target_tokens]
text = tokenizer.decode(tokens)
# Verify we have exactly target_tokens
assert len(tokenizer.encode(text)) == target_tokens, (
"Test setup: should have exactly 50 tokens"
)
result = truncate_to_token_limit([text], token_limit=target_tokens)
assert len(result) == 1
result_tokens = len(tokenizer.encode(result[0]))
assert result_tokens <= target_tokens, (
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
)

7764
uv.lock generated
View File

File diff suppressed because it is too large Load Diff