Compare commits

...

366 Commits

Author SHA1 Message Date
yichuan520030910320
00c44e3980 [cli] fix # 81 2025-09-01 16:53:09 -07:00
yichuan520030910320
e6a542bf4b [cli] better gitignore / better leann list 2025-09-01 16:42:11 -07:00
yichuan520030910320
7e84dae02e [chore] add slack to share use case 2025-08-30 00:32:13 -07:00
yichuan520030910320
2f05ed4535 chore(submodule): bump faiss to latest storage-efficient build 2025-08-23 18:29:11 -07:00
yichuan520030910320
4e5b73ce7b fix bug introduce in #58 2025-08-22 02:35:09 -07:00
Gabriel Dehan
31b4973141 Metadata filtering feature (#75)
* Metadata filtering initial version

* Metadata filtering initial version

* Fixes linter issues

* Cleanup code

* Clean up and readme

* Fix after review

* Use UV in example

* Merge main into feature/metadata-filtering
2025-08-20 19:57:56 -07:00
Yichuan Wang
dde2221513 [EXP] Update the benchmark code (#71)
* chore(hnsw): reorder imports to satisfy ruff I001

* chore: sync changes; fix Ruff import order; update examples, benchmarks, and dependencies

- Fix import order in packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py (Ruff I001)

- Update benchmarks/run_evaluation.py

- Update apps/base_rag_example.py and leann-core API usage

- Add benchmarks/data/README.md

- Update uv.lock

- Misc cleanup

- Note: added paru-bin as an embedded git repo; consider making it a submodule (git rm --cached paru-bin) if unintended

* chore: remove unintended embedded repo paru-bin and ignore it

Fix CI: avoid missing .gitmodules entry by removing gitlink and adding to .gitignore.

* ci: retrigger after removing unintended gitlink (paru-bin)

* feat(benchmarks): add --batch-size option and plumb through to HNSW search (default 0)

* feat(hnsw): add batch_size to LeannSearcher.search and LeannChat.ask; forward only for HNSW backend

* chore(logging): surface recompute and batching params; enable INFO logging in benchmark

* feat(embeddings): add optional manual tokenization path (HF tokenizer+model) with mean pooling; default remains SentenceTransformer.encode

* fix micro bench and fix pre commit

* update readme

---------

Co-authored-by: yichuan-w <yichuan-w@users.noreply.github.com>
2025-08-20 17:31:46 -07:00
Andy Lee
6d11e86e71 Run Evaluation RPJ Wiki on Arch Linux (#74)
* chore: ignore benchmark data

* perf: avoid merging offset dicts for lower mem usage

* style: format

* docs: rpj_wiki
2025-08-20 12:25:54 -07:00
Gabriel Dehan
13bb561aad Add AST-aware code chunking for better code understanding (#58)
* feat(core): Add AST-aware code chunking with astchunk integration

This PR introduces intelligent code chunking that preserves semantic boundaries
(functions, classes, methods) for better code understanding in RAG applications.

Key Features:
- AST-aware chunking for Python, Java, C#, TypeScript files
- Graceful fallback to traditional chunking for unsupported languages
- New specialized code RAG application for repositories
- Enhanced CLI with --use-ast-chunking flag
- Comprehensive test suite with integration tests

Technical Implementation:
- New chunking_utils.py module with enhanced chunking logic
- Extended base RAG framework with AST chunking arguments
- Updated document RAG with --enable-code-chunking flag
- CLI integration with proper error handling and fallback

Benefits:
- Better semantic understanding of code structure
- Improved search quality for code-related queries
- Maintains backward compatibility with existing workflows
- Supports mixed content (code + documentation) seamlessly

Dependencies:
- Added astchunk and tree-sitter parsers to pyproject.toml
- All dependencies are optional - fallback works without them

Testing:
- Comprehensive test suite in test_astchunk_integration.py
- Integration tests with document RAG
- Error handling and edge case coverage

Documentation:
- Updated README.md with AST chunking highlights
- Added ASTCHUNK_INTEGRATION.md with complete guide
- Updated features.md with new capabilities

* Refactored chunk utils

* Remove useless import

* Update README.md

* Update apps/chunking/utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update apps/code_rag.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Fix issue

* apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Fixes after pr review

* Fix tests not passing

* Fix linter error for documentation files

* Update .gitignore with unwanted files

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Andy Lee <andylizf@outlook.com>
2025-08-19 23:35:31 -07:00
GitHub Actions
0174ba5571 chore: release v0.3.2 2025-08-19 09:41:40 +00:00
Andy Lee
03af82d695 fix: leann mcp search cwd & interactive issues (#72) 2025-08-19 02:27:06 -07:00
GitHub Actions
738f1dbab8 chore: release v0.3.1 2025-08-19 05:56:45 +00:00
yichuan520030910320
37d990d51c [feature] fix cli 2025-08-18 22:55:43 -07:00
Andy Lee
a6f07a54f1 fix: Use uv venv for Arch Linux CI wheel installation (#69)
- Use astral-sh/setup-uv@v4 action for consistency with other jobs
- Create virtual environment with uv venv to bypass PEP 668 restrictions
- Install wheels using uv pip install for faster dependency resolution
- Maintain tool consistency across the entire CI pipeline

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-08-16 21:32:19 -07:00
Andy Lee
46905e0687 feat: Improve DiskANN cross-platform compatibility and add Arch Linux support (#66)
* feat: Enhance CLI with improved list and smart remove commands

##  New Features

### 🏠 Enhanced `leann list` command
- **Better UX**: Current project shown first with clear separation
- **Visual improvements**: Icons (🏠/📂), better formatting, size info
- **Smart guidance**: Context-aware usage examples and getting started tips

### 🛡️ Smart `leann remove` command
- **Safety first**: Always shows ALL matching indexes across projects
- **Intelligent handling**:
  - Single match: Clear location display with cross-project warnings
  - Multiple matches: Interactive selection with final confirmation
- **Prevents accidents**: No more deleting wrong indexes due to name conflicts
- **User-friendly**: 'c' to cancel, clear visual hierarchy, detailed info

### 🔧 Technical improvements
- **Clean logging**: Hide debug messages for better CLI experience
- **Comprehensive search**: Always scan all projects for transparency
- **Error handling**: Graceful handling of edge cases and user input

## 🎯 Impact
- **Safer**: Eliminates risk of accidental index deletion
- **Clearer**: Users always know what they're operating on
- **Smarter**: Automatic detection and handling of common scenarios

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: vscode ruff, and format

* fix: Update DiskANN submodule with MKL linking improvements

Updates DiskANN submodule to include fix for MKL linking issues:
- Replaces global link_libraries() with target-specific linking
- Uses dynamic MKL linking (mkl_rt) for better cross-platform compatibility
- Prevents MKL contamination of unrelated targets (like zlib tests)
- Resolves build failures on strict linkers (Arch Linux) while maintaining Ubuntu compatibility

DiskANN commit: c593831 - fix: Replace global MKL linking with target-specific approach

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: all linux deps

* fix: Update Intel MKL download link to avoid 403 error

- Replace problematic Intel download URL that returns 403 Forbidden
- Use general Intel oneAPI MKL page instead of specific download parameters
- This fixes the lychee link checker CI failure

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Configure lychee to use browser User-Agent for Intel links

- Replace domain exclusion with browser User-Agent to properly check Intel links
- Intel website blocks automated tools but allows browser-like requests
- This enables proper link validation while avoiding 403 Forbidden errors

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Use curl User-Agent for lychee link checking

Intel website has specific anti-bot logic:
- Blocks browser User-Agents (returns 403)
- Blocks lychee default User-Agent (returns 403)
- Allows curl User-Agent (returns 200)

This enables proper link validation for Intel documentation.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-08-16 14:42:20 -07:00
Andy Lee
838ade231e 🔗 Auto-register apps: Universal index discovery (#64)
* feat: Enhance CLI with improved list and smart remove commands

##  New Features

### 🏠 Enhanced `leann list` command
- **Better UX**: Current project shown first with clear separation
- **Visual improvements**: Icons (🏠/📂), better formatting, size info
- **Smart guidance**: Context-aware usage examples and getting started tips

### 🛡️ Smart `leann remove` command
- **Safety first**: Always shows ALL matching indexes across projects
- **Intelligent handling**:
  - Single match: Clear location display with cross-project warnings
  - Multiple matches: Interactive selection with final confirmation
- **Prevents accidents**: No more deleting wrong indexes due to name conflicts
- **User-friendly**: 'c' to cancel, clear visual hierarchy, detailed info

### 🔧 Technical improvements
- **Clean logging**: Hide debug messages for better CLI experience
- **Comprehensive search**: Always scan all projects for transparency
- **Error handling**: Graceful handling of edge cases and user input

## 🎯 Impact
- **Safer**: Eliminates risk of accidental index deletion
- **Clearer**: Users always know what they're operating on
- **Smarter**: Automatic detection and handling of common scenarios

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: vscode ruff, and format

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-08-16 11:50:25 -07:00
Andy Lee
da6540decd feat: Enhance CLI with improved list and smart remove commands (#63)
- **Better UX**: Current project shown first with clear separation
- **Visual improvements**: Icons (🏠/📂), better formatting, size info
- **Smart guidance**: Context-aware usage examples and getting started tips

- **Safety first**: Always shows ALL matching indexes across projects
- **Intelligent handling**:
  - Single match: Clear location display with cross-project warnings
  - Multiple matches: Interactive selection with final confirmation
- **Prevents accidents**: No more deleting wrong indexes due to name conflicts
- **User-friendly**: 'c' to cancel, clear visual hierarchy, detailed info

- **Clean logging**: Hide debug messages for better CLI experience
- **Comprehensive search**: Always scan all projects for transparency
- **Error handling**: Graceful handling of edge cases and user input

- **Safer**: Eliminates risk of accidental index deletion
- **Clearer**: Users always know what they're operating on
- **Smarter**: Automatic detection and handling of common scenarios
2025-08-15 23:49:47 -07:00
yichuan520030910320
39e18a7c11 [chore] remove gitattribute 2025-08-15 23:12:24 -07:00
Andy Lee
6bde28584b feat: Add Google Gemini API support for chat and embeddings (#57)
- Add GeminiChat class with gemini-2.5-flash model support
- Add compute_embeddings_gemini function with text-embedding-004 model
- Update get_llm factory to support "gemini" type
- Update API documentation to include gemini embedding mode
- Support temperature, max_tokens, top_p parameters for Gemini chat
- Support batch embedding processing with progress bars
- Add proper error handling and API key validation
2025-08-15 21:54:11 -07:00
yichuan520030910320
f62632c41f [readme]update arch linux install 2025-08-15 21:41:34 -07:00
yichuan520030910320
27708243ca update system support 2025-08-15 21:32:53 -07:00
GitHub Actions
9a1e4652ca chore: release v0.3.0 2025-08-16 00:54:47 +00:00
Andy Lee
14e84d9e2d fix(core): skip empty/invalid chunks before embedding; guard OpenAI embeddings (#55)
Avoid 400 errors from OpenAI when chunker yields empty strings by filtering
invalid texts in LeannBuilder.build_index. Add validation fail-fast in
OpenAI embedding path to surface upstream issues earlier. Keeps passages and
embeddings aligned during build.

Refs #54
2025-08-15 17:53:53 -07:00
Yichuan Wang
2dcfca19ff style: apply ruff format (#56) 2025-08-15 17:48:33 -07:00
Yichuan Wang
bee2167ee3 docs: update READMEs (MCP docs + conclusion polish)
- Polish conclusion in packages/leann-mcp/README.md
- Sync root README wording and links
2025-08-15 17:21:23 -07:00
yichuan520030910320
ef980d70b3 [MCP]update MCP of claude code 2025-08-15 14:29:59 -07:00
Andy Lee
db3c63c441 Docs/Core: Low-Resource Setups, SkyPilot Option, and No-Recompute (#45)
* docs: add SkyPilot template and instructions for running embeddings/index build on cloud GPU

* docs: add low-resource note in README; point to config guide; suggest OpenAI embeddings, SkyPilot remote build, and --no-recompute

* docs: consolidate low-resource guidance into config guide; README points to it

* cli: add --no-recompute and --no-recompute-embeddings flags; docs: clarify HNSW requires --no-compact when disabling recompute

* docs: dedupe recomputation guidance; keep single Low-resource setups section

* sky: expand leann-build.yaml with configurable params and flags (backend, recompute, compact, embedding options)

* hnsw: auto-disable compact when --no-recompute is used; docs: expand SkyPilot with -e overrides and copy-back example

* docs+sky: simplify SkyPilot flow (auto-build on launch, rsync copy-back); clarify HNSW auto non-compact when no-recompute

* feat: auto compact for hnsw when recompute

* reader: non-destructive portability (relative hints + fallback); fix comments; sky: refine yaml

* cli: unify flags to --recompute/--no-recompute for build/search/ask; docs: update references

* chore: remove

* hnsw: move pruned/no-recompute assertion into backend; api: drop global assertion; docs: will adjust after benchmarking

* cli: use argparse.BooleanOptionalAction for paired flags (--recompute/--compact) across build/search/ask

* docs: a real example on recompute

* benchmarks: fix and extend HNSW+DiskANN recompute vs no-recompute; docs: add fresh numbers and DiskANN notes

* benchmarks: unify HNSW & DiskANN into one clean script; isolate groups, fixed ports, warm-up, param complexity

* docs: diskann recompute

* core: auto-cleanup for LeannSearcher/LeannChat (__enter__/__exit__/__del__); ensure server terminate/kill robustness; benchmarks: use searcher.cleanup(); docs: suggest uv run

* fix: hang on warnings

* docs: boolean flags

* docs: leann help
2025-08-15 12:03:19 -07:00
yichuan520030910320
00eeadb9dd upd pkg 2025-08-14 14:39:45 -07:00
yichuan520030910320
42c8370709 add chunk size in leann build& fix batch size in oai& docs 2025-08-14 13:14:14 -07:00
Andy Lee
fafdf8fcbe feat(core,diskann): robust embedding server (no-hang) + DiskANN fast mode (graph partition) (#29)
* feat: Add graph partition support for DiskANN backend

- Add GraphPartitioner class for advanced graph partitioning
- Add partition_graph_simple function for easy-to-use partitioning
- Add pybind11 dependency for C++ executable building
- Update __init__.py to export partition functions
- Include test scripts for partition functionality

The partition functionality allows optimizing disk-based indices
for better search performance and memory efficiency.

* chore: Update DiskANN submodule to latest with graph partition tools

- Update DiskANN submodule to commit b2dc4ea
- Includes graph partition tools and CMake integration
- Enables graph partitioning functionality in DiskANN backend

* merge

* ruff

* add a path related fix

* fix: always use relative path in metadata

* docs: tool cli install

* chore: more data

* fix: diskann building and partitioning

* tests: diskann and partition

* docs: highlight diskann readiness and add performance comparison

* docs: add ldg-times parameter for diskann graph locality optimization

* fix: update pre-commit ruff version and format compliance

* fix: format test files with latest ruff version for CI compatibility

* fix: pin ruff version to 0.12.7 across all environments

- Pin ruff==0.12.7 in pyproject.toml dev dependencies
- Update CI to use exact ruff version instead of latest
- Add comments explaining version pinning rationale
- Ensures consistent formatting across local, CI, and pre-commit

* fix: use uv tool install for ruff instead of uv pip install

- uv tool install is the correct way to install CLI tools like ruff
- uv pip install --system is for Python packages, not tools

* debug: add detailed logging for CI path resolution debugging

- Add logging in DiskANN embedding server to show metadata_file_path
- Add debug logging in PassageManager to trace path resolution
- This will help identify why CI fails to find passage files

* fix: force install local wheels in CI to prevent PyPI version conflicts

- Change from --find-links to direct wheel installation with --force-reinstall
- This ensures CI uses locally built packages with latest source code
- Prevents uv from using PyPI packages with same version number but old code
- Fixes CI test failures where old code (without metadata_file_path) was used

Root cause: CI was installing leann-backend-diskann v0.2.1 from PyPI
instead of the locally built wheel with same version number.

* debug: add more CI diagnostics for DiskANN module import issue

- Check wheel contents before and after auditwheel repair
- Verify _diskannpy module installation after pip install
- List installed package directory structure
- Add explicit platform tag for auditwheel repair

This helps diagnose why ImportError: cannot import name '_diskannpy' occurs

* fix: remove invalid --plat argument from auditwheel repair

- Remove '--plat linux_x86_64' which is not a valid platform tag
- Let auditwheel automatically determine the correct platform
- Based on CI output, it will use manylinux_2_35_x86_64

This was causing auditwheel repair to fail, preventing proper wheel repair

* fix: ensure CI installs correct Python version wheel packages

- Use --find-links with --no-index to let uv select correct wheel
- Prevents installing wrong Python version wheel (e.g., cp310 for Python 3.11)
- Fixes ImportError: _diskannpy.cpython-310-x86_64-linux-gnu.so in Python 3.11

The issue was that *.whl glob matched all Python versions, causing
uv to potentially install a cp310 wheel in a Python 3.11 environment.

* fix: ensure venv uses correct Python version from matrix

- Explicitly specify Python version when creating venv with uv
- Prevents mismatch between build Python (e.g., 3.10) and test Python
- Fixes: _diskannpy.cpython-310-x86_64-linux-gnu.so in Python 3.11 error

The issue: uv venv was defaulting to Python 3.11 regardless of matrix version

* fix: resolve dependency issues in CI package installation

- Ubuntu: Install all packages from local builds with --no-index
- macOS: Install core packages from PyPI, backends from local builds
- Remove --no-index for macOS backend installation to allow dependency resolution
- Pin versions when installing from PyPI to ensure consistency

Fixes error: 'leann-core was not found in the provided package locations'

* fix: Python 3.9 compatibility - replace Union type syntax

- Replace 'int | None' with 'Optional[int]' everywhere
- Replace 'subprocess.Popen | None' with 'Optional[subprocess.Popen]'
- Add Optional import to all affected files
- Update ruff target-version from py310 to py39
- The '|' syntax for Union types was introduced in Python 3.10 (PEP 604)

Fixes TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'

* ci: build all packages on all platforms; install from local wheels only

- Build leann-core and leann on macOS too
- Install all packages via --find-links and --no-index across platforms
- Lower macOS MACOSX_DEPLOYMENT_TARGET to 12.0 for wider compatibility

This ensures consistency and avoids PyPI drift while improving macOS compatibility.

* ci: allow resolving third-party deps from index; still prefer local wheels for our packages

- Remove --no-index so numpy/scipy/etc can be resolved on Python 3.13
- Keep --find-links to force our packages from local dist

Fixes: dependency resolution failure on Ubuntu Python 3.13 (numpy missing)

* ci(macOS): set MACOSX_DEPLOYMENT_TARGET back to 13.3

- Fix build failure: 'sgesdd_' only available on macOS 13.3+
- Keep other CI improvements (local builds, find-links installs)

* fix(py39): replace union type syntax in chat.py

- validate_model_and_suggest: str | None -> Optional[str]
- OpenAIChat.__init__: api_key: str | None -> Optional[str]
- get_llm: dict[str, Any] | None -> Optional[dict[str, Any]]

Ensures Python 3.9 compatibility for CI macOS 3.9.

* style: organize imports per ruff; finish py39 Optional changes

- Fix import ordering in embedding servers and graph_partition_simple
- Remove duplicate Optional import
- Complete Optional[...] replacements

* fix(py39): replace remaining '| None' in diskann graph_partition (module-level function)

* fix(py39): remove zip(strict=...) usage in api; Python 3.9 compatibility

* style: organize imports; fix process-group stop for embedding server

* chore: keep embedding server stdout/stderr visible; still use new session and pg-kill on stop

* fix: add timeout to final wait() in stop_server to prevent infinite hang

* fix: prevent hang in CI by flushing print statements and redirecting embedding server output

- Add flush=True to all print statements in convert_to_csr.py to prevent buffer deadlock
- Redirect embedding server stdout/stderr to DEVNULL in CI environment (CI=true)
- Fix timeout in embedding_server_manager.stop_server() final wait call

* fix: resolve CI hanging by removing problematic wait() in stop_server

* fix: remove hardcoded paths from MCP server and documentation

* feat: add CI timeout protection for tests

* fix: skip OpenAI test in CI to avoid failures and API costs

- Add CI skip for test_document_rag_openai
- Test was failing because it incorrectly used --llm simulated which isn't supported by document_rag.py

* feat: add simulated LLM option to document_rag.py

- Add 'simulated' to the LLM choices in base_rag_example.py
- Handle simulated case in get_llm_config() method
- This allows tests to use --llm simulated to avoid API costs

* feat: add comprehensive debugging capabilities with tmate integration

1. Tmate SSH Debugging:
   - Added manual workflow_dispatch trigger with debug_enabled option
   - Integrated mxschmitt/action-tmate@v3 for SSH access to CI runner
   - Can be triggered manually or by adding [debug] to commit message
   - Detached mode with 30min timeout, limited to actor only
   - Also triggers on test failure when debug is enabled

2. Enhanced Pytest Output:
   - Added --capture=no to see real-time output
   - Added --log-cli-level=DEBUG for maximum verbosity
   - Added --tb=short for cleaner tracebacks
   - Pipe output to tee for both display and logging
   - Show last 20 lines of output on completion

3. Environment Diagnostics:
   - Export PYTHONUNBUFFERED=1 for immediate output
   - Show Python/Pytest versions at start
   - Display relevant environment variables
   - Check network ports before/after tests

4. Diagnostic Script:
   - Created scripts/diagnose_hang.sh for comprehensive system checks
   - Shows processes, network, file descriptors, memory, ZMQ status
   - Automatically runs on timeout for detailed debugging info

This allows debugging CI hangs via SSH when needed while providing extensive logging by default.

* fix: add diagnostic script (force add to override .gitignore)

The diagnose_hang.sh script needs to be in git for CI to use it.
Using -f to override *.sh rule in .gitignore.

* test: investigate hanging [debug]

* fix: move tmate debug session inside pytest step to avoid hanging

The issue was that tmate was placed before pytest step, but the hang
occurs during pytest execution. Now tmate starts inside the test step
and provides connection info before running tests.

* debug: trigger tmate debug session [debug]

* fix: debug variable values and add commit message [debug] trigger

- Add debug output to show variable values
- Support both manual trigger and [debug] in commit message

* fix: force debug mode for investigation branch

- Auto-enable debug mode for debug/clean-state-investigation branch
- Add more debug info to troubleshoot trigger issues
- This ensures tmate will start regardless of trigger method

* fix: use github.head_ref for PR branch detection

For pull requests, github.ref is refs/pull/N/merge, but github.head_ref
contains the actual branch name. This should fix debug mode detection.

* fix: FORCE debug mode on - no more conditions

Just always enable debug mode on this branch.
We need tmate to work for investigation!

* fix: improve tmate connection info retrieval

- Add proper wait and retry logic for tmate initialization
- Tmate needs time to connect to servers before showing SSH info
- Try multiple times with delays to get connection details

* fix: ensure OpenMP is found during DiskANN build on macOS

- Add OpenMP environment variables directly in build step
- Should fix the libomp.dylib not found error on macOS-14

* fix: simplify macOS OpenMP configuration to match main branch

- Remove complex OpenMP environment variables
- Use simplified configuration from working main branch
- Remove redundant OpenMP setup in DiskANN build step
- Keep essential settings: OpenMP_ROOT, CMAKE_PREFIX_PATH, LDFLAGS, CPPFLAGS

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: revert DiskANN submodule to stable version

The debug branch had updated DiskANN submodule to a version with
hardcoded OpenMP paths that break macOS 13 builds. This reverts
to the stable version used in main branch.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: update faiss submodule to latest stable version

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: remove upterm/tmate debug code and clean CI workflow

- Remove all upterm/tmate SSH debugging infrastructure
- Restore clean CI workflow from main branch
- Remove diagnostic script that was only for SSH debugging
- Keep valuable DiskANN and HNSW backend improvements

This provides a clean base to add targeted pytest hang debugging
without the complexity of SSH sessions.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* debug: increase timeouts to 600s for comprehensive hang investigation

- Increase pytest timeout from 300s to 600s for thorough testing
- Increase import testing timeout from 60s to 120s
- Allow more time for C++ extension loading (faiss/diskann)
- Still provides timeout protection against infinite hangs

This gives the system more time to complete imports and tests
while still catching genuine hangs that exceed reasonable limits.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: remove debug_enabled parameter from build-and-publish workflow

- Remove debug_enabled input parameter that no longer exists in build-reusable.yml
- Keep workflow_dispatch trigger but without debug options
- Fixes workflow validation error: 'debug_enabled is not defined'

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* debug: fix YAML syntax and add post-pytest cleanup monitoring

- Fix Python code formatting in YAML (pre-commit fixed indentation issues)
- Add comprehensive post-pytest cleanup monitoring
- Monitor for hanging processes after test completion
- Focus on teardown phase based on previous hang analysis

This addresses the root cause identified: hang occurs after tests pass,
likely during cleanup/teardown of C++ extensions or embedding servers.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* debug: add external process monitoring and unbuffered output for precise hang detection

* fix

* feat: add comprehensive hang detection for pytest CI debugging

- Add Python faulthandler integration with signal-triggered stack dumps
- Implement periodic stack dumps at 5min and 10min intervals
- Add external process monitoring with SIGUSR1 signal on hang detection
- Use debug_pytest.py wrapper to capture exact hang location in C++ cleanup
- Enhance CPU stability monitoring to trigger precise stack traces

This addresses the persistent pytest hanging issue in Ubuntu 22.04 CI by
providing detailed stack traces to identify the exact code location where
the hang occurs during test cleanup phase.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* CI: move pytest hang-debug script into scripts/ci_debug_pytest.py; sort imports and apply ruff suggestion; update workflow to call the script

* fix: improve hang detection to monitor actual pytest process

* fix: implement comprehensive solution for CI pytest hangs

Key improvements:
1. Replace complex monitoring with simpler process group management
2. Add pytest conftest.py with per-test timeouts and aggressive cleanup
3. Skip problematic tests in CI that cause infinite loops
4. Enhanced cleanup at session start/end and after each test
5. Shorter timeouts (3min per test, 10min total) with better monitoring

This should resolve the hanging issues by:
- Preventing individual tests from running too long
- Automatically cleaning up hanging processes
- Skipping known problematic tests in CI
- Using process groups for more reliable cleanup

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: correct pytest_runtest_call hook parameter in conftest.py

- Change invalid 'puretest' parameter to proper pytest hooks
- Replace problematic pytest_runtest_call with pytest_runtest_setup/teardown
- This fixes PluginValidationError preventing pytest from starting
- Remove unused time import

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: prevent wrapper script from killing itself in cleanup

- Remove overly aggressive pattern 'python.*pytest' that matched wrapper itself
- Add current PID check to avoid killing wrapper process
- Add exclusion for wrapper and debug script names
- This fixes exit code 137 (SIGKILL) issue where wrapper killed itself

Root cause: cleanup function was killing the wrapper process itself,
causing immediate termination with no output in CI.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: prevent wrapper from detecting itself as remaining process

- Add PID and script name checks in post-test verification
- Avoid false positive detection of wrapper process as 'remaining'
- This prevents unnecessary cleanup calls that could cause hangs
- Root cause: wrapper was trying to clean up itself in verification phase

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: implement graceful shutdown for embedding servers

- Replace daemon threads with coordinated shutdown mechanism
- Add shutdown_event for thread synchronization
- Implement proper ZMQ resource cleanup
- Wait for threads to complete before exit
- Add ZMQ timeout to allow periodic shutdown checks
- Move signal handlers into server functions for proper scope access
- Fix protobuf class names and variable references
- Simplify resource cleanup to avoid variable scope issues

Root cause: Original servers used daemon threads + direct sys.exit(0)
which interrupted ZMQ operations and prevented proper resource cleanup,
causing hangs during process termination in CI environments.

This should resolve the core pytest hanging issue without complex wrappers.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: simplify embedding server process management

- Remove start_new_session=True to fix signal handling issues
- Simplify termination logic to use standard SIGTERM/SIGKILL
- Remove complex process group management that could cause hangs
- Add timeout-based cleanup to prevent CI hangs while ensuring proper resource cleanup
- Give graceful shutdown more time (5s) since we fixed the server shutdown logic
- Remove unused signal import

This addresses the remaining process management issues that could
cause startup failures and hanging during termination.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: increase CI test timeouts to accommodate model download

Analysis of recent CI failures shows:
- Model download takes ~12 seconds
- Embedding server startup + first search takes additional ~78 seconds
- Total time needed: ~90-100 seconds

Updated timeouts:
- test_readme_basic_example: 90s -> 180s
- test_backend_options: 60s -> 150s
- test_llm_config_simulated: 75s -> 150s

Root cause: Initial model download from huggingface.co in CI environment
is slower than local development, causing legitimate timeouts rather than
actual hanging processes.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* debug: preserve stderr in CI to debug embedding server startup failures

Previous fix revealed the real issue: embedding server fails to start within 120s,
not timeout issues. The error was hidden because both stdout and stderr were
redirected to DEVNULL in CI.

Changes:
- Keep stderr output in CI environment for debugging
- Only redirect stdout to DEVNULL to avoid buffer deadlock
- This will help us see why embedding server startup is failing

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix(embedding-server): ensure shutdown-capable ZMQ threads create/bind their own REP sockets and poll with timeouts; fix undefined socket causing startup crash and CI hangs on Ubuntu 22.04

* style(hnsw-server): apply ruff-format after robustness changes

* fix(hnsw-server): be lenient to nested [[ids]] for both distance and embedding requests to match client expectations; prevents missing ID lookup when wrapper nests the list

* refactor(hnsw-server): remove duplicate legacy ZMQ thread; keep single shutdown-capable server implementation to reduce surface and avoid hangs

* ci: simplify test step to run pytest uniformly across OS; drop ubuntu-22.04 wrapper special-casing

* chore(ci): remove unused pytest wrapper and debug runner

* refactor(diskann): remove redundant graph_partition_simple; keep single partition API (graph_partition)

* refactor(hnsw-convert): remove global print override; rely on default flushing in CI

* tests: drop custom ci_timeout decorator and helpers; rely on pytest defaults and simplified CI

* tests: remove conftest global timeouts/cleanup; keep test suite minimal and rely on simplified CI + robust servers

* tests: call searcher.cleanup()/chat.cleanup() to ensure background embedding servers terminate after tests

* tests: fix ruff warnings in minimal conftest

* core: add weakref.finalize and atexit-based cleanup in EmbeddingServerManager to ensure server stops on interpreter exit/GC

* tests: remove minimal conftest to validate atexit/weakref cleanup path

* core: adopt compatible running server (record PID) and ensure stop_server() can terminate adopted processes; clear server_port on stop

* ci/core: skip compatibility scanning in CI (LEANN_SKIP_COMPAT=1) to avoid slow/hanging process scans; always pick a fresh available port

* core: unify atexit to always call _finalize_process (covers both self-launched and adopted servers)

* zmq: set SNDTIMEO=1s and LINGER=0 for REP sockets to avoid send blocking during shutdown; reduces CI hang risk

* tests(ci): skip DiskANN branch of README basic example on CI to avoid core dump in constrained runners; HNSW still validated

* diskann(ci): avoid stdout/stderr FD redirection in CI to prevent aborts from low-level dup2; no-op contextmanager on CI

* core: purge dead helpers and comments from EmbeddingServerManager; keep only minimal in-process flow

* core: fix lint (remove unused passages_file); keep per-instance reuse only

* fix: keep backward-compat

---------

Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
Co-authored-by: Claude <noreply@anthropic.com>
2025-08-14 01:02:24 -07:00
yichuan520030910320
21f7d8e031 docs: update -h and config advice 2025-08-13 14:26:35 -07:00
Andy Lee
46565b9249 docs: follows #34, patch leann backends into tool environment 2025-08-12 17:56:02 -07:00
GitHub Actions
3dad76126a chore: release v0.2.9 2025-08-12 23:00:12 +00:00
Andy Lee
18e28bda32 feat: Add macOS 15 support for M4 Mac compatibility (#38)
* feat: add macOS 15 support for M4 Mac compatibility

- Add macos-15 CI builds for Python 3.9-3.13
- Update MACOSX_DEPLOYMENT_TARGET from 11.0/13.3 to 14.0 for broader compatibility
- Addresses issue #34 with Mac M4 wheel compatibility

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: ensure wheels are compatible with older macOS versions

- Set MACOSX_DEPLOYMENT_TARGET=11.0 for HNSW backend (broad compatibility)
- Set MACOSX_DEPLOYMENT_TARGET=13.0 for DiskANN backend (required for LAPACK)
- Add --require-target-macos-version to delocate-wheel commands
- This fixes CI failures on macos-13 runners while maintaining M4 Mac support

Fixes the issue where wheels built on macos-14 runners were incorrectly
tagged as macosx_14_0, preventing installation on macos-13 runners.

* fix: use macOS 13.3 for DiskANN backend as required by LAPACK

DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function, so we must
use 13.3 as the deployment target, not 13.0.

* fix: match deployment target with runner OS for library compatibility

The issue is that Homebrew libraries on macOS 14 runners are built for
macOS 14 and cannot be downgraded. We must use different deployment
targets based on the runner OS:

- macOS 13 runners: Can build for macOS 11.0 (HNSW) and 13.3 (DiskANN)
- macOS 14 runners: Must build for macOS 14.0 (due to system libraries)

This ensures delocate-wheel succeeds by matching the deployment target
with the actual minimum version required by bundled libraries.

* fix: add macOS 15 support to deployment target configuration

The issue extends to macOS 15 runners where Homebrew libraries are built
for macOS 15. We must handle all runner versions explicitly:

- macOS 13 runners: Can build for macOS 11.0 (HNSW) and 13.3 (DiskANN)
- macOS 14 runners: Must build for macOS 14.0 (system libraries)
- macOS 15 runners: Must build for macOS 15.0 (system libraries)

This ensures wheels are properly tagged for their actual minimum
supported macOS version, matching the bundled libraries.

* fix: correct macOS deployment targets based on Homebrew library requirements

The key insight is that Homebrew libraries on each macOS version are
compiled for that specific version:
- macOS 13: Libraries require macOS 13.0 minimum
- macOS 14: Libraries require macOS 14.0 minimum
- macOS 15: Libraries require macOS 15.0 minimum

We cannot build wheels for older macOS versions than what the bundled
Homebrew libraries require. This means:
- macOS 13 runners: Build for macOS 13.0+ (HNSW) and 13.3+ (DiskANN)
- macOS 14 runners: Build for macOS 14.0+
- macOS 15 runners: Build for macOS 15.0+

This ensures delocate-wheel succeeds by matching deployment targets
with the actual minimum versions required by system libraries.

* fix: restore macOS 15 build matrix and correct test path

- Add back macOS 15 configurations for Python 3.9-3.13
- Fix pytest path from test/ to tests/ (correct directory name)

The macOS 15 support was accidentally missing from the matrix, and
pytest was looking for the wrong directory name.

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-08-12 14:01:02 -07:00
GitHub Actions
609fa62fd5 chore: release v0.2.8 2025-08-12 19:04:51 +00:00
Yichuan Wang
eab13434ef feat: support multiple input formats for --docs argument (#39) 2025-08-12 10:30:31 -07:00
yichuan520030910320
b2390ccc14 [Ollama] fix ollama recompute 2025-08-12 00:24:20 -07:00
Andy Lee
e8fca2c84a fix: detect and report Ollama embedding dimension inconsistency (#37)
- Add validation for embedding dimension consistency in Ollama mode
- Provide clear error message with troubleshooting steps when dimensions mismatch
- Fail fast instead of silent fallback to prevent data corruption

Fixes #31
2025-08-11 17:41:52 -07:00
yichuan520030910320
790ae14f69 fix missing file 2025-08-11 17:35:45 -07:00
yichuan520030910320
ac363072e6 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-08-11 17:31:04 -07:00
yichuan520030910320
93465af46c docs: update README fix wrong data file 2025-08-11 17:29:54 -07:00
Andy Lee
792ece67dc ci: add Mac Intel (x86_64) build support (#26)
* ci: add Mac Intel (x86_64) build support

* fix: auto-detect Homebrew path for Intel vs Apple Silicon Macs

This fixes the hardcoded /opt/homebrew path which only works on Apple
Silicon Macs. Intel Macs use /usr/local as the Homebrew prefix.

* fix: auto-detect Homebrew paths for both DiskANN and HNSW backends

- Fix DiskANN CMakeLists.txt path reference
- Add macOS environment variable detection for OpenMP_ROOT
- Support both Intel (/usr/local) and Apple Silicon (/opt/homebrew) paths

* fix: improve macOS build reliability with proper OpenMP path detection

- Add proper CMAKE_PREFIX_PATH and OpenMP_ROOT detection for both Intel and Apple Silicon Macs
- Set LDFLAGS and CPPFLAGS for all Homebrew packages to ensure CMake can find them
- Apply CMAKE_ARGS to both HNSW and DiskANN backends for consistent builds
- Fix hardcoded paths that caused build failures on Intel Macs (macos-13)

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: add abseil library path for protobuf compilation on macOS

- Include abseil in CMAKE_PREFIX_PATH for both Intel and Apple Silicon Macs
- Add explicit absl_DIR CMake variable to help find abseil for protobuf
- Fixes 'absl/log/absl_log.h' file not found error during compilation

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: add abseil include path to CPPFLAGS for both Intel and Apple Silicon

- Add -I/opt/homebrew/opt/abseil/include to CPPFLAGS for Apple Silicon
- Add -I/usr/local/opt/abseil/include to CPPFLAGS for Intel
- Fixes 'absl/log/absl_log.h' file not found by ensuring abseil headers are in compiler include path

Root cause: CMAKE_PREFIX_PATH alone wasn't sufficient - compiler needs explicit -I flags

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: clean build system and Python 3.9 compatibility

Build system improvements:
- Simplify macOS environment detection using brew --prefix
- Remove complex hardcoded paths and CMAKE_ARGS
- Let CMake automatically find Homebrew packages via CMAKE_PREFIX_PATH
- Clean separation between Intel (/usr/local) and Apple Silicon (/opt/homebrew)

Python 3.9 compatibility:
- Set ruff target-version to py39 to match project requirements
- Replace str | None with Union[str, None] in type annotations
- Add Union imports where needed
- Fix core interface, CLI, chat, and embedding server files

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: type

* fix: ensure CMAKE_PREFIX_PATH is passed to backend builds

- Add CMAKE_ARGS with CMAKE_PREFIX_PATH and OpenMP_ROOT for both HNSW and DiskANN backends
- This ensures CMake can find Homebrew packages on both Intel (/usr/local) and Apple Silicon (/opt/homebrew)
- Fixes the issue where CMake was still looking for hardcoded paths instead of using detected ones

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: configure CMake paths in pyproject.toml for proper Homebrew detection

- Add CMAKE_PREFIX_PATH and OpenMP_ROOT environment variable mapping in both backends
- Remove CMAKE_ARGS from GitHub Actions workflow (cleaner separation)
- Ensure scikit-build-core correctly uses environment variables for CMake configuration
- This should fix the hardcoded /opt/homebrew paths on Intel Macs

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: remove hardcoded /opt/homebrew paths from DiskANN CMake

- Auto-detect Homebrew libomp path using OpenMP_ROOT environment variable
- Fallback to CMAKE_PREFIX_PATH/opt/libomp if OpenMP_ROOT not set
- Final fallback to brew --prefix libomp for auto-detection
- Maintains backwards compatibility with old hardcoded path
- Fixes Intel Mac builds that were failing due to hardcoded Apple Silicon paths

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: update DiskANN submodule with macOS Intel/Apple Silicon compatibility fixes

- Auto-detect Homebrew libomp path using OpenMP_ROOT environment variable
- Exclude mkl_set_num_threads on macOS (uses Accelerate framework instead of MKL)
- Fixes compilation on Intel Macs by using correct /usr/local paths

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: update DiskANN submodule with SIMD function name corrections

- Fix _mm128_loadu_ps to _mm_loadu_ps (and similar functions)
- This is a known issue in upstream DiskANN code where incorrect function names were used
- Resolves compilation errors on macOS Intel builds

References: Known DiskANN issue with SIMD intrinsics naming

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: update DiskANN submodule with type cast fix for signed char templates

- Add missing type casts (float*)a and (float*)b in SSE2 version
- This matches the existing type casts in the AVX version
- Fixes compilation error when instantiating DistanceInnerProduct<int8_t>
- Resolves "cannot initialize const float* with const signed char*" error

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: update Faiss submodule with override keyword fix

- Add missing override keyword to IDSelectorModulo::is_member function
- Fixes C++ compilation warning that was treated as error due to -Werror flag
- Resolves "warning: 'is_member' overrides a member function but is not marked 'override'"
- Improves code conformance to modern C++ best practices

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: update Faiss submodule with override keyword fix

* fix: update DiskANN submodule with additional type cast fix

- Add missing type cast in DistanceFastL2::norm function SSE2 version
- Fixes const float* = const signed char* compilation error
- Ensures consistent type casting across all SIMD code paths
- Resolves template instantiation error for DistanceFastL2<int8_t>

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* debug: simplify wheel compatibility checking

- Fix YAML syntax error in debug step
- Use simpler approach to show platform tags and wheel names
- This will help identify platform tag compatibility issues

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: use correct Python version for wheel builds

- Replace --python python with --python ${{ matrix.python }}
- This ensures wheels are built for the correct Python version in each matrix job
- Fixes Python version mismatch where cp39 wheels were used in cp311 environments

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: resolve wheel installation conflicts in CI matrix builds

Fix issue where multiple Python versions' wheels in the same dist directory
caused installation conflicts during CI testing. The problem occurred when
matrix builds for different Python versions accumulated wheels in shared
directories, and uv pip install would find incompatible wheels.

Changes:
- Add Python version detection using matrix.python variable
- Convert Python version to wheel tag format (e.g., 3.11 -> cp311)
- Use find with version-specific pattern matching to select correct wheels
- Add explicit error handling if no matching wheel is found

This ensures each CI job installs only wheels compatible with its specific
Python version, preventing "A path dependency is incompatible with the
current platform" errors.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: ensure virtual environment uses correct Python version in CI

Fix issue where uv venv was creating virtual environments with a different
Python version than specified in the matrix, causing wheel compatibility
errors. The problem occurred when the system had multiple Python versions
and uv venv defaulted to a different version than intended.

Changes:
- Add --python ${{ matrix.python }} flag to uv venv command
- Ensures virtual environment matches the matrix-specified Python version
- Fixes "The wheel is compatible with CPython 3.X but you're using CPython 3.Y" errors

This ensures wheel installation selects and installs the correctly built
wheels that match the runtime Python version.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: complete Python 3.9 type annotation compatibility fixes

Fix remaining Python 3.9 incompatible type annotations throughout the
leann-core package that were causing test failures in CI. The union operator
(|) syntax for type hints was introduced in Python 3.10 and causes
"TypeError: unsupported operand type(s) for |" errors in Python 3.9.

Changes:
- Convert dict[str, Any] | None to Optional[dict[str, Any]]
- Convert int | None to Optional[int]
- Convert subprocess.Popen | None to Optional[subprocess.Popen]
- Convert LeannBackendFactoryInterface | None to Optional[LeannBackendFactoryInterface]
- Add missing Optional imports to all affected files

This resolves all test failures related to type annotation syntax and ensures
compatibility with Python 3.9 as specified in pyproject.toml.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: complete Python 3.9 type annotation fixes in backend packages

Fix remaining Python 3.9 incompatible type annotations in backend packages
that were causing test failures. The union operator (|) syntax for type hints
was introduced in Python 3.10 and causes "TypeError: unsupported operand
type(s) for |" errors in Python 3.9.

Changes in leann-backend-diskann:
- Convert zmq_port: int | None to Optional[int] in diskann_backend.py
- Convert passages_file: str | None to Optional[str] in diskann_embedding_server.py
- Add Optional imports to both files

Changes in leann-backend-hnsw:
- Convert zmq_port: int | None to Optional[int] in hnsw_backend.py
- Add Optional import

This resolves the final test failures related to type annotation syntax and
ensures full Python 3.9 compatibility across all packages.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: remove Python 3.10+ zip strict parameter for Python 3.9 compatibility

Remove the strict=False parameter from zip() call in api.py as it was
introduced in Python 3.10 and causes "TypeError: zip() takes no keyword
arguments" in Python 3.9.

The strict parameter controls whether zip() raises an exception when the
iterables have different lengths. Since we're not relying on this behavior
and the code works correctly without it, removing it maintains the same
functionality while ensuring Python 3.9 compatibility.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: ensure leann-core package is built on all platforms, not just Ubuntu

This fixes the issue where CI was installing leann-core from PyPI instead of
using locally built package with Python 3.9 compatibility fixes.

* fix: build and install leann meta package on all platforms

The leann meta package is pure Python and platform-independent, so there's
no reason to restrict it to Ubuntu only. This ensures all platforms use
consistent local builds instead of falling back to PyPI versions.

* fix: restrict MLX dependencies to Apple Silicon Macs only

MLX framework only supports Apple Silicon (ARM64) Macs, not Intel x86_64.
Add platform_machine == 'arm64' condition to prevent installation failures
on Intel Macs (macos-13).

* cleanup: simplify CI configuration

- Remove debug step with non-existent 'uv pip debug' command
- Simplify wheel installation logic - let uv handle compatibility
- Use -e .[test] instead of manually listing all test dependencies

* fix: install backend wheels before meta packages

Install backend wheels first to ensure they're available when core/meta
packages are installed, preventing uv from trying to resolve backend
dependencies from PyPI.

* fix: use local leann-core when building backend packages

Add --find-links to backend builds to ensure they use the locally built
leann-core with fixed MLX dependencies instead of downloading from PyPI.

Also bump leann-core version to 0.2.8 to ensure clean dependency resolution.

* fix: use absolute path for find-links and upgrade backend version

- Use GITHUB_WORKSPACE for absolute path to ensure find-links works
- Upgrade leann-backend-hnsw to 0.2.8 to match leann-core version

* fix: use absolute path for find-links and upgrade backend version

- Use GITHUB_WORKSPACE for absolute path to ensure find-links works
- Upgrade leann-backend-hnsw to 0.2.8 to match leann-core version

* fix: correct version consistency for --find-links to work properly

- All packages now use version 0.2.7 consistently
- Backend packages can find exact leann-core==0.2.7 from local build
- This ensures --find-links works during CI builds instead of falling back to PyPI

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: revert all packages to consistent version 0.2.7

- This PR should not bump versions, only fix Intel Mac build
- Version bumps should be done in release_manual workflow
- All packages now use 0.2.7 consistently for --find-links to work

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: use --find-links during package installation to avoid PyPI MLX conflicts

- Backend wheels contain Requires-Dist: leann-core==0.2.7
- Without --find-links, uv resolves this from PyPI which has MLX for all Darwin
- With --find-links, uv uses local leann-core with proper platform restrictions
- Root cause: dependency resolution happens at install time, not just build time
- Local test confirms this fixes Intel Mac MLX dependency issues

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: restrict MLX dependencies to ARM64 Macs in workspace pyproject.toml

- Root pyproject.toml also had MLX dependencies without platform_machine restriction
- This caused test dependency installation to fail on Intel Macs
- Now consistent with packages/leann-core/pyproject.toml platform restrictions

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: cleanup unused files and fix GitHub Actions warnings

- Remove unused packages/leann-backend-diskann/CMakeLists.txt
  (DiskANN uses cmake.source-dir=third_party/DiskANN instead)
- Replace macos-latest with macos-14 to avoid migration warnings
  (macos-latest will migrate to macOS 15 on August 4, 2025)
- Keep packages/leann-backend-hnsw/CMakeLists.txt (needed for Faiss config)

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: properly handle Python 3.13 support with PyTorch compatibility

- Support Python 3.13 on most platforms (Ubuntu, ARM64 Mac)
- Exclude Intel Mac + Python 3.13 combination due to PyTorch wheel availability
- PyTorch <2.5 supports Intel Mac but not Python 3.13
- PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64
- Document limitation in CI configuration comments
- Update README badges with detailed Python version support and CI status

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-11 16:39:58 -07:00
GitHub Actions
239e35e2e6 chore: release v0.2.7 2025-08-11 03:11:46 +00:00
Andy Lee
2fac0c6fbf fix: improve gitignore and Jupyter notebook support (#28)
- Add nbconvert dependency for .ipynb file support
- Replace manual gitignore parsing with gitignore-parser library
- Proper recursive .gitignore handling (all subdirectories)
- Fix compliance with Git gitignore behavior
- Simplify code and improve reliability

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-08-10 20:02:46 -07:00
yichuan520030910320
9801aa581b [Readme]update embedding model config according to reddit feedback 2025-08-09 21:33:33 -07:00
GitHub Actions
5e97916608 chore: release v0.2.6 2025-08-10 03:39:45 +00:00
Andy Lee
8b9c2be8c9 Feat/claude code refine (#24)
* feat: Add Ollama embedding support for local embedding models

* docs: Add clear documentation for Ollama embedding usage

* fix: remove leann_ask

* docs: remove ollama embedding extra instructions

* simplify MCP interface for Claude Code

- Remove unnecessary search parameters: search_mode, recompute_embeddings, file_types, min_score
- Remove leann_clear tool (not needed for Claude Code workflow)
- Streamline search to only use: query, index_name, top_k, complexity
- Keep core tools: leann_index, leann_search, leann_status, leann_list

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* remove leann_index from MCP interface

Users should use CLI command 'leann build' to create indexes first.
MCP now only provides search functionality:
- leann_search: search existing indexes
- leann_status: check index health
- leann_list: list available indexes

This separates index creation (CLI) from search (Claude Code).

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* improve CLI with auto project name and .gitignore support

- Make index_name optional, auto-use current directory name
- Read .gitignore patterns and respect them during indexing
- Add _read_gitignore_patterns() to parse .gitignore files
- Add _should_exclude_file() for pattern matching
- Apply exclusion patterns to both PDF and general file processing
- Show helpful messages about gitignore usage

Now users can simply run: leann build
And it will use project name + respect .gitignore patterns.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-08-09 20:37:17 -07:00
Andy Lee
3ff5aac8e0 Add Ollama embedding support to enable local embedding models (#22)
* feat: Add Ollama embedding support for local embedding models

* docs: Add clear documentation for Ollama embedding usage

* feat: Enhance Ollama embedding with better error handling and concurrent processing

- Add intelligent model validation and suggestions (inspired by OllamaChat)
- Implement concurrent processing for better performance
- Add retry mechanism with timeout handling
- Provide user-friendly error messages with emojis
- Auto-detect and recommend embedding models
- Add text truncation for long texts
- Improve progress bar display logic

* docs: don't mention it in README
2025-08-08 18:44:07 -07:00
yichuan520030910320
67fef60466 [Readme]More about claude code 2025-08-08 16:05:35 -07:00
GitHub Actions
b6ab6f1993 chore: release v0.2.5 2025-08-08 22:32:27 +00:00
joshuashaffer
9f2e82a838 Propagate hosts argument for ollama through chat.py (#21)
* Propigate hosts argument for ollama through chat.py

* Apply suggestions from code review

Good AI slop suggestions.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-08 15:31:15 -07:00
yichuan520030910320
0b2b799d5a [README]fix instructions in cli 2025-08-08 01:04:13 -07:00
yichuan520030910320
0f790fbbd9 docs: polish README and add optimized MCP integration image
- Improve grammar and sentence structure in MCP section
- Add proper markdown image formatting with relative paths
- Optimize mcp_leann.png size (1.3MB -> 224KB)
- Update data description to be more specific about Chinese content
2025-08-08 00:58:36 -07:00
GitHub Actions
387ae21eba chore: release v0.2.4 2025-08-08 07:14:51 +00:00
Andy Lee
3cc329c3e7 fix: remove hardcoded paths from MCP server and documentation 2025-08-08 00:08:56 -07:00
Andy Lee
5567302316 feat: promote Claude Code integration as primary RAG feature 2025-08-07 23:19:19 -07:00
GitHub Actions
075d4bd167 chore: release v0.2.2 2025-08-08 01:58:40 +00:00
yichuan520030910320
e4bcc76f88 fix cli & make recompute default true 2025-08-07 18:58:04 -07:00
yichuan520030910320
710e83b1fd fix cli if there is no other type of doc to make it robust 2025-08-07 18:46:05 -07:00
yichuan520030910320
c96d653072 more support for type of docs in cli 2025-08-07 18:14:03 -07:00
Andy Lee
8b22d2b5d3 Merge pull request #19 from yichuan-w/feature/claude-code-research
Feature/claude code research
2025-08-05 23:02:34 -07:00
Andy Lee
4cb544ee38 docs: Update co-contributors with GitHub usernames (#18)
* docs: Update co-contributors with GitHub usernames

* docs: Use GitHub links for co-contributors and improve order

* docs: Change to Contributors and use personal homepage

* docs: Specify core contributors and welcome new contributors
2025-08-05 17:43:59 -07:00
yichuan520030910320
f94ce63d51 add gpt oss! serve your RAG using ollama 2025-08-05 16:49:52 -07:00
GitHub Actions
4271ff9d84 chore: release v0.2.1 2025-08-05 05:50:56 +00:00
Andy Lee
0d448c4a41 docs: config guidance (#17)
* docs: config guidance

* feat: add comprehensive configuration guide and update README

- Create docs/configuration-guide.md with detailed guidance on:
  - Embedding model selection (small/medium/large)
  - Index selection (HNSW vs DiskANN)
  - LLM engine and model comparison
  - Parameter tuning (build/search complexity, top-k)
  - Performance optimization tips
  - Deep dive into LEANN's recomputation feature
- Update README.md to link to the configuration guide
- Include latest 2025 model recommendations (Qwen3, DeepSeek-R1, O3-mini)

* chore: move evaluation data .gitattributes to correct location

* docs: Weaken DiskANN emphasis in README

- Change backend description to emphasize HNSW as default
- DiskANN positioned as optional for billion-scale datasets
- Simplify evaluation commands to be more generic

* docs: Adjust DiskANN positioning in features and roadmap

- features.md: Put HNSW/FAISS first as default, DiskANN as optional
- roadmap.md: Reorder to show HNSW integration before DiskANN
- Consistent with positioning DiskANN as advanced option for large-scale use

* docs: Improve configuration guide based on feedback

- List specific files in default data/ directory (2 AI papers, literature, tech report)
- Update examples to use English and better RAG-suitable queries
- Change full dataset reference to use --max-items -1
- Adjust small model guidance about upgrading to larger models when time allows
- Update top-k defaults to reflect actual default of 20
- Ensure consistent use of full model name Qwen/Qwen3-Embedding-0.6B
- Reorder optimization steps, move MLX to third position
- Remove incorrect chunk size tuning guidance
- Change README from 'Having trouble' to 'Need best practices'

* docs: Address all configuration guide feedback

- Fix grammar: 'If time is not a constraint' instead of 'time expense is not large'
- Highlight Qwen3-Embedding-0.6B performance (nearly OpenAI API level)
- Add OpenAI quick start section with configuration example
- Fold Cloud vs Local trade-offs into collapsible section
- Update HNSW as 'default and recommended for extreme low storage'
- Add DiskANN beta warning and explain PQ+rerank architecture
- Expand Ollama models: add qwen3:0.6b, 4b, 7b variants
- Note OpenAI as current default but recommend Ollama switch
- Add 'need to install extra software' warning for Ollama
- Remove incorrect latency numbers from search-complexity recommendations

* docs: add a link
2025-08-04 22:50:32 -07:00
yichuan520030910320
af5599e33c fix data example name 2025-08-04 17:49:03 -07:00
yichuan520030910320
efdf6d917a fix diskann for faster mode 2025-08-04 17:46:46 -07:00
Andy Lee
dd71ac8d71 feat: implement smart memory configuration for DiskANN (#16)
- Add intelligent memory calculation based on data size and system specs
- search_memory_maximum: 1/10 of embedding size (controls PQ compression)
- build_memory_maximum: 50% of available RAM (controls sharding)
- Provides optimal balance between performance and memory usage
- Automatic fallback to default values if parameters are explicitly provided
2025-08-04 14:36:29 -07:00
GitHub Actions
8bee1d4100 chore: release v0.2.0 2025-08-04 21:34:31 +00:00
yichuan520030910320
33521d6d00 add logs 2025-08-04 14:15:52 -07:00
Andy Lee
8899734952 refactor: Unify examples interface with BaseRAGExample (#12)
* refactor: Unify examples interface with BaseRAGExample

- Create BaseRAGExample base class for all RAG examples
- Refactor 4 examples to use unified interface:
  - document_rag.py (replaces main_cli_example.py)
  - email_rag.py (replaces mail_reader_leann.py)
  - browser_rag.py (replaces google_history_reader_leann.py)
  - wechat_rag.py (replaces wechat_history_reader_leann.py)
- Maintain 100% parameter compatibility with original files
- Add interactive mode support for all examples
- Unify parameter names (--max-items replaces --max-emails/--max-entries)
- Update README.md with new examples usage
- Add PARAMETER_CONSISTENCY.md documenting all parameter mappings
- Keep main_cli_example.py for backward compatibility with migration notice

All default values, LeannBuilder parameters, and chunking settings
remain identical to ensure full compatibility with existing indexes.

* fix: Update CI tests for new unified examples interface

- Rename test_main_cli.py to test_document_rag.py
- Update all references from main_cli_example.py to document_rag.py
- Update tests/README.md documentation

The tests now properly test the new unified interface while maintaining
the same test coverage and functionality.

* fix: Fix pre-commit issues and update tests

- Fix import sorting and unused imports
- Update type annotations to use built-in types (list, dict) instead of typing.List/Dict
- Fix trailing whitespace and end-of-file issues
- Fix Chinese fullwidth comma to regular comma
- Update test_main_cli.py to test_document_rag.py
- Add backward compatibility test for main_cli_example.py
- Pass all pre-commit hooks (ruff, ruff-format, etc.)

* refactor: Remove old example scripts and migration references

- Delete old example scripts (mail_reader_leann.py, google_history_reader_leann.py, etc.)
- Remove migration hints and backward compatibility
- Update tests to use new unified examples directly
- Clean up all references to old script names
- Users now only see the new unified interface

* fix: Restore embedding-mode parameter to all examples

- All examples now have --embedding-mode parameter (unified interface benefit)
- Default is 'sentence-transformers' (consistent with original behavior)
- Users can now use OpenAI or MLX embeddings with any data source
- Maintains functional equivalence with original scripts

* docs: Improve parameter categorization in README

- Clearly separate core (shared) vs specific parameters
- Move LLM and embedding examples to 'Example Commands' section
- Add descriptive comments for all specific parameters
- Keep only truly data-source-specific parameters in specific sections

* docs: Make example commands more representative

- Add default values to parameter descriptions
- Replace generic examples with real-world use cases
- Focus on data-source-specific features in examples
- Remove redundant demonstrations of common parameters

* docs: Reorganize parameter documentation structure

- Move common parameters to a dedicated section before all examples
- Rename sections to 'X-Specific Arguments' for clarity
- Remove duplicate common parameters from individual examples
- Better information architecture for users

* docs: polish applications

* docs: Add CLI installation instructions

- Add two installation options: venv and global uv tool
- Clearly explain when to use each option
- Make CLI more accessible for daily use

* docs: Clarify CLI global installation process

- Explain the transition from venv to global installation
- Add upgrade command for global installation
- Make it clear that global install allows usage without venv activation

* docs: Add collapsible section for CLI installation

- Wrap CLI installation instructions in details/summary tags
- Keep consistent with other collapsible sections in README
- Improve document readability and navigation

* style: format

* docs: Fix collapsible sections

- Make Common Parameters collapsible (as it's lengthy reference material)
- Keep CLI Installation visible (important for users to see immediately)
- Better information hierarchy

* docs: Add introduction for Common Parameters section

- Add 'Flexible Configuration' heading with descriptive sentence
- Create parallel structure with 'Generation Model Setup' section
- Improve document flow and readability

* docs: nit

* fix: Fix issues in unified examples

- Add smart path detection for data directory
- Fix add_texts -> add_text method call
- Handle both running from project root and examples directory

* fix: Fix async/await and add_text issues in unified examples

- Remove incorrect await from chat.ask() calls (not async)
- Fix add_texts -> add_text method calls
- Verify search-complexity correctly maps to efSearch parameter
- All examples now run successfully

* feat: Address review comments

- Add complexity parameter to LeannChat initialization (default: search_complexity)
- Fix chunk-size default in README documentation (256, not 2048)
- Add more index building parameters as CLI arguments:
  - --backend-name (hnsw/diskann)
  - --graph-degree (default: 32)
  - --build-complexity (default: 64)
  - --no-compact (disable compact storage)
  - --no-recompute (disable embedding recomputation)
- Update README to document all new parameters

* feat: Add chunk-size parameters and improve file type filtering

- Add --chunk-size and --chunk-overlap parameters to all RAG examples
- Preserve original default values for each data source:
  - Document: 256/128 (optimized for general documents)
  - Email: 256/25 (smaller overlap for email threads)
  - Browser: 256/128 (standard for web content)
  - WeChat: 192/64 (smaller chunks for chat messages)
- Make --file-types optional filter instead of restriction in document_rag
- Update README to clarify interactive mode and parameter usage
- Fix LLM default model documentation (gpt-4o, not gpt-4o-mini)

* feat: Update documentation based on review feedback

- Add MLX embedding example to README
- Clarify examples/data content description (two papers, Pride and Prejudice, Chinese README)
- Move chunk parameters to common parameters section
- Remove duplicate chunk parameters from document-specific section

* docs: Emphasize diverse data sources in examples/data description

* fix: update default embedding models for better performance

- Change WeChat, Browser, and Email RAG examples to use all-MiniLM-L6-v2
- Previous Qwen/Qwen3-Embedding-0.6B was too slow for these use cases
- all-MiniLM-L6-v2 is a fast 384-dim model, ideal for large-scale personal data

* add response highlight

* change rebuild logic

* fix some example

* feat: check if k is larger than #docs

* fix: WeChat history reader bugs and refactor wechat_rag to use unified architecture

* fix email wrong -1 to process all file

* refactor: reorgnize all examples/ and test/

* refactor: reorganize examples and add link checker

* fix: add init.py

* fix: handle certificate errors in link checker

* fix wechat

* merge

* docs: update README to use proper module imports for apps

- Change from 'python apps/xxx.py' to 'python -m apps.xxx'
- More professional and pythonic module calling
- Ensures proper module resolution and imports
- Better separation between apps/ (production tools) and examples/ (demos)

---------

Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
2025-08-03 23:06:24 -07:00
Andy Lee
54df6310c5 fix: diskann build and prevent termination from hanging
- Fix OpenMP library linking in DiskANN CMake configuration
- Add timeout protection for HuggingFace model loading to prevent hangs
- Improve embedding server process termination with better timeouts
- Make DiskANN backend default enabled alongside HNSW
- Update documentation to reflect both backends included by default
2025-08-03 21:16:52 -07:00
yichuan520030910320
19bcc07814 change readme discription 2025-07-28 20:52:45 -07:00
yichuan520030910320
8356e3c668 changr to openai main cli 2025-07-28 17:39:14 -07:00
GitHub Actions
08eac5c821 chore: release v0.1.16 2025-07-29 00:15:18 +00:00
Andy Lee
4671ed9b36 Fix macos ABI by using system default clang (#11)
* fix: auto-detect normalized embeddings and use cosine distance

- Add automatic detection for normalized embedding models (OpenAI, Voyage AI, Cohere)
- Automatically set distance_metric='cosine' for normalized embeddings
- Add warnings when using non-optimal distance metrics
- Implement manual L2 normalization in HNSW backend (custom Faiss build lacks normalize_L2)
- Fix DiskANN zmq_port compatibility with lazy loading strategy
- Add documentation for normalized embeddings feature

This fixes the low accuracy issue when using OpenAI text-embedding-3-small model with default MIPS metric.

* style: format

* feat: add OpenAI embeddings support to google_history_reader_leann.py

- Add --embedding-model and --embedding-mode arguments
- Support automatic detection of normalized embeddings
- Works correctly with cosine distance for OpenAI embeddings

* feat: add --use-existing-index option to google_history_reader_leann.py

- Allow using existing index without rebuilding
- Useful for testing pre-built indices

* fix: Improve OpenAI embeddings handling in HNSW backend

* fix: improve macOS C++ compatibility and add CI tests

* refactor: improve test structure and fix main_cli example

- Move pytest configuration from pytest.ini to pyproject.toml
- Remove unnecessary run_tests.py script (use test extras instead)
- Fix main_cli_example.py to properly use command line arguments for LLM config
- Add test_readme_examples.py to test code examples from README
- Refactor tests to use pytest fixtures and parametrization
- Update test documentation to reflect new structure
- Set proper environment variables in CI for test execution

* fix: add --distance-metric support to DiskANN embedding server and remove obsolete macOS ABI test markers

- Add --distance-metric parameter to diskann_embedding_server.py for consistency with other backends
- Remove pytest.skip and pytest.xfail markers for macOS C++ ABI issues as they have been fixed
- Fix test assertions to handle SearchResult objects correctly
- All tests now pass on macOS with the C++ ABI compatibility fixes

* chore: update lock file with test dependencies

* docs: remove obsolete C++ ABI compatibility warnings

- Remove outdated macOS C++ compatibility warnings from README
- Simplify CI workflow by removing macOS-specific failure handling
- All tests now pass consistently on macOS after ABI fixes

* fix: update macOS deployment target for DiskANN to 13.3

- DiskANN uses sgesdd_ LAPACK function which is only available on macOS 13.3+
- Update MACOSX_DEPLOYMENT_TARGET from 11.0 to 13.3 for DiskANN builds
- This fixes the compilation error on GitHub Actions macOS runners

* fix: align Python version requirements to 3.9

- Update root project to support Python 3.9, matching subpackages
- Restore macOS Python 3.9 support in CI
- This fixes the CI failure for Python 3.9 environments

* fix: handle MPS memory issues in CI tests

- Use smaller MiniLM-L6-v2 model (384 dimensions) for README tests in CI
- Skip other memory-intensive tests in CI environment
- Add minimal CI tests that don't require model loading
- Set CI environment variable and disable MPS fallback
- Ensure README examples always run correctly in CI

* fix: remove Python 3.10+ dependencies for compatibility

- Comment out llama-index-readers-docling and llama-index-node-parser-docling
- These packages require Python >= 3.10 and were causing CI failures on Python 3.9
- Regenerate uv.lock file to resolve dependency conflicts

* fix: use virtual environment in CI instead of system packages

- uv-managed Python environments don't allow --system installs
- Create and activate virtual environment before installing packages
- Update all CI steps to use the virtual environment

* add some env in ci

* fix: use --find-links to install platform-specific wheels

- Let uv automatically select the correct wheel for the current platform
- Fixes error when trying to install macOS wheels on Linux
- Simplifies the installation logic

* fix: disable OpenMP parallelism in CI to avoid libomp crashes

- Set OMP_NUM_THREADS=1 to avoid OpenMP thread synchronization issues
- Set MKL_NUM_THREADS=1 for single-threaded MKL operations
- This prevents segfaults in LayerNorm on macOS CI runners
- Addresses the libomp compatibility issues with PyTorch on Apple Silicon

* skip several macos test because strange issue on ci

---------

Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
2025-07-28 17:14:42 -07:00
yichuan520030910320
055c086398 add ablation of embedding model compare 2025-07-28 14:43:42 -07:00
Andy Lee
d505dcc5e3 Fix/OpenAI embeddings cosine distance (#10)
* fix: auto-detect normalized embeddings and use cosine distance

- Add automatic detection for normalized embedding models (OpenAI, Voyage AI, Cohere)
- Automatically set distance_metric='cosine' for normalized embeddings
- Add warnings when using non-optimal distance metrics
- Implement manual L2 normalization in HNSW backend (custom Faiss build lacks normalize_L2)
- Fix DiskANN zmq_port compatibility with lazy loading strategy
- Add documentation for normalized embeddings feature

This fixes the low accuracy issue when using OpenAI text-embedding-3-small model with default MIPS metric.

* style: format

* feat: add OpenAI embeddings support to google_history_reader_leann.py

- Add --embedding-model and --embedding-mode arguments
- Support automatic detection of normalized embeddings
- Works correctly with cosine distance for OpenAI embeddings

* feat: add --use-existing-index option to google_history_reader_leann.py

- Allow using existing index without rebuilding
- Useful for testing pre-built indices

* fix: Improve OpenAI embeddings handling in HNSW backend
2025-07-28 14:35:49 -07:00
Andy Lee
261006c36a docs: revert 2025-07-27 22:07:36 -07:00
GitHub Actions
b2eba23e21 chore: release v0.1.15 2025-07-28 05:05:30 +00:00
yichuan520030910320
e9ee687472 nit: fix readme 2025-07-27 21:56:05 -07:00
yichuan520030910320
6f5d5e4a77 fix some readme 2025-07-27 21:50:09 -07:00
Andy Lee
5c8921673a fix: auto-detect normalized embeddings and use cosine distance (#8)
* fix: auto-detect normalized embeddings and use cosine distance

- Add automatic detection for normalized embedding models (OpenAI, Voyage AI, Cohere)
- Automatically set distance_metric='cosine' for normalized embeddings
- Add warnings when using non-optimal distance metrics
- Implement manual L2 normalization in HNSW backend (custom Faiss build lacks normalize_L2)
- Fix DiskANN zmq_port compatibility with lazy loading strategy
- Add documentation for normalized embeddings feature

This fixes the low accuracy issue when using OpenAI text-embedding-3-small model with default MIPS metric.

* style: format
2025-07-27 21:19:29 -07:00
yichuan520030910320
e9d2d420bd fix some readme 2025-07-27 20:48:23 -07:00
yichuan520030910320
ebabfad066 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-27 20:44:36 -07:00
yichuan520030910320
e6f612b5e8 fix install and readme 2025-07-27 20:44:28 -07:00
Andy Lee
51c41acd82 docs: add comprehensive CONTRIBUTING.md guide with pre-commit setup 2025-07-27 20:40:42 -07:00
yichuan520030910320
455f93fb7c fix emaple and add pypi example 2025-07-27 18:20:13 -07:00
yichuan520030910320
48207c3b69 add pypi example 2025-07-27 17:08:49 -07:00
yichuan520030910320
4de1caa40f fix redame install method 2025-07-27 17:00:28 -07:00
yichuan520030910320
60eaa8165c fix precommit and fix redame install method 2025-07-27 16:36:30 -07:00
yichuan520030910320
c1a5d0c624 fix readme 2025-07-27 02:24:28 -07:00
yichuan520030910320
af1790395a fix ruff errors and formatting 2025-07-27 02:22:54 -07:00
yichuan520030910320
383c6d8d7e add clear instructions 2025-07-27 02:19:27 -07:00
yichuan520030910320
bc0d839693 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-27 02:07:41 -07:00
yichuan520030910320
8596562de5 add pip install option to README 2025-07-27 02:06:40 -07:00
GitHub Actions
5d09586853 chore: release v0.1.14 2025-07-27 08:50:56 +00:00
Andy Lee
a7cba078dd chore: consolidate essential fixes and add pre-commit hooks
- Add pre-commit configuration with ruff and black
- Fix lint CI job to use uv tool install instead of sync
- Add essential LlamaIndex dependencies to leann-core

Co-Authored-By: Yichuan Wang <73766326+yichuan-w@users.noreply.github.com>
2025-07-27 01:24:24 -07:00
Andy Lee
b3e9ee96fa fix: resolve all ruff linting errors and add lint CI check
- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments
- Replace Chinese comments with English equivalents
- Fix unused imports with proper noqa annotations for intentional imports
- Fix bare except clauses with specific exception types
- Fix redefined variables and undefined names
- Add ruff noqa annotations for generated protobuf files
- Add lint and format check to GitHub Actions CI pipeline
2025-07-26 22:38:13 -07:00
yichuan520030910320
8537a6b17e default args change 2025-07-26 21:51:14 -07:00
yichuan520030910320
7c8d7dc5c2 tones down 2025-07-26 21:47:55 -07:00
yichuan520030910320
8e23d663e6 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-26 21:46:02 -07:00
yichuan520030910320
8a3994bf80 update colab now it works perfect 2025-07-26 21:45:56 -07:00
GitHub Actions
8375f601ba chore: release v0.1.13 2025-07-27 01:08:17 +00:00
yichuan520030910320
c87c0fe662 update colab install & fix colab path 2025-07-26 18:07:31 -07:00
yichuan520030910320
73927b68ef Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-26 17:09:55 -07:00
yichuan520030910320
cc1a62e5aa update pytoml version again 2025-07-26 17:09:45 -07:00
GitHub Actions
802020cb41 chore: release v0.1.12 2025-07-26 23:35:28 +00:00
yichuan520030910320
cdb92f7cf4 update pytoml version && fix colab env && fix pdf extract in pip 2025-07-26 16:33:13 -07:00
yichuan520030910320
dc69bdec00 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 17:54:43 -07:00
yichuan520030910320
98073e9868 update missing pkg 2025-07-25 17:54:21 -07:00
GitHub Actions
cf2ef48967 chore: release v0.1.11 2025-07-26 00:12:37 +00:00
yichuan520030910320
0692bbf7a2 change workflow 2025-07-25 17:11:56 -07:00
GitHub Actions
52584a171f chore: release v0.1.10 2025-07-25 23:12:16 +00:00
Andy Lee
efd6b5324b fix: add protobuf as a dependency for DiskANN backend
- Fixes 'No module named google' error when starting DiskANN embedding server
- Prevents users from having to manually install protobuf
2025-07-25 16:10:25 -07:00
Andy Lee
2baaa4549b fix: handle relative paths in HNSW embedding server metadata
- Convert relative paths to absolute paths based on metadata file location
- Fixes FileNotFoundError when starting embedding server
- Resolves issue with passages file not found in different working directories
2025-07-25 16:09:53 -07:00
Andy Lee
35310ddd52 fix: pure Python packages not building due to ubuntu-latest check
The build workflow was checking for matrix.os == 'ubuntu-latest',
but we changed the matrix to use 'ubuntu-22.04', causing the
pure Python packages (leann-core and leann) to never be built.

Changed to use pattern matching [[ == ubuntu-* ]] to match any
Ubuntu version.

This explains why v0.1.9 only published the C++ backend packages
but not the pure Python packages.
2025-07-25 15:14:21 -07:00
Andy Lee
fc9c5cb39d fix: make release workflow idempotent
- Check if version is already updated before trying to update
- Check if tag already exists before creating
- Check if GitHub release already exists before creating
- This allows re-running the workflow after partial failures

Previously, if the workflow failed after updating version but before
completing the release, it couldn't be re-run with the same version.
2025-07-25 14:47:35 -07:00
Andy Lee
8f2a1e87ea Merge pull request #7 from yichuan-w/fix/simple-ubuntu22-build
fix: simplify build system for Colab compatibility
2025-07-25 14:08:37 -07:00
Andy Lee
50caf65f28 fix: change ubuntu-latest to ubuntu-22.04 and add Python 3.13
- Explicitly use ubuntu-22.04 instead of ubuntu-latest
- Add Python 3.13 to the build matrix
- This ensures we build on the same OS version as Google Colab
2025-07-25 13:48:59 -07:00
Andy Lee
1b48794ca8 cleanup: remove cibuildwheel workflow files
- Remove ci-cibuildwheel.yml and build-cibuildwheel.yml
- These files were not present in v0.1.5
- Keep only the simple build system
2025-07-25 13:48:08 -07:00
Andy Lee
4aef1d814e revert: simplify build system by removing manylinux/cibuildwheel
- Revert to simple Ubuntu 22.04 builds that should work with Colab
- Remove all manylinux container complexity
- Colab runs on Ubuntu 22.04, so direct builds should be compatible
- Restore build-reusable.yml to v0.1.5 version
- Remove cibuildwheel option from release workflow

This should fix the overcomplicated build issues while maintaining
Colab compatibility through direct Ubuntu 22.04 builds.
2025-07-25 13:46:51 -07:00
GitHub Actions
75ddcd6158 chore: release v0.1.9 2025-07-25 20:04:42 +00:00
Andy Lee
2a4df11f5c fix: absolute path for passages 2025-07-25 11:59:30 -07:00
Andy Lee
5eb893c62b ci: add Python 3.13 support to build matrix 2025-07-25 09:53:36 -07:00
yichuan520030910320
d91ce2e94d readme 2025-07-25 02:19:54 -07:00
yichuan520030910320
5c2ff8a641 clean research stuff 2025-07-25 02:14:15 -07:00
yichuan520030910320
d4f474c9b7 update broken link 2025-07-25 02:13:22 -07:00
yichuan520030910320
170f7644e9 simplify readme 2025-07-25 02:11:02 -07:00
yichuan520030910320
cd8b970eff Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 01:45:57 -07:00
yichuan520030910320
52153bbb69 update faiss compare 2025-07-25 01:45:50 -07:00
GitHub Actions
e1ae087207 chore: release v0.1.8 2025-07-25 08:24:40 +00:00
Andy Lee
48c5e12ac1 fix: use absolute path for passages_file to prevent FileNotFoundError
When embedding server is launched as a subprocess, it may run in a different
working directory. Using absolute paths ensures the server can always find
the metadata file regardless of where it's launched from.
2025-07-25 01:23:47 -07:00
yichuan520030910320
f8b5c97190 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 00:37:33 -07:00
yichuan520030910320
d038c81b8b update benchmard section 2025-07-25 00:37:27 -07:00
Andy Lee
29cbbbd0d6 fix: resolve libzmq pkg-config issues in manylinux containers
- Add gcc-c++ and cmake to dependencies
- Create libzmq.pc file if missing (CentOS 7 issue)
- Set PKG_CONFIG_PATH through CIBW_ENVIRONMENT_LINUX
- Add protobuf-devel to ensure all headers are available
- Fix shell variable escaping in heredoc
2025-07-25 00:35:52 -07:00
Andy Lee
179f30bc36 fix: improve system dependency installation in manylinux containers
- Add yum cache cleaning and updating
- Make package installations more resilient with fallbacks
- Use pkgconfig instead of pkg-config (CentOS 7 naming)
- Handle optional packages that might not be available
- Add error handling for package installation failures
2025-07-25 00:30:29 -07:00
Andy Lee
c4a0a68581 fix: handle pure Python packages in cibuildwheel workflow
- Build pure Python packages (leann-core, leann) with standard build tool
- Only use cibuildwheel for C extension packages (leann-backend-hnsw, leann-backend-diskann)
- Build pure Python packages only once on ubuntu-latest
- Add Python setup for building pure packages
- Add package listing step for debugging
2025-07-25 00:26:15 -07:00
Andy Lee
5c836ad08e fix: handle git dubious ownership error in manylinux containers
- Add multiple safe.directory configurations to cover different possible paths
- This fixes 'detected dubious ownership in repository' error
- Ensures git works properly in manylinux2014 containers
2025-07-25 00:22:01 -07:00
Andy Lee
673fd9b7cd fix: upgrade to actions v4 and handle manylinux2014 compatibility
- Upgrade all GitHub Actions to v4 (v3 is deprecated)
- Use manual git checkout in manylinux2014 containers to avoid Node.js issues
- Update artifact naming to ensure uniqueness (required by v4)
- Add fail-fast: false to build strategies
- This maintains manylinux2014 compatibility while using latest actions
2025-07-25 00:20:21 -07:00
Andy Lee
84b24b233d feat: add cibuildwheel option to release workflow
- Add optional use_cibuildwheel parameter to release workflow
- Create separate CI workflow for testing cibuildwheel
- Support conditional build workflow selection in release process
- This allows building wheels compatible with Google Colab and older systems
- Maintains backward compatibility with existing build process
2025-07-25 00:16:08 -07:00
Andy Lee
499cdd7822 feat: add cibuildwheel workflow for better platform compatibility
- Use cibuildwheel for professional wheel building
- Specifically use manylinux2014 for Google Colab compatibility
- Supports Python 3.9-3.12 on Linux and macOS
- Handles monorepo structure with separate builds per package
- Includes basic import tests for each package
- This should resolve compatibility issues with older systems like Google Colab
2025-07-25 00:16:08 -07:00
yichuan520030910320
800d4cf111 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 00:12:47 -07:00
yichuan520030910320
b6d43f5fd9 add gif 2025-07-25 00:12:35 -07:00
Andy Lee
3603cd5034 fix: downgrade GitHub Actions versions for manylinux2014 compatibility
- Use actions/checkout@v3 instead of v4 (Node.js 16 vs 20)
- Use actions/setup-python@v4 instead of v5
- Use actions/upload-artifact@v3 and download-artifact@v3
- This fixes GLIBC version errors in manylinux2014 containers
- manylinux2014 (CentOS 7) has glibc 2.17 but Node.js 20 needs 2.25+
2025-07-25 00:12:05 -07:00
Andy Lee
6df7893173 feat: use manylinux2014 containers for better Linux compatibility
- Add manylinux2014 Docker containers for Linux builds
- This will generate wheels compatible with older Linux systems (CentOS 7+, Ubuntu 16.04+)
- Separate build logic for container vs regular environments
- Install appropriate system dependencies for yum-based manylinux environment
- Use pip instead of uv in containers for better compatibility
- Fix Python version format for manylinux container paths
2025-07-25 00:08:42 -07:00
GitHub Actions
e64b599276 chore: release v0.1.7 2025-07-25 04:47:57 +00:00
Andy Lee
2dd59c4ba1 fix: let auditwheel auto-detect manylinux platform tag
- Remove --plat manylinux2014_x86_64 flag that was causing build failures
- Let auditwheel automatically determine the appropriate manylinux tag
- Add auditwheel show command to display compatibility info
- This fixes the 'too-recent versioned symbols' error
2025-07-24 21:44:15 -07:00
GitHub Actions
166986d5e6 chore: release v0.1.6 2025-07-25 04:30:07 +00:00
Andy Lee
a6aec68f32 fix: use manylinux2014 for better Linux compatibility
- Change auditwheel --plat to manylinux2014_x86_64
- This ensures wheels work on Ubuntu 16.04+ instead of requiring 24.04+
- Fixes compatibility issues for users on Ubuntu 22.04 and similar systems
2025-07-24 21:26:28 -07:00
GitHub Actions
ed27a127d5 chore: release v0.1.5 2025-07-25 04:00:54 +00:00
Andy Lee
d8b4ea7564 fix: add write permissions for GitHub Actions to push commits 2025-07-24 20:55:24 -07:00
Andy Lee
f0a2ef96b4 fix: restore complete build configuration from working version 2025-07-24 19:49:38 -07:00
Andy Lee
7d73c2c803 fix: remove invalid --extra build flag from build commands 2025-07-24 19:43:23 -07:00
Andy Lee
e8d2ecab03 refactor: use reusable workflow to avoid code duplication 2025-07-24 19:35:12 -07:00
Andy Lee
32a374d094 feat: true one-click automated release with multi-platform support 2025-07-24 19:30:44 -07:00
Andy Lee
d45c013806 fix: handle workflow trigger permission gracefully 2025-07-24 19:25:29 -07:00
GitHub Actions
9000a7083d chore: release v0.1.4 2025-07-25 02:23:36 +00:00
Andy Lee
8307555d54 fix: manually trigger CI after version push in release workflow 2025-07-24 19:21:32 -07:00
GitHub Actions
20f2aece08 chore: release v0.1.3 2025-07-25 02:05:11 +00:00
yichuan520030910320
43eb4f9a1d Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-24 19:03:52 -07:00
yichuan520030910320
5461b71d8c colab dev 2025-07-24 19:03:46 -07:00
Andy Lee
374db0ebb8 fix: release workflow to build new version before publishing 2025-07-24 19:03:09 -07:00
GitHub Actions
cea1f6f87c chore: release v0.1.2 2025-07-25 01:53:29 +00:00
Andy Lee
6c0e39372b fix: download all artifacts in release workflow 2025-07-24 17:45:46 -07:00
Andy Lee
2bec67d2b6 feat: auto-update leann-core dependencies during release
- Enhanced bump_version.sh to automatically update leann-core dependency versions
- Script now updates both package versions and their leann-core dependencies
- This ensures version consistency across all packages during release

No more manual dependency version updates needed!
2025-07-24 17:22:41 -07:00
Andy Lee
133e715832 fix: resolve CI issues and consolidate workflows
- Fix version dependencies: update backend packages to depend on leann-core==0.1.1
- Remove duplicate ci.yml workflow (keeping build-and-publish.yml as main CI)
- Update release-manual.yml to reference correct CI workflow name

This fixes the dependency resolution error and eliminates duplicate builds.
2025-07-24 17:20:58 -07:00
Andy Lee
95cf2f16e2 refactor: consolidate release and publish into single workflow
- Manual Release workflow now directly publishes to PyPI after downloading CI artifacts
- No more duplicate builds - reuses artifacts from CI
- build-and-publish.yml renamed to 'CI - Build Multi-Platform Packages'
- Publishing in CI workflow only for emergency manual triggers
- Updated RELEASE.md to reflect the new streamlined process

This fixes the issue where releases would trigger redundant builds.
2025-07-24 17:04:47 -07:00
Andy Lee
47a4c153eb fix: enable PyPI publish on tag push
- Manual Release workflow creates tags but build-and-publish.yml only published on 'release' events
- Now build-and-publish.yml will also publish when v* tags are pushed
- This fixes the issue where manual releases didn't trigger PyPI uploads
2025-07-24 17:00:21 -07:00
GitHub Actions
faf5ae3533 chore: release v0.1.1 2025-07-24 23:36:23 +00:00
Andy Lee
a44dccecac fix: make TestPyPI upload optional and non-blocking
- Add continue-on-error to TestPyPI step
- Check if TEST_PYPI_API_TOKEN exists before attempting upload
- Add graceful failure handling with clear messages
- Update docs to explain TestPyPI token configuration
- Clarify that TestPyPI testing is optional

Now the release won't fail if TestPyPI is not configured or upload fails
2025-07-24 16:02:07 -07:00
yichuan520030910320
9cf9358b9c Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-24 14:40:39 -07:00
yichuan520030910320
de252fef31 [chat] update 30s example 2025-07-24 14:40:33 -07:00
Andy Lee
9076bc27b8 fix: resolve CI run detection issues in release workflow
- Add 'actions: read' permission to access workflow runs
- Use workflow name instead of filename for gh run list
- Look for CI run on HEAD~1 (before version bump commit)
- Improve error messages for better debugging

Fixes HTTP 403 error when trying to find successful CI runs
2025-07-24 14:27:26 -07:00
Andy Lee
50686c0819 refactor: use CI artifacts in release workflow instead of rebuilding
- Download pre-built wheels from successful CI runs
- Avoids duplicate builds and ensures consistency
- CI artifacts are already tested across all platforms
- Faster release process (no build time)
- Updates release documentation to reflect new flow

This ensures the released packages are exactly what was tested in CI.
2025-07-24 14:24:03 -07:00
Andy Lee
1614203786 fix: make bump_version.sh work on both macOS and Linux
- macOS uses sed -i '' while Linux uses sed -i
- Add OS detection to use correct syntax
- Ensures script works in CI (Linux) and local dev (macOS)
2025-07-24 14:13:31 -07:00
Andy Lee
3d4c75a56c fix: add missing scripts directory to git
- Remove scripts/ from .gitignore
- Add build_and_test.sh for local testing
- Add bump_version.sh for version updates (used by CI)
- Add release.sh and upload_to_pypi.sh for publishing
- Fixes CI error: ./scripts/bump_version.sh: No such file or directory
2025-07-24 14:13:14 -07:00
Andy Lee
2684ee71dc fix: ensure uv build uses correct Python version in CI
- Add --python python flag to uv build commands
- This ensures wheels are built with the correct Python version (cp313 for Python 3.13, etc)
- Fixes issue where Python 3.13 CI was building cp311 wheels
- Also adds Python version verification before build
2025-07-24 13:44:02 -07:00
Andy Lee
1d321953ba ci: update all GitHub Actions to latest versions
- Update actions/upload-artifact from v3 to v4 (v3 deprecated April 2024)
- Update actions/setup-python from v4 to v5 (latest version)
- Add Python 3.12 and 3.13 to CI test matrix
- Ensure compatibility with latest Python versions and GitHub Actions
2025-07-24 13:36:21 -07:00
Andy Lee
b3cb251369 ci: add Python 3.12 and 3.13 to test matrix
- Add Python 3.12 and 3.13 to CI test matrix
- Ensure compatibility with latest Python versions
- Python 3.12 is stable, 3.13 was released in Oct 2024
2025-07-24 13:32:29 -07:00
Andy Lee
0a17d2c9d8 feat: implement comprehensive CI/CD pipeline with two-stage release
- Add ci.yml for continuous integration on every commit
  - Test builds on Ubuntu/macOS with Python 3.9/3.10/3.11
  - Ensure code quality before any release

- Add release-manual.yml for controlled releases
  - Manual trigger prevents accidental releases
  - Version validation and tag creation
  - Optional TestPyPI testing before production
  - Only creates tag after validation passes

- Keep build-and-publish.yml for automated PyPI deployment
  - Triggered by new tags (separation of concerns)
  - Handles multi-platform wheel building
  - Allows retry if PyPI upload fails

- Update RELEASE.md with clear prerequisites and workflow

This setup ensures:
1. Every commit is tested (CI)
2. Releases are deliberate (manual trigger)
3. Failed CI won't create broken tags
4. PyPI publish can be retried independently
2025-07-24 13:29:21 -07:00
Andy Lee
e3defbca84 fix: add minimal CI dependencies for HNSW and DiskANN backends
- HNSW (Ubuntu): add libopenblas-dev for BLAS requirements
- DiskANN (Ubuntu): keep MKL, remove redundant pkg-config (HNSW already has it)
- DiskANN (macOS): add protobuf for build requirements
- Both: ensure patchelf for auditwheel on Linux

This avoids OpenBLAS/MKL conflicts by using them in separate jobs
2025-07-24 01:06:57 -07:00
Andy Lee
e407f63977 chore: fix uv build 2025-07-24 00:51:57 -07:00
Andy Lee
7add391b2c chore: build and package 2025-07-24 00:47:46 -07:00
yichuan520030910320
efd6373b32 [chat] update huggingface chat and make qwen no thinking 2025-07-24 00:11:42 -07:00
yichuan520030910320
d502fa24b0 [installation] update install for linux 2025-07-24 02:17:17 +00:00
yichuan520030910320
258a9a5c7f [misc]test link again 2025-07-23 18:29:32 -07:00
yichuan520030910320
5d41ac6115 test link 2025-07-23 18:28:22 -07:00
yichuan520030910320
2a0fdb49b8 test link 2025-07-23 18:27:08 -07:00
yichuan520030910320
9d1b7231b6 fix broken link 2025-07-23 18:25:22 -07:00
yichuan520030910320
ed3095b478 fix broken link 2025-07-23 18:24:17 -07:00
yichuan520030910320
88eca75917 fix readme 2025-07-23 18:22:10 -07:00
yichuan520030910320
42de27e16a Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-23 18:17:19 -07:00
yichuan520030910320
c083bda5b7 fix several bug 2025-07-23 18:17:11 -07:00
Andy Lee
e86da38726 fix: ollama hint for similar models 2025-07-23 15:45:10 -07:00
yichuan520030910320
99076e38bc update install 2025-07-23 14:55:34 -07:00
yichuan520030910320
9698c1a02c fix readme 2025-07-23 14:52:01 -07:00
yichuan520030910320
851f0f04c3 fix some para 2025-07-23 01:46:34 -07:00
yichuan520030910320
ae16d9d888 fix readme 2025-07-23 00:44:43 -07:00
yichuan520030910320
6e1af2eb0c fix readme 2025-07-23 00:43:46 -07:00
yichuan520030910320
7695dd0d50 fix readme 2025-07-23 00:42:17 -07:00
yichuan520030910320
c2065473ad fix readme 2025-07-23 00:30:42 -07:00
yichuan520030910320
5f3870564d Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-23 00:09:30 -07:00
yichuan520030910320
c214b2e33e fix readme 2025-07-23 00:09:24 -07:00
Andy Lee
2420c5fd35 chore: update sentence-transformer to prevent MixIn not found error 2025-07-22 23:27:25 -07:00
yichuan520030910320
f48f526f0a fix readme 2025-07-22 23:21:15 -07:00
yichuan520030910320
5dd74982ba fix readme 2025-07-22 23:14:31 -07:00
Andy Lee
e07aaf52a7 docs: align 2025-07-22 22:37:27 -07:00
Andy Lee
30e5f12616 docs: quick start 2025-07-22 22:33:04 -07:00
Andy Lee
594427bf87 docs: demo 2025-07-22 22:32:18 -07:00
yichuan520030910320
a97d3ada1c fix readme need to polish example 2025-07-22 22:09:55 -07:00
yichuan520030910320
6217bb5638 fix readme 2025-07-22 22:05:28 -07:00
yichuan520030910320
2760e99e18 fix readme 2025-07-22 22:03:19 -07:00
yichuan520030910320
0544f96b79 default main cli to openai add data dict as a args 2025-07-22 21:56:30 -07:00
yichuan520030910320
2ebb29de65 default main cli to openai 2025-07-22 21:55:18 -07:00
yichuan520030910320
43762d44c7 fix readme 2025-07-22 21:51:30 -07:00
yichuan520030910320
cdaf0c98be fix readme 2025-07-22 21:44:52 -07:00
yichuan520030910320
aa9a14a917 make the email wonderful format 2025-07-22 21:41:58 -07:00
yichuan520030910320
9efcc6d95c Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-22 20:44:02 -07:00
yichuan520030910320
f3f5d91207 make the google history wonderful format 2025-07-22 20:43:56 -07:00
Andy Lee
6070160959 chore: remove .vscode 2025-07-22 19:59:35 -07:00
Andy Lee
43155d2811 fix: supress resources leak logs 2025-07-22 19:53:45 -07:00
Andy Lee
d3f85678ec perf: much faster loading and embedding serving 2025-07-22 19:38:22 -07:00
yichuan520030910320
2a96d05b21 upd readme 2025-07-22 17:06:33 -07:00
yichuan520030910320
851e888535 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-22 17:01:04 -07:00
yichuan520030910320
90120d4dff upd the structure in the chat for better perf 2025-07-22 17:00:56 -07:00
Andy Lee
8513471573 feat: make diskann runnable 2025-07-22 14:26:03 -07:00
Andy Lee
71e5f1774c docs: cli 2025-07-21 23:48:40 -07:00
yichuan520030910320
870a443446 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 23:13:45 -07:00
yichuan520030910320
cefaa2a4cc upd readme 2025-07-21 23:13:38 -07:00
Andy Lee
ab72a2ab9d fix: more logs 2025-07-21 23:08:53 -07:00
yichuan520030910320
046d457d22 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 23:04:00 -07:00
yichuan520030910320
7fd0a30fee upd log 2025-07-21 23:03:52 -07:00
Andy Lee
c2f35c8e73 fix: logs 2025-07-21 23:02:13 -07:00
Andy Lee
573313f0b6 refactor: logs 2025-07-21 22:45:24 -07:00
yichuan520030910320
f7af6805fa readme 2025-07-21 22:33:03 -07:00
yichuan520030910320
966de3a399 readme 2025-07-21 22:32:02 -07:00
yichuan520030910320
8a75829f3a readme 2025-07-21 22:30:03 -07:00
yichuan520030910320
0f7e34b9e2 readme 2025-07-21 22:18:00 -07:00
yichuan520030910320
be0322b616 readme 2025-07-21 22:16:52 -07:00
yichuan520030910320
232a525a62 readme 2025-07-21 22:14:43 -07:00
yichuan520030910320
587ce65cf6 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 21:54:27 -07:00
yichuan520030910320
ccf6c8bfd7 fix flush print 2025-07-21 21:54:20 -07:00
Andy Lee
c112956d2d fix: mlx 2025-07-21 21:29:15 -07:00
Andy Lee
b3970793cf fix: cache the loaded model 2025-07-21 21:20:53 -07:00
yichuan520030910320
727724990e add todo 2025-07-21 20:59:09 -07:00
yichuan520030910320
530f6e4af5 add progress bar in build 2025-07-21 20:55:18 -07:00
Andy Lee
2f224f5793 fix: use server to emb query only when recompute 2025-07-21 20:40:21 -07:00
Andy Lee
1b6272ce0e Building, CLI tool & Embedding Server Fixed (#5)
* chore: shorter build time

* chore: update faiss

* fix: no longger do embedding server reuse

* fix: do not reuse emb_server and close it properly

* feat: cli tool

* feat: cli more args

* fix: same embedding logic
2025-07-21 20:17:25 -07:00
yichuan520030910320
5259ace111 [Readme] 2025-07-21 20:06:21 -07:00
yichuan520030910320
48ea5566e9 [Readme] detail number 2025-07-21 19:51:51 -07:00
yichuan520030910320
3f8b6c5bbd [Readme] 2025-07-21 18:15:00 -07:00
yichuan520030910320
725b32e74f [Readme] 2025-07-21 17:57:44 -07:00
yichuan520030910320
d0b71f393f [Readme] 2025-07-21 17:56:10 -07:00
yichuan520030910320
8a92efdae3 [Readme] 2025-07-21 17:53:59 -07:00
yichuan520030910320
019cdce2e8 [Readme] 2025-07-21 17:30:11 -07:00
yichuan520030910320
b64aa54fac fix break link 2025-07-21 17:29:35 -07:00
yichuan520030910320
c0d040f9d4 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 16:22:24 -07:00
yichuan520030910320
32364320f8 update wechat and we should fix the bug introduced in 1c5fec5 2025-07-21 16:22:16 -07:00
Andy Lee
34c71c072d chore: parallel compile fix 2025-07-19 22:51:47 -07:00
Andy Lee
6d2149c503 chore: parallel compile fix 2025-07-19 22:46:24 -07:00
Andy Lee
043b0bf69d chore: parallel compile fix 2025-07-19 22:34:19 -07:00
Andy Lee
9b07e392c6 chore: parallel compile 2025-07-19 22:32:13 -07:00
Andy Lee
e60fad8c73 chore: mark diskann as optional 2025-07-19 22:24:44 -07:00
Andy Lee
19c1b182c3 docs: effects figure 2025-07-19 22:07:04 -07:00
Andy Lee
49edea780c docs: figure 2025-07-19 21:59:58 -07:00
Andy Lee
12ef5a1900 docs: effects 2025-07-19 21:57:12 -07:00
Andy Lee
d21a134b2a docs: polish 2025-07-19 21:53:41 -07:00
Andy Lee
1cd809aa41 [Docs] README polished version (#4)
* docs: polish

* docs: logo

* docs: logo

* docs: logo with text

* docs: readme effects

* docs: polish

* docs: highlight applications

* docs: polish

* docs: how it works earlier

* docs: polish

* docs: polish

* docs: follow yichuan's suggestion

* docs: follow yichuan's suggestion

---------

Co-authored-by: Yichuan Wang <73766326+yichuan-w@users.noreply.github.com>
2025-07-19 21:47:25 -07:00
yichuan520030910320
e728449b8f change chinese 2025-07-19 19:54:02 -07:00
yichuan520030910320
d0c20b14d5 clear output pf ipynb 2025-07-19 19:48:56 -07:00
yichuan520030910320
83b7ea5a59 change wecaht app split logic& merge 2025-07-19 19:44:33 -07:00
yichuan520030910320
0796a52df1 change wecaht app split logic 2025-07-19 19:43:30 -07:00
Andy Lee
85b7ba0168 feat: allow build from existed embeddings 2025-07-19 01:27:37 -07:00
yichuan520030910320
e117743d24 Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-17 22:29:39 -07:00
yichuan520030910320
aec2291f04 add embedding api 2025-07-17 22:29:31 -07:00
yichuan520030910320
335ae003ac add data 2025-07-17 22:29:03 -07:00
Andy Lee
71c7de9c84 fix: build with direct embedding 2025-07-17 21:49:36 -07:00
Andy Lee
1c5fec5565 perf: make embedder loading faster by 6x, and embed queries through the server 2025-07-17 20:08:06 -07:00
yichuan520030910320
99d439577d Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-17 18:15:27 -07:00
yichuan520030910320
4f83086788 update readme and auto find email 2025-07-17 18:15:17 -07:00
Andy Lee
a13c527e39 feat: openai embeddings 2025-07-17 17:02:47 -07:00
yichuan520030910320
90d9f27383 update readme and main example 2025-07-17 15:03:22 -07:00
yichuan520030910320
0db81c16cd update readme and chrome example 2025-07-17 12:58:11 -07:00
yichuan520030910320
e115e186b7 update example and more stats on result 2025-07-16 22:07:15 -07:00
yichuan520030910320
6546b29ef7 update readme 2025-07-16 20:29:45 -07:00
yichuan520030910320
51255bdffa update readme and add timer 2025-07-16 17:15:51 -07:00
Andy Lee
f77c4e38cb perf: reuse embedding server for query embed 2025-07-16 16:12:15 -07:00
Andy Lee
2a1a152073 refactor: nits 2025-07-16 15:39:58 -07:00
Andy Lee
7b9406a3ea feat: different search_args and docstrings 2025-07-16 15:25:58 -07:00
Andy Lee
c3fb949693 docs: ollama 2025-07-16 15:12:37 -07:00
yichuan520030910320
ed3f8dbfd6 update readme 2025-07-15 23:32:25 -07:00
yichuan520030910320
42aa6db170 update readme 2025-07-15 23:23:04 -07:00
yichuan520030910320
a6591d20ca Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-15 23:18:08 -07:00
yichuan520030910320
c1bc2603a2 update readme and 30 seconds example 2025-07-15 23:18:01 -07:00
Andy Lee
e595bbb5fb feat: hint for users about wrong model name 2025-07-15 22:40:40 -07:00
yichuan520030910320
4a2cb914d7 clean dict 2025-07-15 22:30:52 -07:00
yichuan520030910320
b1c93fe178 Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-15 22:29:09 -07:00
yichuan520030910320
0719458775 upd readme stats 2025-07-15 22:28:59 -07:00
Andy Lee
6a1dc895fb feat: disable warmup by default 2025-07-15 22:16:02 -07:00
Andy Lee
125c1f6f25 fix: model name 2025-07-15 21:48:45 -07:00
yichuan520030910320
1ceaa7d709 Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-15 21:19:25 -07:00
yichuan520030910320
dec3ee85fd fix main cli 2025-07-15 21:19:16 -07:00
Andy Lee
d94a5176dc docs: storage reduction data 2025-07-15 15:37:17 -07:00
yichuan520030910320
326783f7f1 fix mem compare fix split 2025-07-14 23:07:46 -07:00
yichuan520030910320
e5a9ca8787 fix mem compare 2025-07-14 22:55:10 -07:00
Andy Lee
f2feccdbd0 fix: mem compare 2025-07-14 16:35:08 -07:00
yichuan520030910320
246a077d64 upd readme 2025-07-14 16:21:34 -07:00
yichuan520030910320
3ba100ff25 upd readme 2025-07-14 16:18:39 -07:00
yichuan520030910320
1e3b571e72 add readme bench 2025-07-14 16:13:21 -07:00
Andy Lee
b89e56e9c2 fix: file name 2025-07-14 15:34:56 -07:00
yichuan520030910320
ed8a02e721 update readme and mlx support 2025-07-14 15:23:56 -07:00
Andy Lee
baa60b40d1 fix: smaller warmup id 2025-07-14 15:20:45 -07:00
Andy Lee
ef01d6997a fix: faiss only 2025-07-14 13:15:55 -07:00
Andy Lee
3da5b44d7f fix: mlx when searching, added to embedding_server 2025-07-14 01:11:21 -07:00
Andy Lee
8b4654921b fix: run faiss in subprocess to prevent kmp 2025-07-14 00:29:21 -07:00
yichuan520030910320
cf1cbafa78 Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-13 23:19:54 -07:00
yichuan520030910320
c96091744b update readme 2025-07-13 23:19:44 -07:00
Andy Lee
711fb4a775 feat: compare faiss 2025-07-13 22:44:16 -07:00
Andy Lee
3b5a185e60 refactor: check if current emb_server has correct passages/embedder 2025-07-13 22:43:51 -07:00
yichuan520030910320
77ac013a74 update readem 2025-07-13 22:37:41 -07:00
yichuan520030910320
b8e5728e6a fix wechat application 2025-07-13 22:29:54 -07:00
yichuan520030910320
d038319d8b upd readme wechat application 2025-07-13 22:00:49 -07:00
yichuan520030910320
c611d0f30f upd readme mail application 2025-07-13 21:48:57 -07:00
yichuan520030910320
c17899662f upd readme mail application 2025-07-13 21:30:08 -07:00
yichuan520030910320
c51d5320fa upd test/mlx 2025-07-13 20:16:02 -07:00
yichuan520030910320
6fa9512a64 fix wechat path 2025-07-13 18:23:31 -07:00
Andy Lee
fddc61df5e chore: reset to latest version 2025-07-13 17:06:48 -07:00
Andy Lee
53c58fa755 perf: switch to tranditional pdf reader 2025-07-13 17:04:23 -07:00
yichuan520030910320
c69afb56e4 Resolve submodule conflict - update to af2a264 2025-07-13 17:03:42 -07:00
yichuan520030910320
0fa8a9191f add wechat history extract app 2025-07-13 16:52:08 -07:00
Andy Lee
48dda1cb5b feat: mlx 2025-07-13 02:13:04 -07:00
Andy Lee
71ef4b7d4c fix: reproducible dpr on mac 2025-07-12 18:13:22 -07:00
Andy Lee
ecab43e307 feat: dataset for evaluation 2025-07-12 23:43:10 +00:00
Fangzhou66
88ca09440d fix some hf problem 2025-07-12 16:13:15 -07:00
Andy Lee
8e0ab4a28d chore: update deps 2025-07-12 22:48:13 +00:00
yichuan520030910320
9b8c5041dc update readme 2025-07-12 13:01:11 -07:00
yichuan520030910320
74ffd7ec64 add email test code 2025-07-11 23:59:47 -07:00
Andy Lee
eb6f504789 Datastore reproduce (#3)
* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann

* docs: embedding pruning

* refactor: passage structure

* feat: reproducible research datas, rpj_wiki & dpr

* refactor: chat and base searcher

* feat: chat on mps
2025-07-11 23:37:23 -07:00
yichuan520030910320
91a026f38b polish readme 2025-07-11 23:06:08 -07:00
yichuan520030910320
595138a0a3 upd readme 2025-07-11 22:43:48 -07:00
yichuan520030910320
19df04095f add readme 2025-07-11 22:34:54 -07:00
yichuan520030910320
8239bbb48f add google hostory api 2025-07-11 21:21:36 -07:00
yichuan520030910320
16ee9d0422 add traverse all dict interface 2025-07-10 15:59:16 -07:00
yichuan520030910320
8a961f8ab3 align the llamaindex result w leann& test attachment 2025-07-09 21:42:15 -07:00
yichuan520030910320
558126c46e add leann and llamaindex email infra, and need to align the results 2025-07-09 16:27:11 -07:00
yichuan520030910320
04c9684488 add email test code 2025-07-09 15:06:31 -07:00
Andy Lee
b744faa7e6 chore: all deps 2025-07-08 23:37:40 +00:00
Andy Lee
27b3a26e75 fix(deps): Update DiskANN with cleaned up CMake configuration 2025-07-08 23:27:05 +00:00
Andy Lee
41d872504e feat(deps): Update DiskANN to use system-installed Boost and Protobuf 2025-07-08 23:13:36 +00:00
Andy Lee
963cd05273 chore: diskann modules 2025-07-08 21:57:38 +00:00
Andy Lee
09b6e67baf chore: diskann upg boost 2025-07-08 21:44:44 +00:00
yichuan520030910320
dafb2aacab update macos env 2025-07-08 14:37:41 -07:00
Andy Lee
a6c400cd4f chroe: linux boost and protobuf 2025-07-08 21:25:43 +00:00
Andy Lee
c013e5ccce chore: linux deps 2025-07-08 13:55:39 -07:00
Andy Lee
f25a1a3840 chore: macos compatible 2025-07-08 13:32:00 -07:00
yichuan520030910320
6497e17671 add gpu chunk embedd and add complexity in hnsw 2025-07-08 18:40:52 +00:00
yichuan520030910320
44369a8138 update diskann module 2025-07-07 18:27:07 -07:00
yichuan520030910320
dfca00c21b add mac support in this repo 2025-07-07 18:22:24 -07:00
yichuan520030910320
637dab379e add workaround code 2025-07-07 23:13:47 +00:00
yichuan520030910320
6fc57eb48e add reuse code 2025-07-07 21:07:00 +00:00
yichuan520030910320
95a653993a rm useless 2025-07-06 06:47:20 +00:00
yichuan520030910320
af0959818d rm useless 2025-07-06 05:21:05 +00:00
Andy Lee
cf17c85607 Make DiskANN and HNSW work on main example (#2)
* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann
2025-07-05 22:18:12 -07:00
Andy Lee
a38bc0a3fc refactor: embedding server manager 2025-07-06 01:54:46 +00:00
yichuan
449983c937 Merge pull request #1 from yichuan520030910320/debug_diskann_disable_pipe
debug_diskann_disable_pipe
2025-07-05 17:55:27 -07:00
217 changed files with 37711 additions and 27369 deletions

1
.gitattributes vendored
View File

@@ -1 +0,0 @@
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text

12
.github/workflows/build-and-publish.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
jobs:
build:
uses: ./.github/workflows/build-reusable.yml

358
.github/workflows/build-reusable.yml vendored Normal file
View File

@@ -0,0 +1,358 @@
name: Reusable Build
on:
workflow_call:
inputs:
ref:
description: 'Git ref to build'
required: false
type: string
default: ''
jobs:
lint:
name: Lint and Format Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install ruff
run: |
uv tool install ruff
- name: Run ruff check
run: |
ruff check .
- name: Run ruff format check
run: |
ruff format --check .
build:
needs: lint
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
strategy:
matrix:
include:
- os: ubuntu-22.04
python: '3.9'
- os: ubuntu-22.04
python: '3.10'
- os: ubuntu-22.04
python: '3.11'
- os: ubuntu-22.04
python: '3.12'
- os: ubuntu-22.04
python: '3.13'
- os: macos-14
python: '3.9'
- os: macos-14
python: '3.10'
- os: macos-14
python: '3.11'
- os: macos-14
python: '3.12'
- os: macos-14
python: '3.13'
- os: macos-15
python: '3.9'
- os: macos-15
python: '3.10'
- os: macos-15
python: '3.11'
- os: macos-15
python: '3.12'
- os: macos-15
python: '3.13'
- os: macos-13
python: '3.9'
- os: macos-13
python: '3.10'
- os: macos-13
python: '3.11'
- os: macos-13
python: '3.12'
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v5
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
patchelf
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
# Don't install LLVM, use system clang for better compatibility
brew install libomp boost protobuf zeromq
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
else
uv pip install --system delocate
fi
- name: Set macOS environment variables
if: runner.os == 'macOS'
run: |
# Use brew --prefix to automatically detect Homebrew installation path
HOMEBREW_PREFIX=$(brew --prefix)
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
# Set compiler flags for OpenMP (required for both backends)
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
- name: Build packages
run: |
# Build core (platform independent)
cd packages/leann-core
uv build
cd ../..
# Build HNSW backend
cd packages/leann-backend-hnsw
if [[ "${{ matrix.os }}" == macos-* ]]; then
# Use system clang for better compatibility
export CC=clang
export CXX=clang++
# Homebrew libraries on each macOS version require matching minimum version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.0
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
fi
cd ../..
# Build DiskANN backend
cd packages/leann-backend-diskann
if [[ "${{ matrix.os }}" == macos-* ]]; then
# Use system clang for better compatibility
export CC=clang
export CXX=clang++
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
# But Homebrew libraries on each macOS version require matching minimum version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
fi
cd ../..
# Build meta package (platform independent)
cd packages/leann
uv build
cd ../..
- name: Repair wheels (Linux)
if: runner.os == 'Linux'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Determine deployment target based on runner OS
# Must match the Homebrew libraries for each macOS version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
HNSW_TARGET="13.0"
DISKANN_TARGET="13.3"
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
HNSW_TARGET="15.0"
DISKANN_TARGET="15.0"
fi
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: List built packages
run: |
echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Install built packages for testing
run: |
# Create a virtual environment with the correct Python version
uv venv --python ${{ matrix.python }}
source .venv/bin/activate || source .venv/Scripts/activate
# Install packages using --find-links to prioritize local builds
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
# Install test dependencies using extras
uv pip install -e ".[test]"
- name: Run tests with pytest
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HF_HUB_DISABLE_SYMLINKS: 1
TOKENIZERS_PARALLELISM: false
PYTORCH_ENABLE_MPS_FALLBACK: 0
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
run: |
source .venv/bin/activate || source .venv/Scripts/activate
pytest tests/ -v --tb=short
- name: Run sanity checks (optional)
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
# Run distance function tests if available
if [ -f test/sanity_checks/test_distance_functions.py ]; then
echo "Running distance function sanity checks..."
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
fi
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/
arch-smoke:
name: Arch Linux smoke test (install & import)
needs: build
runs-on: ubuntu-latest
container:
image: archlinux:latest
steps:
- name: Prepare system
run: |
pacman -Syu --noconfirm
pacman -S --noconfirm python python-pip gcc git zlib openssl
- name: Download ALL wheel artifacts from this run
uses: actions/download-artifact@v5
with:
# Don't specify name, download all artifacts
path: ./wheels
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Create virtual environment and install wheels
run: |
uv venv
source .venv/bin/activate || source .venv/Scripts/activate
uv pip install --find-links wheels leann-core
uv pip install --find-links wheels leann-backend-hnsw
uv pip install --find-links wheels leann-backend-diskann
uv pip install --find-links wheels leann
- name: Import & tiny runtime check
env:
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
run: |
source .venv/bin/activate || source .venv/Scripts/activate
python - <<'PY'
import leann
import leann_backend_hnsw as h
import leann_backend_diskann as d
from leann import LeannBuilder, LeannSearcher
b = LeannBuilder(backend_name="hnsw")
b.add_text("hello arch")
b.build_index("arch_demo.leann")
s = LeannSearcher("arch_demo.leann")
print("search:", s.search("hello", top_k=1))
PY

19
.github/workflows/link-check.yml vendored Normal file
View File

@@ -0,0 +1,19 @@
name: Link Check
on:
push:
branches: [ main, master ]
pull_request:
schedule:
- cron: "0 3 * * 1"
jobs:
link-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2
with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

129
.github/workflows/release-manual.yml vendored Normal file
View File

@@ -0,0 +1,129 @@
name: Release
on:
workflow_dispatch:
inputs:
version:
description: 'Version to release (e.g., 0.1.2)'
required: true
type: string
jobs:
update-version:
name: Update Version
runs-on: ubuntu-latest
permissions:
contents: write
outputs:
commit-sha: ${{ steps.push.outputs.commit-sha }}
steps:
- uses: actions/checkout@v4
- name: Validate version
run: |
# Remove 'v' prefix if present for validation
VERSION_CLEAN="${{ inputs.version }}"
VERSION_CLEAN="${VERSION_CLEAN#v}"
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
exit 1
fi
echo "✅ Version format valid: ${{ inputs.version }}"
- name: Update versions and push
id: push
run: |
# Check current version
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
echo "Current version: $CURRENT_VERSION"
echo "Target version: ${{ inputs.version }}"
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
COMMIT_SHA=$(git rev-parse HEAD)
else
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
git push origin main
COMMIT_SHA=$(git rev-parse HEAD)
echo "✅ Pushed version update: $COMMIT_SHA"
fi
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
build-packages:
name: Build packages
needs: update-version
uses: ./.github/workflows/build-reusable.yml
with:
ref: 'main'
publish:
name: Publish and Release
needs: [update-version, build-packages]
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
ref: 'main'
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: dist-artifacts
- name: Collect packages
run: |
mkdir -p dist
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
echo "📦 Packages to publish:"
ls -la dist/
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
if [ -z "$TWINE_PASSWORD" ]; then
echo "❌ PYPI_API_TOKEN not configured!"
exit 1
fi
pip install twine
twine upload dist/* --skip-existing --verbose
echo "✅ Published to PyPI!"
- name: Create release
run: |
# Check if tag already exists
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
else
git tag "v${{ inputs.version }}"
git push origin "v${{ inputs.version }}"
echo "✅ Created and pushed tag v${{ inputs.version }}"
fi
# Check if release already exists
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
else
gh release create "v${{ inputs.version }}" \
--title "Release v${{ inputs.version }}" \
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
--latest
echo "✅ Created GitHub release v${{ inputs.version }}"
fi
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

33
.gitignore vendored
View File

@@ -8,11 +8,17 @@ demo/indices/
*pycache* *pycache*
outputs/ outputs/
*.pkl *.pkl
*.pdf
*.idx
*.map
.history/ .history/
scripts/
lm_eval.egg-info/ lm_eval.egg-info/
demo/experiment_results/**/*.json demo/experiment_results/**/*.json
*.jsonl *.jsonl
*.eml
*.emlx
*.json
!.vscode/*.json
*.sh *.sh
*.txt *.txt
!CMakeLists.txt !CMakeLists.txt
@@ -29,7 +35,15 @@ build/
nprobe_logs/ nprobe_logs/
micro/results micro/results
micro/contriever-INT8 micro/contriever-INT8
examples/data/ data/*
!data/2501.14312v1 (1).pdf
!data/2506.08276v1.pdf
!data/PrideandPrejudice.txt
!data/huawei_pangu.md
!data/ground_truth/
!data/indices/
!data/queries/
!data/.gitattributes
*.qdstrm *.qdstrm
benchmark_results/ benchmark_results/
results/ results/
@@ -42,6 +56,7 @@ embedding_comparison_results/
*.ivecs *.ivecs
*.index *.index
*.bin *.bin
*.old
read_graph read_graph
analyze_diskann_graph analyze_diskann_graph
@@ -71,3 +86,17 @@ test_indices*/
test_*.py test_*.py
!tests/** !tests/**
packages/leann-backend-diskann/third_party/DiskANN/_deps/ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json
batchtest.py
tests/__pytest_cache__/
tests/__pycache__/
paru-bin/
CLAUDE.md
CLAUDE.local.md
.claude/*.local.*
.claude/local/*
benchmarks/data/

14
.gitmodules vendored
View File

@@ -1,6 +1,16 @@
[submodule "packages/leann-backend-diskann/third_party/DiskANN"] [submodule "packages/leann-backend-diskann/third_party/DiskANN"]
path = packages/leann-backend-diskann/third_party/DiskANN path = packages/leann-backend-diskann/third_party/DiskANN
url = https://github.com/yichuan520030910320/DiskANN.git url = https://github.com/yichuan-w/DiskANN.git
[submodule "packages/leann-backend-hnsw/third_party/faiss"] [submodule "packages/leann-backend-hnsw/third_party/faiss"]
path = packages/leann-backend-hnsw/third_party/faiss path = packages/leann-backend-hnsw/third_party/faiss
url = https://github.com/yichuan520030910320/faiss.git url = https://github.com/yichuan-w/faiss.git
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
path = packages/leann-backend-hnsw/third_party/msgpack-c
url = https://github.com/msgpack/msgpack-c.git
branch = cpp_master
[submodule "packages/leann-backend-hnsw/third_party/cppzmq"]
path = packages/leann-backend-hnsw/third_party/cppzmq
url = https://github.com/zeromq/cppzmq.git
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git

17
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
- id: debug-statements
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7 # Fixed version to match pyproject.toml
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

View File

@@ -1,9 +1,5 @@
{ {
"recommendations": [ "recommendations": [
"llvm-vs-code-extensions.vscode-clangd", "charliermarsh.ruff",
"ms-python.python",
"ms-vscode.cmake-tools",
"vadimcn.vscode-lldb",
"eamodio.gitlens",
] ]
} }

283
.vscode/launch.json vendored
View File

@@ -1,283 +0,0 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
// new emdedder
{
"name": "New Embedder",
"type": "debugpy",
"request": "launch",
"program": "demo/main.py",
"console": "integratedTerminal",
"args": [
"--search",
"--use-original",
"--domain",
"dpr",
"--nprobe",
"5000",
"--load",
"flat",
"--embedder",
"intfloat/multilingual-e5-small"
]
}
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
{
"name": "main.py",
"type": "debugpy",
"request": "launch",
"program": "demo/main.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--query",
"1000",
"--load",
"bm25"
]
},
{
"name": "Simple Build",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"faiss/demo/simple_build.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
//# Fix for Intel MKL error
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
//python faiss/demo/build_demo.py
{
"name": "Build Demo",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"faiss/demo/build_demo.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
{
"name": "DiskANN Serve",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--mode",
"serve",
"--engine",
"sglang",
"--load-indices",
"diskann",
"--domain",
"rpj_wiki",
"--lazy-load",
"--recompute-beighbor-embeddings",
"--port",
"8082",
"--diskann-search-memory-maximum",
"2",
"--diskann-graph",
"240",
"--search-only"
],
"env": {
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
},
"preLaunchTask": "CMake: build",
},
{
"name": "DiskANN Serve MAC",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--mode",
"serve",
"--engine",
"ollama",
"--load-indices",
"diskann",
"--domain",
"rpj_wiki",
"--lazy-load",
"--recompute-beighbor-embeddings"
],
"preLaunchTask": "CMake: build",
"env": {
"KMP_DUPLICATE_LIB_OK": "TRUE",
"OMP_NUM_THREADS": "1",
"MKL_NUM_THREADS": "1",
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
"KMP_BLOCKTIME": "0"
}
},
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "ric/main_ric.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--config-name",
"${input:configSelection}"
],
"justMyCode": false
},
//python ./demo/validate_equivalence.py sglang
{
"name": "Validate Equivalence",
"type": "debugpy",
"request": "launch",
"program": "demo/validate_equivalence.py",
"console": "integratedTerminal",
"args": [
"sglang"
],
},
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
{
"name": "Retrieval Demo",
"type": "debugpy",
"request": "launch",
"program": "demo/retrieval_demo.py",
"console": "integratedTerminal",
"args": [
"--engine",
"vllm",
"--skip-embeddings",
"--domain",
"dpr",
"--load-indices",
// "flat",
"ivf_flat"
],
},
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
{
"name": "Retrieval Demo DiskANN",
"type": "debugpy",
"request": "launch",
"program": "demo/retrieval_demo.py",
"console": "integratedTerminal",
"args": [
"--engine",
"sglang",
"--skip-embeddings",
"--domain",
"dpr",
"--load-indices",
"diskann",
"--hnsw-M",
"64",
"--hnsw-efConstruction",
"150",
"--hnsw-efSearch",
"128",
"--hnsw-sq-bits",
"8"
],
},
{
"name": "Find Probe",
"type": "debugpy",
"request": "launch",
"program": "find_probe.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
},
{
"name": "Python: Attach",
"type": "debugpy",
"request": "attach",
"processId": "${command:pickProcess}",
"justMyCode": true
},
{
"name": "Edge RAG",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"edgerag_demo.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
"MKL_NUM_THREADS": "1",
"OMP_NUM_THREADS": "1",
}
},
{
"name": "Launch Embedding Server",
"type": "debugpy",
"request": "launch",
"program": "demo/embedding_server.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--domain",
"rpj_wiki",
"--zmq-port",
"5556",
]
},
{
"name": "HNSW Serve",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--domain",
"rpj_wiki",
"--load",
"hnsw",
"--mode",
"serve",
"--search",
"--skip-pa",
"--recompute",
"--hnsw-old"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
],
"inputs": [
{
"id": "configSelection",
"type": "pickString",
"description": "Select a configuration",
"options": [
"example_config",
"vllm_gritlm"
],
"default": "example_config"
}
],
}

57
.vscode/settings.json vendored Executable file → Normal file
View File

@@ -1,43 +1,22 @@
{ {
"python.analysis.extraPaths": [ "python.defaultInterpreterPath": ".venv/bin/python",
"./sglang_repo/python" "python.terminal.activateEnvironment": true,
], "[python]": {
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN", "editor.defaultFormatter": "charliermarsh.ruff",
"cmake.configureArgs": [ "editor.formatOnSave": true,
"-DPYBIND=True", "editor.codeActionsOnSave": {
"-DUPDATE_EDITABLE_INSTALL=ON", "source.organizeImports": "explicit",
], "source.fixAll": "explicit"
"cmake.environment": {
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
}, },
"cmake.buildDirectory": "${workspaceFolder}/build", "editor.insertSpaces": true,
"files.associations": { "editor.tabSize": 4
"*.tcc": "cpp",
"deque": "cpp",
"string": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"map": "cpp",
"unordered_set": "cpp",
"atomic": "cpp",
"inplace_vector": "cpp",
"*.ipp": "cpp",
"forward_list": "cpp",
"list": "cpp",
"any": "cpp",
"system_error": "cpp",
"__hash_table": "cpp",
"__split_buffer": "cpp",
"__tree": "cpp",
"ios": "cpp",
"set": "cpp",
"__string": "cpp",
"string_view": "cpp",
"ranges": "cpp",
"iosfwd": "cpp"
}, },
"lldb.displayFormat": "auto", "ruff.enable": true,
"lldb.showDisassembly": "auto", "files.watcherExclude": {
"lldb.dereferencePointers": true, "**/.venv/**": true,
"lldb.consoleMode": "commands", "**/__pycache__/**": true,
"**/*.egg-info/**": true,
"**/build/**": true,
"**/dist/**": true
}
} }

16
.vscode/tasks.json vendored
View File

@@ -1,16 +0,0 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "cmake",
"label": "CMake: build",
"command": "build",
"targets": [
"all"
],
"group": "build",
"problemMatcher": [],
"detail": "CMake template build task"
}
]
}

View File

@@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2024 Rulin Shao Copyright (c) 2025 LEANN Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

924
README.md
View File

@@ -1,172 +1,704 @@
# 🚀 LEANN: A Low-Storage Vector Index <p align="center">
<img src="assets/logo-text.png" alt="LEANN Logo" width="400">
</p>
<p align="center"> <p align="center">
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+"> <img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License"> <img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome"> <img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows-lightgrey" alt="Platform"> <a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
</p> </p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
The smallest vector index in the world. RAG Everything with LEANN!
</h2>
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
## Why LEANN?
<p align="center"> <p align="center">
<strong>⚡ Real-time embedding computation for large-scale RAG on consumer hardware</strong> <img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
</p> </p>
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#-storage-comparison)
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
## Installation
### 📦 Prerequisites: Install uv
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
### 🚀 Quick Install
Clone the repository to access all examples and try amazing applications,
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
```
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
```bash
uv venv
source .venv/bin/activate
uv pip install leann
```
<!--
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
<details>
<summary>
<strong>🔧 Build from Source (Recommended for development)</strong>
</summary>
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
```
**macOS:**
Note: DiskANN requires MacOS 13.3 or later.
```bash
brew install libomp boost protobuf zeromq pkgconf
uv sync --extra diskann
```
**Linux (Ubuntu/Debian):**
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
```bash
sudo apt-get update && sudo apt-get install -y \
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
libmkl-full-dev
uv sync --extra diskann
```
**Linux (Arch Linux):**
```bash
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
boost boost-libs protobuf abseil-cpp libaio zeromq
# For MKL in DiskANN
sudo pacman -S --needed base-devel git
git clone https://aur.archlinux.org/paru-bin.git
cd paru-bin && makepkg -si
paru -S intel-oneapi-mkl intel-oneapi-compiler
source /opt/intel/oneapi/setvars.sh
uv sync --extra diskann
```
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
```bash
sudo dnf groupinstall -y "Development Tools"
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
# For MKL in DiskANN
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
source /opt/intel/oneapi/setvars.sh
uv sync --extra diskann
```
</details>
## Quick Start
Our declarative API makes RAG as easy as writing a config file.
Check out [demo.ipynb](demo.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
```python
from leann import LeannBuilder, LeannSearcher, LeannChat
from pathlib import Path
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
# Build an index
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.add_text("Tung Tung Tung Sahur called—they need their bananacrocodile hybrid back")
builder.build_index(INDEX_PATH)
# Search
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search("fantastical AI-generated creatures", top_k=1)
# Chat with your data
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
response = chat.ask("How much storage does LEANN save?", top_k=1)
```
## RAG on Everything!
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
### Generation Model Setup
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
<details>
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
Set your OpenAI API key as an environment variable:
```bash
export OPENAI_API_KEY="your-api-key-here"
```
</details>
<details>
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
**macOS:**
First, [download Ollama for macOS](https://ollama.com/download/mac).
```bash
# Pull a lightweight model (recommended for consumer hardware)
ollama pull llama3.2:1b
```
**Linux:**
```bash
# Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh
# Start Ollama service manually
ollama serve &
# Pull a lightweight model (recommended for consumer hardware)
ollama pull llama3.2:1b
```
</details>
## ⭐ Flexible Configuration
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
<details>
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
```bash
# Core Parameters (General preprocessing for all examples)
--index-dir DIR # Directory to store the index (default: current directory)
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
--max-items N # Limit data preprocessing (default: -1, process all data)
--force-rebuild # Force rebuild index even if it exists
# Embedding Parameters
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
# Search Parameters
--top-k N # Number of results to retrieve (default: 20)
--search-complexity N # Search complexity for graph traversal (default: 32)
# Chunking Parameters
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
# Index Building Parameters
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
--graph-degree N # Graph degree for index construction (default: 32)
--build-complexity N # Build complexity for index construction (default: 64)
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
```
</details>
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
<p align="center"> <p align="center">
<a href="#-quick-start">Quick Start</a> • <img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
<a href="#-features">Features</a> •
<a href="#-benchmarks">Benchmarks</a> •
<a href="#-documentation">Documentation</a> •
<a href="#-paper">Paper</a>
</p> </p>
--- The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
## 🌟 What is Leann? ```bash
source .venv/bin/activate # Don't forget to activate the virtual environment
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
```
**Leann** revolutionizes Retrieval-Augmented Generation (RAG) by eliminating the storage bottleneck of traditional vector databases. Instead of pre-computing and storing billions of embeddings, Leann dynamically computes embeddings at query time using highly optimized graph-based search algorithms. <details>
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
### 🎯 Why Leann? #### Parameters
```bash
--data-dir DIR # Directory containing documents to process (default: data)
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
```
Traditional RAG systems face a fundamental trade-off: #### Example Commands
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space ```bash
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change # Process all documents with larger chunks for academic papers
- **💰 Cost**: Vector databases are expensive to scale python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
**Leann solves this by:** # Filter only markdown and Python files with smaller chunks
-**Zero embedding storage** - Only graph structure is persisted python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
-**Real-time computation** - Embeddings computed on-demand with ms latency
-**Memory efficient** - Runs on consumer hardware (8GB RAM)
-**Always fresh** - No stale embeddings, ever
## 🚀 Quick Start # Enable AST-aware chunking for code files
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
# Or use the specialized code RAG for better code understanding
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
```
</details>
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
> **Note:** The examples below currently support macOS only. Windows support coming soon.
<p align="center">
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
</p>
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
```bash
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
```
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
<details>
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
#### Parameters
```bash
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
--include-html # Include HTML content in processing (useful for newsletters)
```
#### Example Commands
```bash
# Search work emails from a specific account
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
# Find all receipts and order confirmations (includes HTML)
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
```
</details>
<details>
<summary><strong>📋 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "Find emails from my boss about deadlines"
- "What did John say about the project timeline?"
- "Show me emails about travel expenses"
</details>
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
<p align="center">
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
</p>
```bash
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
```
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
<details>
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
#### Parameters
```bash
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
```
#### Example Commands
```bash
# Search academic research from your browsing history
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
# Track competitor analysis across work profile
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
```
</details>
<details>
<summary><strong>📋 Click to expand: How to find your Chrome profile</strong></summary>
The default Chrome profile path is configured for a typical macOS setup. If you need to find your specific Chrome profile:
1. Open Terminal
2. Run: `ls ~/Library/Application\ Support/Google/Chrome/`
3. Look for folders like "Default", "Profile 1", "Profile 2", etc.
4. Use the full path as your `--chrome-profile` argument
**Common Chrome profile locations:**
- macOS: `~/Library/Application Support/Google/Chrome/Default`
- Linux: `~/.config/google-chrome/Default`
</details>
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "What websites did I visit about machine learning?"
- "Find my search history about programming"
- "What YouTube videos did I watch recently?"
- "Show me websites I visited about travel planning"
</details>
### 💬 WeChat Detective: Unlock Your Golden Memories!
<p align="center">
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
</p>
```bash
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
```
**400K messages → 64MB storage** Search years of chat history in any language.
<details>
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
```bash
brew install sunnyyoung/repo/wechattweak-cli
```
or install it manually (if you have issues with Homebrew):
```bash
sudo packages/wechat-exporter/wechattweak-cli install
```
**Troubleshooting:**
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
- **Export errors**: If you encounter the error below, try restarting WeChat
```bash
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
Failed to find or export WeChat data. Exiting.
```
</details>
<details>
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
#### Parameters
```bash
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
--force-export # Force re-export even if data exists
```
#### Example Commands
```bash
# Search for travel plans discussed in group chats
python -m apps.wechat_rag --query "travel plans" --max-items 10000
# Re-export and search recent chats (useful after new messages)
python -m apps.wechat_rag --force-export --query "work schedule"
```
</details>
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?" (Chinese: Show me chat records about buying Magic Johnson's jersey)
</details>
### 🚀 Claude Code Integration: Transform Your Development Workflow!
<details>
<summary><strong>NEW!! ASTAware Code Chunking</strong></summary>
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
</details>
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
**Key features:**
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
- 📚 **Context-aware assistance** for debugging and development
- 🚀 **Zero-config setup** with automatic language detection
```bash
# Install LEANN globally for MCP integration
uv tool install leann-core --with leann
claude mcp add --scope user leann-server -- leann_mcp
# Setup is automatic - just start using Claude Code!
```
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
![LEANN MCP Integration](assets/mcp_leann.png)
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
## 🖥️ Command Line Interface
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
### Installation ### Installation
If you followed the Quick Start, `leann` is already installed in your virtual environment:
```bash ```bash
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann source .venv/bin/activate
cd leann leann --help
git submodule update --init --recursive
uv sync
``` ```
### 30-Second Example **To make it globally available:**
```bash
# Install the LEANN CLI globally using uv tool
uv tool install leann-core --with leann
# Now you can use leann from anywhere without activating venv
leann --help
```
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
### Usage Examples
```bash
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
leann build my-docs --docs ./your_documents
# Search your documents
leann search my-docs "machine learning concepts"
# Interactive chat with your documents
leann ask my-docs --interactive
# List all your indexes
leann list
# Remove an index
leann remove my-docs
```
**Key CLI features:**
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
- Smart text chunking with overlap for all other content
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
- Organized index storage in `.leann/indexes/` (project-local)
- Support for advanced search parameters
<details>
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
**Build Command:**
```bash
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
Options:
--backend {hnsw,diskann} Backend to use (default: hnsw)
--embedding-model MODEL Embedding model (default: facebook/contriever)
--graph-degree N Graph degree (default: 32)
--complexity N Build complexity (default: 64)
--force Force rebuild existing index
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
--recompute / --no-recompute Enable recomputation (default: true)
```
**Search Command:**
```bash
leann search INDEX_NAME QUERY [OPTIONS]
Options:
--top-k N Number of results (default: 5)
--complexity N Search complexity (default: 64)
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
--pruning-strategy {global,local,proportional}
```
**Ask Command:**
```bash
leann ask INDEX_NAME [OPTIONS]
Options:
--llm {ollama,openai,hf} LLM provider (default: ollama)
--model MODEL Model name (default: qwen3:8b)
--interactive Interactive chat mode
--top-k N Retrieval count (default: 20)
```
**List Command:**
```bash
leann list
# Lists all indexes across all projects with status indicators:
# ✅ - Index is complete and ready to use
# ❌ - Index is incomplete or corrupted
# 📁 - CLI-created index (in .leann/indexes/)
# 📄 - App-created index (*.leann.meta.json files)
```
**Remove Command:**
```bash
leann remove INDEX_NAME [OPTIONS]
Options:
--force, -f Force removal without confirmation
# Smart removal: automatically finds and safely removes indexes
# - Shows all matching indexes across projects
# - Requires confirmation for cross-project removal
# - Interactive selection when multiple matches found
# - Supports both CLI and app-created indexes
```
</details>
## 🚀 Advanced Features
### 🎯 Metadata Filtering
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
```python ```python
from leann.api import LeannBuilder, LeannSearcher # Add metadata during indexing
builder.add_text(
"def authenticate_user(token): ...",
metadata={"file_extension": ".py", "lines_of_code": 25}
)
# 1. Build index (no embeddings stored!) # Search with filters
builder = LeannBuilder(backend_name="diskann") results = searcher.search(
builder.add_text("Python is a powerful programming language") query="authentication function",
builder.add_text("Machine learning transforms industries") metadata_filters={
builder.add_text("Neural networks process complex data") "file_extension": {"==": ".py"},
builder.build_index("knowledge.leann") "lines_of_code": {"<": 100}
}
# 2. Search with real-time embeddings )
searcher = LeannSearcher("knowledge.leann")
results = searcher.search("programming languages", top_k=2)
for result in results:
print(f"Score: {result['score']:.3f} - {result['text']}")
``` ```
### Run the Demo **Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
## 🏗️ Architecture & How It Works
<p align="center">
<img src="assets/arch.png" alt="LEANN Architecture" width="800">
</p>
**The magic:** Most vector DBs store every single embedding (expensive). LEANN stores a pruned graph structure (cheap) and recomputes embeddings only when needed (fast).
**Core techniques:**
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
**Backends:**
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
## Benchmarks
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
### 📊 Storage Comparison
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|--------|-------------|------------|-------------|--------------|---------------|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
| Savings| 91% | 97% | 97% | 97% | 95% |
## Reproduce Our Results
```bash ```bash
uv run examples/document_search.py uv pip install -e ".[dev]" # Install dev dependencies
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
``` ```
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)** The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
This demo showcases how to build a RAG system for PDF documents using Leann.
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
```bash
uv run examples/main_cli_example.py
```
## ✨ Features
### 🔥 Core Features
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
- **📈 Scalable Architecture**: Handles millions of documents on consumer hardware
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
### 🛠️ Technical Highlights
- **Zero-copy operations** for maximum performance
- **SIMD-optimized** distance computations (AVX2/AVX512)
- **Async embedding pipeline** with batched processing
- **Memory-mapped indices** for fast startup
- **Recompute mode** for highest accuracy scenarios
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
- **Rich debugging tools** - Built-in performance profiling
## 📊 Benchmarks
### Memory Usage Comparison
| System | 1M Documents | 10M Documents | 100M Documents |
|--------|-------------|---------------|----------------|
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
### Query Performance
| Backend | Index Size | Query Time | Recall@10 |
|---------|------------|------------|-----------|
| DiskANN | 1M docs | 12ms | 0.95 |
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
| HNSW | 1M docs | 8ms | 0.93 |
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
## 🏗️ Architecture
```
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
│ Query Text │───▶│ Embedding │───▶│ Graph-based │
│ │ │ Computation │ │ Search │
└─────────────────┘ └──────────────────┘ └─────────────────┘
│ │
▼ ▼
┌──────────────┐ ┌──────────────┐
│ ZMQ Server │ │ Pruned Graph │
│ (Cached) │ │ Index │
└──────────────┘ └──────────────┘
```
### Key Components
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
2. **📊 Graph Index**: Memory-efficient navigation structures
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
## 🎓 Supported Models & Backends
### 🤖 Embedding Models
- **sentence-transformers/all-mpnet-base-v2** (default)
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
- Any HuggingFace sentence-transformer model
- Custom model support via API
### 🔧 Search Backends
- **DiskANN**: Microsoft's billion-scale ANN algorithm
- **HNSW**: Hierarchical Navigable Small World graphs
- **Coming soon**: ScaNN, Faiss-IVF, NGT
### 📏 Distance Functions
- **L2**: Euclidean distance for precise similarity
- **Cosine**: Angular similarity for normalized vectors
- **MIPS**: Maximum Inner Product Search for recommendation systems
## 🔬 Paper ## 🔬 Paper
If you find Leann useful, please cite: If you find Leann useful, please cite:
@@ -185,110 +717,15 @@ If you find Leann useful, please cite:
} }
``` ```
## 🌍 Use Cases ## ✨ [Detailed Features →](docs/features.md)
### 💼 Enterprise RAG ## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
```python
# Handle millions of documents with limited resources
builder = LeannBuilder(
backend_name="diskann",
distance_metric="cosine",
graph_degree=64,
memory_budget="4GB"
)
```
### 🔬 Research & Experimentation
```python
# Quick prototyping with different algorithms
for backend in ["diskann", "hnsw"]:
searcher = LeannSearcher(index_path, backend=backend)
evaluate_recall(searcher, queries, ground_truth)
```
### 🚀 Real-time Applications
```python
# Sub-second response times
chat = LeannChat("knowledge.leann")
response = chat.ask("What is quantum computing?")
# Returns in <100ms with recompute mode
```
## 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
### Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
### Development Setup
```bash
git clone https://github.com/yourname/leann
cd leann
uv sync --dev
uv run pytest tests/
```
### Quick Tests
```bash
# Sanity check all distance functions
uv run python tests/sanity_checks/test_distance_functions.py
# Verify L2 implementation
uv run python tests/sanity_checks/test_l2_verification.py
```
## ❓ FAQ
### Common Issues
#### NCCL Topology Error
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
```
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
```
**Solution**: Set these environment variables before running your script:
```bash
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
export NCCL_IB_DISABLE=1
export NCCL_NET_PLUGIN=none
export NCCL_SOCKET_IFNAME=ens5
## 📈 Roadmap ## ❓ [FAQ →](docs/faq.md)
### 🎯 Q1 2024
- [x] DiskANN backend with MIPS/L2/Cosine support
- [x] HNSW backend integration
- [x] Real-time embedding pipeline
- [x] Memory-efficient graph pruning
### 🚀 Q2 2024 ## 📈 [Roadmap →](docs/roadmap.md)
- [ ] Distributed search across multiple nodes
- [ ] ScaNN backend support
- [ ] Advanced caching strategies
- [ ] Kubernetes deployment guides
### 🌟 Q3 2024
- [ ] GPU-accelerated embedding computation
- [ ] Approximate distance functions
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
## 💬 Community
Join our growing community of researchers and engineers!
- 🐦 **Twitter**: [@LeannAI](https://twitter.com/LeannAI)
- 💬 **Discord**: [Join our server](https://discord.gg/leann)
- 📧 **Email**: leann@yourcompany.com
- 🐙 **GitHub Discussions**: [Ask questions here](https://github.com/yourname/leann/discussions)
## 📄 License ## 📄 License
@@ -296,13 +733,18 @@ MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments ## 🙏 Acknowledgments
- **Microsoft Research** for the DiskANN algorithm Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
- **Meta AI** for FAISS and optimization insights
- **HuggingFace** for the transformer ecosystem
- **Our amazing contributors** who make this possible
--- Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
We welcome more contributors! Feel free to open issues or submit PRs.
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=yichuan-w/LEANN&type=Date)](https://www.star-history.com/#yichuan-w/LEANN&Date)
<p align="center"> <p align="center">
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong> <strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
</p> </p>

342
apps/base_rag_example.py Normal file
View File

@@ -0,0 +1,342 @@
"""
Base class for unified RAG examples interface.
Provides common parameters and functionality for all RAG examples.
"""
import argparse
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
from leann.registry import register_project_directory
dotenv.load_dotenv()
class BaseRAGExample(ABC):
"""Base class for all RAG examples with unified interface."""
def __init__(
self,
name: str,
description: str,
default_index_name: str,
):
self.name = name
self.description = description
self.default_index_name = default_index_name
self.parser = self._create_parser()
def _create_parser(self) -> argparse.ArgumentParser:
"""Create argument parser with common parameters."""
parser = argparse.ArgumentParser(
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
)
# Core parameters (all examples share these)
core_group = parser.add_argument_group("Core Parameters")
core_group.add_argument(
"--index-dir",
type=str,
default=f"./{self.default_index_name}",
help=f"Directory to store the index (default: ./{self.default_index_name})",
)
core_group.add_argument(
"--query",
type=str,
default=None,
help="Query to run (if not provided, will run in interactive mode)",
)
# Allow subclasses to override default max_items
max_items_default = getattr(self, "max_items_default", -1)
core_group.add_argument(
"--max-items",
type=int,
default=max_items_default,
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
)
core_group.add_argument(
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
)
# Embedding parameters
embedding_group = parser.add_argument_group("Embedding Parameters")
# Allow subclasses to override default embedding_model
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
embedding_group.add_argument(
"--embedding-model",
type=str,
default=embedding_model_default,
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
)
embedding_group.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
)
# LLM parameters
llm_group = parser.add_argument_group("LLM Parameters")
llm_group.add_argument(
"--llm",
type=str,
default="openai",
choices=["openai", "ollama", "hf", "simulated"],
help="LLM backend: openai, ollama, or hf (default: openai)",
)
llm_group.add_argument(
"--llm-model",
type=str,
default=None,
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
)
llm_group.add_argument(
"--llm-host",
type=str,
default="http://localhost:11434",
help="Host for Ollama API (default: http://localhost:11434)",
)
llm_group.add_argument(
"--thinking-budget",
type=str,
choices=["low", "medium", "high"],
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
# AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters")
ast_group.add_argument(
"--use-ast-chunking",
action="store_true",
help="Enable AST-aware chunking for code files (requires astchunk)",
)
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=512,
help="Maximum characters per AST chunk (default: 512)",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks (default: 64)",
)
ast_group.add_argument(
"--code-file-extensions",
nargs="+",
default=None,
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
)
ast_group.add_argument(
"--ast-fallback-traditional",
action="store_true",
default=True,
help="Fall back to traditional chunking if AST chunking fails (default: True)",
)
# Search parameters
search_group = parser.add_argument_group("Search Parameters")
search_group.add_argument(
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
)
search_group.add_argument(
"--search-complexity",
type=int,
default=32,
help="Search complexity for graph traversal (default: 64)",
)
# Index building parameters
index_group = parser.add_argument_group("Index Building Parameters")
index_group.add_argument(
"--backend-name",
type=str,
default="hnsw",
choices=["hnsw", "diskann"],
help="Backend to use for index (default: hnsw)",
)
index_group.add_argument(
"--graph-degree",
type=int,
default=32,
help="Graph degree for index construction (default: 32)",
)
index_group.add_argument(
"--build-complexity",
type=int,
default=64,
help="Build complexity for index construction (default: 64)",
)
index_group.add_argument(
"--no-compact",
action="store_true",
help="Disable compact index storage",
)
index_group.add_argument(
"--no-recompute",
action="store_true",
help="Disable embedding recomputation",
)
# Add source-specific parameters
self._add_specific_arguments(parser)
return parser
@abstractmethod
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add source-specific arguments. Override in subclasses."""
pass
@abstractmethod
async def load_data(self, args) -> list[str]:
"""Load data from the source. Returns list of text chunks."""
pass
def get_llm_config(self, args) -> dict[str, Any]:
"""Get LLM configuration based on arguments."""
config = {"type": args.llm}
if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o"
elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = args.llm_host
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
# Simulated LLM doesn't need additional configuration
pass
return config
async def build_index(self, args, texts: list[str]) -> str:
"""Build LEANN index from texts."""
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
graph_degree=args.graph_degree,
complexity=args.build_complexity,
is_compact=not args.no_compact,
is_recompute=not args.no_recompute,
num_threads=1, # Force single-threaded mode
)
# Add texts in batches for better progress tracking
batch_size = 1000
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
for text in batch:
builder.add_text(text)
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
print("Building index structure...")
builder.build_index(index_path)
print(f"Index saved to: {index_path}")
# Register project directory so leann list can discover this index
# The index is saved as args.index_dir/index_name.leann
# We want to register the current working directory where the app is run
register_project_directory(Path.cwd())
return index_path
async def run_interactive_chat(self, args, index_path: str):
"""Run interactive chat with the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
complexity=args.search_complexity,
)
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
print("Type 'quit' or 'exit' to stop.\n")
while True:
try:
query = input("You: ").strip()
if query.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
if not query:
continue
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
async def run_single_query(self, args, index_path: str, query: str):
"""Run a single query against the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
complexity=args.search_complexity,
)
print(f"\n[Query]: \033[36m{query}\033[0m")
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
)
print(f"\n[Response]: \033[36m{response}\033[0m")
async def run(self):
"""Main entry point for the example."""
args = self.parser.parse_args()
# Check if index exists
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
index_exists = Path(args.index_dir).exists()
if not index_exists or args.force_rebuild:
# Load data and build index
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
texts = await self.load_data(args)
if not texts:
print("No data found to index!")
return
index_path = await self.build_index(args, texts)
else:
print(f"\nUsing existing index in {args.index_dir}")
# Run query or interactive mode
if args.query:
await self.run_single_query(args, index_path, args.query)
else:
await self.run_interactive_chat(args, index_path)

171
apps/browser_rag.py Normal file
View File

@@ -0,0 +1,171 @@
"""
Browser History RAG example using the unified interface.
Supports Chrome browser history.
"""
import os
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .history_data.history import ChromeHistoryReader
class BrowserRAG(BaseRAGExample):
"""RAG example for Chrome browser history."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Browser History",
description="Process and query Chrome browser history with LEANN",
default_index_name="google_history_index",
)
def _add_specific_arguments(self, parser):
"""Add browser-specific arguments."""
browser_group = parser.add_argument_group("Browser Parameters")
browser_group.add_argument(
"--chrome-profile",
type=str,
default=None,
help="Path to Chrome profile directory (auto-detected if not specified)",
)
browser_group.add_argument(
"--auto-find-profiles",
action="store_true",
default=True,
help="Automatically find all Chrome profiles (default: True)",
)
browser_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
browser_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _get_chrome_base_path(self) -> Path:
"""Get the base Chrome profile path based on OS."""
if sys.platform == "darwin":
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
elif sys.platform.startswith("linux"):
return Path.home() / ".config" / "google-chrome"
elif sys.platform == "win32":
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
else:
raise ValueError(f"Unsupported platform: {sys.platform}")
def _find_chrome_profiles(self) -> list[Path]:
"""Auto-detect all Chrome profiles."""
base_path = self._get_chrome_base_path()
if not base_path.exists():
return []
profiles = []
# Check Default profile
default_profile = base_path / "Default"
if default_profile.exists() and (default_profile / "History").exists():
profiles.append(default_profile)
# Check numbered profiles
for item in base_path.iterdir():
if item.is_dir() and item.name.startswith("Profile "):
if (item / "History").exists():
profiles.append(item)
return profiles
async def load_data(self, args) -> list[str]:
"""Load browser history and convert to text chunks."""
# Determine Chrome profiles
if args.chrome_profile and not args.auto_find_profiles:
profile_dirs = [Path(args.chrome_profile)]
else:
print("Auto-detecting Chrome profiles...")
profile_dirs = self._find_chrome_profiles()
# If specific profile given, filter to just that one
if args.chrome_profile:
profile_path = Path(args.chrome_profile)
profile_dirs = [p for p in profile_dirs if p == profile_path]
if not profile_dirs:
print("No Chrome profiles found!")
print("Please specify --chrome-profile manually")
return []
print(f"Found {len(profile_dirs)} Chrome profiles")
# Create reader
reader = ChromeHistoryReader()
# Process each profile
all_documents = []
total_processed = 0
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
try:
# Apply max_items limit per profile
max_per_profile = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_profile = remaining
# Load history
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_per_profile,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} history entries from this profile")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No browser history found to process!")
return []
print(f"\nTotal history entries processed: {len(all_documents)}")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for browser history RAG
print("\n🌐 Browser History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What websites did I visit about machine learning?'")
print("- 'Find my search history about programming'")
print("- 'What YouTube videos did I watch recently?'")
print("- 'Show me websites about travel planning'")
print("\nNote: Make sure Chrome is closed before running\n")
rag = BrowserRAG()
asyncio.run(rag.run())

22
apps/chunking/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
"""
Chunking utilities for LEANN RAG applications.
Provides AST-aware and traditional text chunking functionality.
"""
from .utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
__all__ = [
"CODE_EXTENSIONS",
"create_ast_chunks",
"create_text_chunks",
"create_traditional_chunks",
"detect_code_files",
"get_language_from_extension",
]

320
apps/chunking/utils.py Normal file
View File

@@ -0,0 +1,320 @@
"""
Enhanced chunking utilities with AST-aware code chunking support.
Provides unified interface for both traditional and AST-based text chunking.
"""
import logging
from pathlib import Path
from typing import Optional
from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__)
# Code file extensions supported by astchunk
CODE_EXTENSIONS = {
".py": "python",
".java": "java",
".cs": "csharp",
".ts": "typescript",
".tsx": "typescript",
".js": "typescript",
".jsx": "typescript",
}
# Default chunk parameters for different content types
DEFAULT_CHUNK_PARAMS = {
"code": {
"max_chunk_size": 512,
"chunk_overlap": 64,
},
"text": {
"chunk_size": 256,
"chunk_overlap": 128,
},
}
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
"""
Separate documents into code files and regular text files.
Args:
documents: List of LlamaIndex Document objects
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
Returns:
Tuple of (code_documents, text_documents)
"""
if code_extensions is None:
code_extensions = CODE_EXTENSIONS
code_docs = []
text_docs = []
for doc in documents:
# Get file path from metadata
file_path = doc.metadata.get("file_path", "")
if not file_path:
# Fallback to file_name
file_path = doc.metadata.get("file_name", "")
if file_path:
file_ext = Path(file_path).suffix.lower()
if file_ext in code_extensions:
# Add language info to metadata
doc.metadata["language"] = code_extensions[file_ext]
doc.metadata["is_code"] = True
code_docs.append(doc)
else:
doc.metadata["is_code"] = False
text_docs.append(doc)
else:
# If no file path, treat as text
doc.metadata["is_code"] = False
text_docs.append(doc)
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
return code_docs, text_docs
def get_language_from_extension(file_path: str) -> Optional[str]:
"""Get the programming language from file extension."""
ext = Path(file_path).suffix.lower()
return CODE_EXTENSIONS.get(ext)
def create_ast_chunks(
documents,
max_chunk_size: int = 512,
chunk_overlap: int = 64,
metadata_template: str = "default",
) -> list[str]:
"""
Create AST-aware chunks from code documents using astchunk.
Args:
documents: List of code documents
max_chunk_size: Maximum characters per chunk
chunk_overlap: Number of AST nodes to overlap between chunks
metadata_template: Template for chunk metadata
Returns:
List of text chunks with preserved code structure
"""
try:
from astchunk import ASTChunkBuilder
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
all_chunks = []
for doc in documents:
# Get language from metadata (set by detect_code_files)
language = doc.metadata.get("language")
if not language:
logger.warning(
"No language detected for document, falling back to traditional chunking"
)
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
all_chunks.extend(traditional_chunks)
continue
try:
# Configure astchunk
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
"metadata_template": metadata_template,
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
}
# Add repository-level metadata if available
repo_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
"creation_date": doc.metadata.get("creation_date", ""),
"last_modified_date": doc.metadata.get("last_modified_date", ""),
}
configs["repo_level_metadata"] = repo_metadata
# Create chunk builder and process
chunk_builder = ASTChunkBuilder(**configs)
code_content = doc.get_content()
if not code_content or not code_content.strip():
logger.warning("Empty code content, skipping")
continue
chunks = chunk_builder.chunkify(code_content)
# Extract text content from chunks
for chunk in chunks:
if hasattr(chunk, "text"):
chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str):
chunk_text = chunk
else:
# Try to convert to string
chunk_text = str(chunk)
if chunk_text and chunk_text.strip():
all_chunks.append(chunk_text.strip())
logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
)
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
all_chunks.extend(traditional_chunks)
return all_chunks
def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]:
"""
Create traditional text chunks using LlamaIndex SentenceSplitter.
Args:
documents: List of documents to chunk
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
Returns:
List of text chunks
"""
# Handle invalid chunk_size values
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
# Ensure chunk_overlap is not negative and not larger than chunk_size
if chunk_overlap < 0:
chunk_overlap = 0
if chunk_overlap >= chunk_size:
chunk_overlap = chunk_size // 2
node_parser = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=" ",
paragraph_separator="\n\n",
)
all_texts = []
for doc in documents:
try:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
chunk_texts = [node.get_content() for node in nodes]
all_texts.extend(chunk_texts)
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}")
# As last resort, add the raw content
content = doc.get_content()
if content and content.strip():
all_texts.append(content.strip())
return all_texts
def create_text_chunks(
documents,
chunk_size: int = 256,
chunk_overlap: int = 128,
use_ast_chunking: bool = False,
ast_chunk_size: int = 512,
ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
) -> list[str]:
"""
Create text chunks from documents with optional AST support for code files.
Args:
documents: List of LlamaIndex Document objects
chunk_size: Size for traditional text chunks
chunk_overlap: Overlap for traditional text chunks
use_ast_chunking: Whether to use AST chunking for code files
ast_chunk_size: Size for AST chunks
ast_chunk_overlap: Overlap for AST chunks
code_file_extensions: Custom list of code file extensions
ast_fallback_traditional: Fall back to traditional chunking on AST errors
Returns:
List of text chunks
"""
if not documents:
logger.warning("No documents provided for chunking")
return []
# Create a local copy of supported extensions for this function call
local_code_extensions = CODE_EXTENSIONS.copy()
# Update supported extensions if provided
if code_file_extensions:
# Map extensions to languages (simplified mapping)
ext_mapping = {
".py": "python",
".java": "java",
".cs": "c_sharp",
".ts": "typescript",
".tsx": "typescript",
}
for ext in code_file_extensions:
if ext.lower() not in local_code_extensions:
# Try to guess language from extension
if ext.lower() in ext_mapping:
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
else:
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
all_chunks = []
if use_ast_chunking:
# Separate code and text documents using local extensions
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
# Process code files with AST chunking
if code_docs:
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
try:
ast_chunks = create_ast_chunks(
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
)
all_chunks.extend(ast_chunks)
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
except Exception as e:
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
logger.info("Falling back to traditional chunking for code files")
traditional_code_chunks = create_traditional_chunks(
code_docs, chunk_size, chunk_overlap
)
all_chunks.extend(traditional_code_chunks)
else:
raise
# Process text files with traditional chunking
if text_docs:
logger.info(f"Processing {len(text_docs)} text files with traditional chunking")
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
all_chunks.extend(text_chunks)
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
else:
# Use traditional chunking for all files
logger.info(f"Processing {len(documents)} documents with traditional chunking")
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")
return all_chunks

211
apps/code_rag.py Normal file
View File

@@ -0,0 +1,211 @@
"""
Code RAG example using AST-aware chunking for optimal code understanding.
Specialized for code repositories with automatic language detection and
optimized chunking parameters.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import CODE_EXTENSIONS, create_text_chunks
from llama_index.core import SimpleDirectoryReader
class CodeRAG(BaseRAGExample):
"""Specialized RAG example for code repositories with AST-aware chunking."""
def __init__(self):
super().__init__(
name="Code",
description="Process and query code repositories with AST-aware chunking",
default_index_name="code_index",
)
# Override defaults for code-specific usage
self.embedding_model_default = "facebook/contriever" # Good for code
self.max_items_default = -1 # Process all code files by default
def _add_specific_arguments(self, parser):
"""Add code-specific arguments."""
code_group = parser.add_argument_group("Code Repository Parameters")
code_group.add_argument(
"--repo-dir",
type=str,
default=".",
help="Code repository directory to index (default: current directory)",
)
code_group.add_argument(
"--include-extensions",
nargs="+",
default=list(CODE_EXTENSIONS.keys()),
help="File extensions to include (default: supported code extensions)",
)
code_group.add_argument(
"--exclude-dirs",
nargs="+",
default=[
".git",
"__pycache__",
"node_modules",
"venv",
".venv",
"build",
"dist",
"target",
],
help="Directories to exclude from indexing",
)
code_group.add_argument(
"--max-file-size",
type=int,
default=1000000, # 1MB
help="Maximum file size in bytes to process (default: 1MB)",
)
code_group.add_argument(
"--include-comments",
action="store_true",
help="Include comments in chunking (useful for documentation)",
)
code_group.add_argument(
"--preserve-imports",
action="store_true",
default=True,
help="Try to preserve import statements in chunks (default: True)",
)
async def load_data(self, args) -> list[str]:
"""Load code files and convert to AST-aware chunks."""
print(f"🔍 Scanning code repository: {args.repo_dir}")
print(f"📁 Including extensions: {args.include_extensions}")
print(f"🚫 Excluding directories: {args.exclude_dirs}")
# Check if repository directory exists
repo_path = Path(args.repo_dir)
if not repo_path.exists():
raise ValueError(f"Repository directory not found: {args.repo_dir}")
# Load code files with filtering
reader_kwargs = {
"recursive": True,
"encoding": "utf-8",
"required_exts": args.include_extensions,
"exclude_hidden": True,
}
# Create exclusion filter
def file_filter(file_path: str) -> bool:
"""Filter out unwanted files and directories."""
path = Path(file_path)
# Check file size
try:
if path.stat().st_size > args.max_file_size:
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
return False
except Exception:
return False
# Check if in excluded directory
for exclude_dir in args.exclude_dirs:
if exclude_dir in path.parts:
return False
return True
try:
# Load documents with file filtering
documents = SimpleDirectoryReader(
args.repo_dir,
file_extractor=None, # Use default extractors
**reader_kwargs,
).load_data(show_progress=True)
# Apply custom filtering
filtered_docs = []
for doc in documents:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
filtered_docs.append(doc)
documents = filtered_docs
except Exception as e:
print(f"❌ Error loading code files: {e}")
return []
if not documents:
print(
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
)
return []
print(f"✅ Loaded {len(documents)} code files")
# Show breakdown by language/extension
ext_counts = {}
for doc in documents:
file_path = doc.metadata.get("file_path", "")
if file_path:
ext = Path(file_path).suffix.lower()
ext_counts[ext] = ext_counts.get(ext, 0) + 1
print("📊 Files by extension:")
for ext, count in sorted(ext_counts.items()):
print(f" {ext}: {count} files")
# Use AST-aware chunking by default for code
print(
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
)
all_texts = create_text_chunks(
documents,
chunk_size=256, # Fallback for non-code files
chunk_overlap=64,
use_ast_chunking=True, # Always use AST for code RAG
ast_chunk_size=args.ast_chunk_size,
ast_chunk_overlap=args.ast_chunk_overlap,
code_file_extensions=args.include_extensions,
ast_fallback_traditional=True,
)
# Apply max_items limit if specified
if args.max_items > 0 and len(all_texts) > args.max_items:
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
all_texts = all_texts[: args.max_items]
print(f"✅ Generated {len(all_texts)} code chunks")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for code RAG
print("\n💻 Code RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'How does the embedding computation work?'")
print("- 'What are the main classes in this codebase?'")
print("- 'Show me the search implementation'")
print("- 'How is error handling implemented?'")
print("- 'What design patterns are used?'")
print("- 'Explain the chunking logic'")
print("\n🚀 Features:")
print("- ✅ AST-aware chunking preserves code structure")
print("- ✅ Automatic language detection")
print("- ✅ Smart filtering of large files and common excludes")
print("- ✅ Optimized for code understanding")
print("\nUsage examples:")
print(" python -m apps.code_rag --repo-dir ./my_project")
print(
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
)
print("\nOr run without --query for interactive mode\n")
rag = CodeRAG()
asyncio.run(rag.run())

131
apps/document_rag.py Normal file
View File

@@ -0,0 +1,131 @@
"""
Document RAG example using the unified interface.
Supports PDF, TXT, MD, and other document formats.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from llama_index.core import SimpleDirectoryReader
class DocumentRAG(BaseRAGExample):
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
def __init__(self):
super().__init__(
name="Document",
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
default_index_name="test_doc_files",
)
def _add_specific_arguments(self, parser):
"""Add document-specific arguments."""
doc_group = parser.add_argument_group("Document Parameters")
doc_group.add_argument(
"--data-dir",
type=str,
default="data",
help="Directory containing documents to index (default: data)",
)
doc_group.add_argument(
"--file-types",
nargs="+",
default=None,
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
)
doc_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
doc_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
doc_group.add_argument(
"--enable-code-chunking",
action="store_true",
help="Enable AST-aware chunking for code files in the data directory",
)
async def load_data(self, args) -> list[str]:
"""Load documents and convert to text chunks."""
print(f"Loading documents from: {args.data_dir}")
if args.file_types:
print(f"Filtering by file types: {args.file_types}")
else:
print("Processing all supported file types")
# Check if data directory exists
data_path = Path(args.data_dir)
if not data_path.exists():
raise ValueError(f"Data directory not found: {args.data_dir}")
# Load documents
reader_kwargs = {
"recursive": True,
"encoding": "utf-8",
}
if args.file_types:
reader_kwargs["required_exts"] = args.file_types
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
show_progress=True
)
if not documents:
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
return []
print(f"Loaded {len(documents)} documents")
# Determine chunking strategy
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
if use_ast:
print("Using AST-aware chunking for code files")
# Convert to text chunks with optional AST support
all_texts = create_text_chunks(
documents,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
use_ast_chunking=use_ast,
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
code_file_extensions=getattr(args, "code_file_extensions", None),
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
)
# Apply max_items limit if specified
if args.max_items > 0 and len(all_texts) > args.max_items:
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
all_texts = all_texts[: args.max_items]
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for document RAG
print("\n📄 Document RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What are the main techniques LEANN uses?'")
print("- 'What is the technique DLPM?'")
print("- 'Who does Elizabeth Bennet marry?'")
print(
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
)
print("\n🚀 NEW: Code-aware chunking available!")
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
print("- Supports Python, Java, C#, TypeScript files")
print("- Better semantic understanding of code structure")
print("\nOr run without --query for interactive mode\n")
rag = DocumentRAG()
asyncio.run(rag.run())

View File

@@ -0,0 +1,167 @@
import email
import os
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str | None = None) -> list[Path]:
"""
Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects.
"""
if root is None:
# Auto-detect user's mail path
home_dir = os.path.expanduser("~")
root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = []
for dirpath, _dirnames, _filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath))
return messages_dirs
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self, include_html: bool = False) -> None:
"""
Initialize.
Args:
include_html: Whether to include HTML content in the email body (default: False)
"""
self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
count = 0
total_files = 0
successful_files = 0
failed_files = 0
print(f"Starting to process directory: {input_dir}")
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
# Check if we've reached the max count (skip if max_count == -1)
if max_count > 0 and count >= max_count:
break
if filename.endswith(".emlx"):
total_files += 1
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split("\n", 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get("Subject", "No Subject")
from_addr = msg.get("From", "Unknown")
to_addr = msg.get("To", "Unknown")
date = msg.get("Date", "Unknown")
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if (
part.get_content_type() == "text/plain"
or part.get_content_type() == "text/html"
):
if (
part.get_content_type() == "text/html"
and not self.include_html
):
continue
try:
payload = part.get_payload(decode=True)
if payload:
body += payload.decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error decoding payload: {e}")
continue
else:
try:
payload = msg.get_payload(decode=True)
if payload:
body = payload.decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error decoding single part payload: {e}")
body = ""
# Only create document if we have some content
if body.strip() or subject != "No Subject":
# Create document content with metadata embedded in text
doc_content = f"""
[File]: {filename}
[From]: {from_addr}
[To]: {to_addr}
[Subject]: {subject}
[Date]: {date}
[EMAIL BODY Start]:
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
successful_files += 1
# Print first few successful files for debugging
if successful_files <= 3:
print(
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
)
except Exception as e:
failed_files += 1
if failed_files <= 5: # Only print first few errors
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
failed_files += 1
if failed_files <= 5: # Only print first few errors
print(f"Error reading file {filepath}: {e}")
continue
print("Processing summary:")
print(f" Total .emlx files found: {total_files}")
print(f" Successfully loaded: {successful_files}")
print(f" Failed to load: {failed_files}")
print(f" Final documents: {len(docs)}")
return docs

186
apps/email_data/email.py Normal file
View File

@@ -0,0 +1,186 @@
"""
Mbox parser.
Contains simple parser for mbox files.
"""
import logging
from pathlib import Path
from typing import Any
from fsspec import AbstractFileSystem
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
logger = logging.getLogger(__name__)
class MboxReader(BaseReader):
"""
Mbox parser.
Extract messages from mailbox files.
Returns string including date, subject, sender, receiver and
content for each message.
"""
DEFAULT_MESSAGE_FORMAT: str = (
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
)
def __init__(
self,
*args: Any,
max_count: int = 0,
message_format: str = DEFAULT_MESSAGE_FORMAT,
**kwargs: Any,
) -> None:
"""Init params."""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
super().__init__(*args, **kwargs)
self.max_count = max_count
self.message_format = message_format
def load_data(
self,
file: Path,
extra_info: dict | None = None,
fs: AbstractFileSystem | None = None,
) -> list[Document]:
"""Parse file into string."""
# Import required libraries
import mailbox
from email.parser import BytesParser
from email.policy import default
from bs4 import BeautifulSoup
if fs:
logger.warning(
"fs was specified but MboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
i = 0
results: list[str] = []
# Load file using mailbox
bytes_parser = BytesParser(policy=default).parse
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
# Iterate through all messages
for _, _msg in enumerate(mbox):
try:
msg: mailbox.mboxMessage = _msg
# Parse multipart messages
if msg.is_multipart():
for part in msg.walk():
ctype = part.get_content_type()
cdispo = str(part.get("Content-Disposition"))
if "attachment" in cdispo:
print(f"Attachment found: {part.get_filename()}")
if ctype == "text/plain" and "attachment" not in cdispo:
content = part.get_payload(decode=True) # decode
break
# Get plain message payload for non-multipart messages
else:
content = msg.get_payload(decode=True)
# Parse message HTML content and remove unneeded whitespace
soup = BeautifulSoup(content)
stripped_content = " ".join(soup.get_text().split())
# Format message to include date, sender, receiver and subject
msg_string = self.message_format.format(
_date=msg["date"],
_from=msg["from"],
_to=msg["to"],
_subject=msg["subject"],
_content=stripped_content,
)
# Add message string to results
results.append(msg_string)
except Exception as e:
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
# Increment counter and return if max count is met
i += 1
if self.max_count > 0 and i >= self.max_count:
break
return [Document(text=result, metadata=extra_info or {}) for result in results]
class EmlxMboxReader(MboxReader):
"""
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
Extends MboxReader to work with Apple Mail's .emlx format by:
1. Reading .emlx files from a directory
2. Converting them to mbox format in memory
3. Using the parent MboxReader's parsing logic
"""
def load_data(
self,
directory: Path,
extra_info: dict | None = None,
fs: AbstractFileSystem | None = None,
) -> list[Document]:
"""Parse .emlx files from directory into strings using MboxReader logic."""
import os
import tempfile
if fs:
logger.warning(
"fs was specified but EmlxMboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
# Find all .emlx files in the directory
emlx_files = list(directory.glob("*.emlx"))
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
if not emlx_files:
logger.warning(f"No .emlx files found in {directory}")
return []
# Create a temporary mbox file
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
temp_mbox_path = temp_mbox.name
# Convert .emlx files to mbox format
for emlx_file in emlx_files:
try:
# Read the .emlx file
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
content = f.read()
# .emlx format: first line is length, rest is email content
lines = content.split("\n", 1)
if len(lines) >= 2:
email_content = lines[1] # Skip the length line
# Write to mbox format (each message starts with "From " and ends with blank line)
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
except Exception as e:
logger.warning(f"Failed to process {emlx_file}: {e}")
continue
# Close the temporary file so MboxReader can read it
temp_mbox.close()
try:
# Use the parent MboxReader's logic to parse the mbox file
return super().load_data(Path(temp_mbox_path), extra_info, fs)
finally:
# Clean up temporary file
try:
os.unlink(temp_mbox_path)
except OSError:
pass

157
apps/email_rag.py Normal file
View File

@@ -0,0 +1,157 @@
"""
Email RAG example using the unified interface.
Supports Apple Mail on macOS.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .email_data.LEANN_email_reader import EmlxReader
class EmailRAG(BaseRAGExample):
"""RAG example for Apple Mail processing."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all emails by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Email",
description="Process and query Apple Mail emails with LEANN",
default_index_name="mail_index",
)
def _add_specific_arguments(self, parser):
"""Add email-specific arguments."""
email_group = parser.add_argument_group("Email Parameters")
email_group.add_argument(
"--mail-path",
type=str,
default=None,
help="Path to Apple Mail directory (auto-detected if not specified)",
)
email_group.add_argument(
"--include-html", action="store_true", help="Include HTML content in email processing"
)
email_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
email_group.add_argument(
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
)
def _find_mail_directories(self) -> list[Path]:
"""Auto-detect all Apple Mail directories."""
mail_base = Path.home() / "Library" / "Mail"
if not mail_base.exists():
return []
# Find all Messages directories
messages_dirs = []
for item in mail_base.rglob("Messages"):
if item.is_dir():
messages_dirs.append(item)
return messages_dirs
async def load_data(self, args) -> list[str]:
"""Load emails and convert to text chunks."""
# Determine mail directories
if args.mail_path:
messages_dirs = [Path(args.mail_path)]
else:
print("Auto-detecting Apple Mail directories...")
messages_dirs = self._find_mail_directories()
if not messages_dirs:
print("No Apple Mail directories found!")
print("Please specify --mail-path manually")
return []
print(f"Found {len(messages_dirs)} mail directories")
# Create reader
reader = EmlxReader(include_html=args.include_html)
# Process each directory
all_documents = []
total_processed = 0
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
try:
# Count emlx files
emlx_files = list(messages_dir.glob("*.emlx"))
print(f"Found {len(emlx_files)} email files")
# Apply max_items limit per directory
max_per_dir = -1 # Default to process all
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_dir = remaining
# If args.max_items == -1, max_per_dir stays -1 (process all)
# Load emails - fix the parameter passing
documents = reader.load_data(
input_dir=str(messages_dir),
max_count=max_per_dir,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} emails from this directory")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No emails found to process!")
return []
print(f"\nTotal emails processed: {len(all_documents)}")
print("now starting to split into text chunks ... take some time")
# Convert to text chunks
# Email reader uses chunk_overlap=25 as in original
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
return all_texts
if __name__ == "__main__":
import asyncio
# Check platform
if sys.platform != "darwin":
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
print(" Windows/Linux support coming soon!\n")
# Example queries for email RAG
print("\n📧 Email RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did my boss say about deadlines?'")
print("- 'Find emails about travel expenses'")
print("- 'Show me emails from last month about the project'")
print("- 'What food did I order from DoorDash?'")
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
rag = EmailRAG()
asyncio.run(rag.run())

View File

@@ -0,0 +1,3 @@
from .history import ChromeHistoryReader
__all__ = ["ChromeHistoryReader"]

View File

@@ -0,0 +1,186 @@
import os
import sqlite3
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChromeHistoryReader(BaseReader):
"""
Chrome browser history reader that extracts browsing data from SQLite database.
Reads Chrome history from the default Chrome profile location and creates documents
with embedded metadata similar to the email reader structure.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load Chrome history data from the default Chrome profile location.
Args:
input_dir: Not used for Chrome history (kept for compatibility)
**load_kwargs:
max_count (int): Maximum amount of history entries to read.
chrome_profile_path (str): Custom path to Chrome profile directory.
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
# Default Chrome profile path on macOS
if chrome_profile_path is None:
chrome_profile_path = os.path.expanduser(
"~/Library/Application Support/Google/Chrome/Default"
)
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return docs
try:
# Connect to the Chrome history database
print(f"Connecting to database: {history_db_path}")
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
# Query to get browsing history with metadata (removed created_time column)
query = """
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
ORDER BY last_visit_time DESC
"""
print(f"Executing query on database: {history_db_path}")
cursor.execute(query)
rows = cursor.fetchall()
print(f"Query returned {len(rows)} rows")
count = 0
for row in rows:
if count >= max_count and max_count > 0:
break
last_visit, url, title, visit_count, typed_count, hidden = row
# Create document content with metadata embedded in text
doc_content = f"""
[Title]: {title}
[URL of the page]: {url}
[Last visited time]: {last_visit}
[Visit times]: {visit_count}
[Typed times]: {typed_count}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={"title": title[0:150]})
# if len(title) > 150:
# print(f"Title is too long: {title}")
docs.append(doc)
count += 1
conn.close()
print(f"Loaded {len(docs)} Chrome history documents")
except Exception as e:
print(f"Error reading Chrome history: {e}")
# add you may need to close your browser to make the database file available
# also highlight in red
print(
"\033[91mYou may need to close your browser to make the database file available\033[0m"
)
return docs
return docs
@staticmethod
def find_chrome_profiles() -> list[Path]:
"""
Find all Chrome profile directories.
Returns:
List of Path objects pointing to Chrome profile directories
"""
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
profile_dirs = []
if not chrome_base_path.exists():
print(f"Chrome directory not found at: {chrome_base_path}")
return profile_dirs
# Find all profile directories
for profile_dir in chrome_base_path.iterdir():
if profile_dir.is_dir() and profile_dir.name != "System Profile":
history_path = profile_dir / "History"
if history_path.exists():
profile_dirs.append(profile_dir)
print(f"Found Chrome profile: {profile_dir}")
print(f"Found {len(profile_dirs)} Chrome profiles")
return profile_dirs
@staticmethod
def export_history_to_file(
output_file: str = "chrome_history_export.txt", max_count: int = 1000
):
"""
Export Chrome history to a text file using the same SQL query format.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
"""
chrome_profile_path = os.path.expanduser(
"~/Library/Application Support/Google/Chrome/Default"
)
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return
try:
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
query = """
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
ORDER BY last_visit_time DESC
LIMIT ?
"""
cursor.execute(query, (max_count,))
rows = cursor.fetchall()
with open(output_file, "w", encoding="utf-8") as f:
for row in rows:
last_visit, url, title, visit_count, typed_count, hidden = row
f.write(
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
)
conn.close()
print(f"Exported {len(rows)} history entries to {output_file}")
except Exception as e:
print(f"Error exporting Chrome history: {e}")

View File

@@ -0,0 +1,774 @@
import json
import os
import re
import subprocess
import time
from datetime import datetime
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class WeChatHistoryReader(BaseReader):
"""
WeChat chat history reader that extracts chat data from exported JSON files.
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
and creates documents with embedded metadata similar to the Chrome history reader structure.
Also includes utilities for automatic WeChat chat history export.
"""
def __init__(self) -> None:
"""Initialize."""
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
def check_wechat_running(self) -> bool:
"""Check if WeChat is currently running."""
try:
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
return result.returncode == 0
except Exception:
return False
def install_wechattweak(self) -> bool:
"""Install WeChatTweak CLI tool."""
try:
# Create wechat-exporter directory if it doesn't exist
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
if not wechattweak_path.exists():
print("Downloading WeChatTweak CLI...")
subprocess.run(
[
"curl",
"-L",
"-o",
str(wechattweak_path),
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
],
check=True,
)
# Make executable
wechattweak_path.chmod(0o755)
# Install WeChatTweak
print("Installing WeChatTweak...")
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
return True
except Exception as e:
print(f"Error installing WeChatTweak: {e}")
return False
def restart_wechat(self):
"""Restart WeChat to apply WeChatTweak."""
try:
print("Restarting WeChat...")
subprocess.run(["pkill", "-f", "WeChat"], check=False)
time.sleep(2)
subprocess.run(["open", "-a", "WeChat"], check=True)
time.sleep(5) # Wait for WeChat to start
except Exception as e:
print(f"Error restarting WeChat: {e}")
def check_api_available(self) -> bool:
"""Check if WeChatTweak API is available."""
try:
result = subprocess.run(
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
capture_output=True,
text=True,
timeout=5,
)
return result.returncode == 0 and result.stdout.strip()
except Exception:
return False
def _extract_readable_text(self, content: str) -> str:
"""
Extract readable text from message content, removing XML and system messages.
Args:
content: The raw message content (can be string or dict)
Returns:
Cleaned, readable text
"""
if not content:
return ""
# Handle dictionary content (like quoted messages)
if isinstance(content, dict):
# Extract text from dictionary structure
text_parts = []
if "title" in content:
text_parts.append(str(content["title"]))
if "quoted" in content:
text_parts.append(str(content["quoted"]))
if "content" in content:
text_parts.append(str(content["content"]))
if "text" in content:
text_parts.append(str(content["text"]))
if text_parts:
return " | ".join(text_parts)
else:
# If we can't extract meaningful text from dict, return empty
return ""
# Handle string content
if not isinstance(content, str):
return ""
# Remove common prefixes like "wxid_xxx:\n"
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
# If it's just XML or system message, return empty
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
return ""
return clean_content.strip()
def _is_text_message(self, content: str) -> bool:
"""
Check if a message contains readable text content.
Args:
content: The message content (can be string or dict)
Returns:
True if the message contains readable text, False otherwise
"""
if not content:
return False
# Handle dictionary content
if isinstance(content, dict):
# Check if dict has any readable text fields
text_fields = ["title", "quoted", "content", "text"]
for field in text_fields:
if content.get(field):
return True
return False
# Handle string content
if not isinstance(content, str):
return False
# Skip image messages (contain XML with img tags)
if "<img" in content and "cdnurl" in content:
return False
# Skip emoji messages (contain emoji XML tags)
if "<emoji" in content and "productid" in content:
return False
# Skip voice messages
if "<voice" in content:
return False
# Skip video messages
if "<video" in content:
return False
# Skip file messages
if "<appmsg" in content and "appid" in content:
return False
# Skip system messages (like "recalled a message")
if "recalled a message" in content:
return False
# Check if there's actual readable text (not just XML or system messages)
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
# If after cleaning we have meaningful text, consider it readable
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
return True
return False
def _concatenate_messages(
self,
messages: list[dict],
max_length: int = 128,
time_window_minutes: int = 30,
overlap_messages: int = 0,
) -> list[dict]:
"""
Concatenate messages based on length and time rules.
Args:
messages: List of message dictionaries
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
overlap_messages: Number of messages to overlap between consecutive groups
Returns:
List of concatenated message groups
"""
if not messages:
return []
concatenated_groups = []
current_group = []
current_length = 0
last_timestamp = None
for message in messages:
# Extract message info
content = message.get("content", "")
message_text = message.get("message", "")
create_time = message.get("createTime", 0)
message.get("fromUser", "")
message.get("toUser", "")
message.get("isSentFromSelf", False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Skip empty messages
if not readable_text.strip():
continue
# Check time window constraint (only if time_window_minutes != -1)
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
time_diff_minutes = (create_time - last_timestamp) / 60
if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group
if current_group:
concatenated_groups.append(
{
"messages": current_group,
"total_length": current_length,
"start_time": current_group[0].get("createTime", 0),
"end_time": current_group[-1].get("createTime", 0),
}
)
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(
len(
self._extract_readable_text(msg.get("content", ""))
or msg.get("message", "")
)
for msg in current_group
)
else:
current_group = []
current_length = 0
# Check length constraint (only if max_length != -1)
message_length = len(readable_text)
if max_length != -1 and current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new
concatenated_groups.append(
{
"messages": current_group,
"total_length": current_length,
"start_time": current_group[0].get("createTime", 0),
"end_time": current_group[-1].get("createTime", 0),
}
)
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(
len(
self._extract_readable_text(msg.get("content", ""))
or msg.get("message", "")
)
for msg in current_group
)
else:
current_group = []
current_length = 0
# Add message to current group
current_group.append(message)
current_length += message_length
last_timestamp = create_time
# Add the last group if it exists
if current_group:
concatenated_groups.append(
{
"messages": current_group,
"total_length": current_length,
"start_time": current_group[0].get("createTime", 0),
"end_time": current_group[-1].get("createTime", 0),
}
)
return concatenated_groups
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
"""
Create concatenated content from a group of messages.
Args:
message_group: Dictionary containing messages and metadata
contact_name: Name of the contact
Returns:
Formatted concatenated content
"""
messages = message_group["messages"]
start_time = message_group["start_time"]
end_time = message_group["end_time"]
# Format timestamps
if start_time:
try:
start_timestamp = datetime.fromtimestamp(start_time)
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
start_time_str = str(start_time)
else:
start_time_str = "Unknown"
if end_time:
try:
end_timestamp = datetime.fromtimestamp(end_time)
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
end_time_str = str(end_time)
else:
end_time_str = "Unknown"
# Build concatenated message content
message_parts = []
for message in messages:
content = message.get("content", "")
message_text = message.get("message", "")
create_time = message.get("createTime", 0)
is_sent_from_self = message.get("isSentFromSelf", False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Format individual message
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
# change to YYYY-MM-DD HH:MM:SS
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
time_str = str(create_time)
else:
time_str = "Unknown"
sender = "[Me]" if is_sent_from_self else "[Contact]"
message_parts.append(f"({time_str}) {sender}: {readable_text}")
concatenated_text = "\n".join(message_parts)
# Create final document content
doc_content = f"""
Contact: {contact_name}
Time Range: {start_time_str} - {end_time_str}
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
{concatenated_text}
"""
# TODO @yichuan give better format and rich info here!
doc_content = f"""
{concatenated_text}
"""
return doc_content, contact_name
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load WeChat chat history data from exported JSON files.
Args:
input_dir: Directory containing exported WeChat JSON files
**load_kwargs:
max_count (int): Maximum amount of chat entries to read.
wechat_export_dir (str): Custom path to WeChat export directory.
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
concatenate_messages (bool): Whether to concatenate messages based on length rules.
max_length (int): Maximum length for concatenated message groups (default: 1000).
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
include_non_text = load_kwargs.get("include_non_text", False)
concatenate_messages = load_kwargs.get("concatenate_messages", False)
max_length = load_kwargs.get("max_length", 1000)
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
# Default WeChat export path
if wechat_export_dir is None:
wechat_export_dir = "./wechat_export_test"
if not os.path.exists(wechat_export_dir):
print(f"WeChat export directory not found at: {wechat_export_dir}")
return docs
try:
# Find all JSON files in the export directory
json_files = list(Path(wechat_export_dir).glob("*.json"))
print(f"Found {len(json_files)} WeChat chat history files")
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, encoding="utf-8") as f:
chat_data = json.load(f)
# Extract contact name from filename
contact_name = json_file.stem
if concatenate_messages:
# Filter messages to only include readable text messages
readable_messages = []
for message in chat_data:
try:
content = message.get("content", "")
if not include_non_text and not self._is_text_message(content):
continue
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
readable_messages.append(message)
except Exception as e:
print(f"Error processing message in {json_file}: {e}")
continue
# Concatenate messages based on rules
message_groups = self._concatenate_messages(
readable_messages,
max_length=max_length,
time_window_minutes=time_window_minutes,
overlap_messages=0, # No overlap between groups
)
# Create documents from concatenated groups
for message_group in message_groups:
if count >= max_count and max_count > 0:
break
doc_content, contact_name = self._create_concatenated_content(
message_group, contact_name
)
doc = Document(
text=doc_content,
metadata={"contact_name": contact_name},
)
docs.append(doc)
count += 1
print(
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
)
else:
# Original single-message processing
for message in chat_data:
if count >= max_count and max_count > 0:
break
# Extract message information
message.get("fromUser", "")
message.get("toUser", "")
content = message.get("content", "")
message_text = message.get("message", "")
create_time = message.get("createTime", 0)
is_sent_from_self = message.get("isSentFromSelf", False)
# Handle content that might be dict or string
try:
# Check if this is a readable text message
if not include_non_text and not self._is_text_message(content):
continue
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
except Exception as e:
# Skip messages that cause processing errors
print(f"Error processing message in {json_file}: {e}")
continue
# Convert timestamp to readable format
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
time_str = str(create_time)
else:
time_str = "Unknown"
# Create document content with metadata header and contact info
doc_content = f"""
Contact: {contact_name}
Is sent from self: {is_sent_from_self}
Time: {time_str}
Message: {readable_text if readable_text else message_text}
"""
# Create document with embedded metadata
doc = Document(
text=doc_content, metadata={"contact_name": contact_name}
)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error reading {json_file}: {e}")
continue
print(f"Loaded {len(docs)} WeChat chat documents")
except Exception as e:
print(f"Error reading WeChat history: {e}")
return docs
return docs
@staticmethod
def find_wechat_export_dirs() -> list[Path]:
"""
Find all WeChat export directories.
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for common export directory names
possible_dirs = [
Path("./wechat_export"),
Path("./wechat_export_direct"),
Path("./wechat_chat_history"),
Path("./chat_export"),
]
for export_dir in possible_dirs:
if export_dir.exists() and export_dir.is_dir():
json_files = list(export_dir.glob("*.json"))
if json_files:
export_dirs.append(export_dir)
print(
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
)
print(f"Found {len(export_dirs)} WeChat export directories")
return export_dirs
@staticmethod
def export_chat_to_file(
output_file: str = "wechat_chat_export.txt",
max_count: int = 1000,
export_dir: str | None = None,
include_non_text: bool = False,
):
"""
Export WeChat chat history to a text file.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
export_dir: Directory containing WeChat JSON files
include_non_text: Whether to include non-text messages
"""
if export_dir is None:
export_dir = "./wechat_export_test"
if not os.path.exists(export_dir):
print(f"WeChat export directory not found at: {export_dir}")
return
try:
json_files = list(Path(export_dir).glob("*.json"))
with open(output_file, "w", encoding="utf-8") as f:
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, encoding="utf-8") as json_f:
chat_data = json.load(json_f)
contact_name = json_file.stem
f.write(f"\n=== Chat with {contact_name} ===\n")
for message in chat_data:
if count >= max_count and max_count > 0:
break
from_user = message.get("fromUser", "")
content = message.get("content", "")
message_text = message.get("message", "")
create_time = message.get("createTime", 0)
# Skip non-text messages unless requested
if not include_non_text:
reader = WeChatHistoryReader()
if not reader._is_text_message(content):
continue
readable_text = reader._extract_readable_text(content)
if not readable_text:
continue
message_text = readable_text
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
time_str = str(create_time)
else:
time_str = "Unknown"
f.write(f"[{time_str}] {from_user}: {message_text}\n")
count += 1
except Exception as e:
print(f"Error processing {json_file}: {e}")
continue
print(f"Exported {count} chat entries to {output_file}")
except Exception as e:
print(f"Error exporting WeChat chat history: {e}")
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
"""
Export WeChat chat history using wechat-exporter tool.
Args:
export_dir: Directory to save exported chat history
Returns:
Path to export directory if successful, None otherwise
"""
try:
import subprocess
import sys
# Create export directory
export_path = Path(export_dir)
export_path.mkdir(exist_ok=True)
print(f"Exporting WeChat chat history to {export_path}...")
# Check if wechat-exporter directory exists
if not self.wechat_exporter_dir.exists():
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
return None
# Install requirements if needed
requirements_file = self.wechat_exporter_dir / "requirements.txt"
if requirements_file.exists():
print("Installing wechat-exporter requirements...")
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
# Run the export command
print("Running wechat-exporter...")
result = subprocess.run(
[
sys.executable,
str(self.wechat_exporter_dir / "main.py"),
"export-all",
str(export_path),
],
capture_output=True,
text=True,
check=True,
)
print("Export command output:")
print(result.stdout)
if result.stderr:
print("Export errors:")
print(result.stderr)
# Check if export was successful
if export_path.exists() and any(export_path.glob("*.json")):
json_files = list(export_path.glob("*.json"))
print(
f"Successfully exported {len(json_files)} chat history files to {export_path}"
)
return export_path
else:
print("Export completed but no JSON files found")
return None
except subprocess.CalledProcessError as e:
print(f"Export command failed: {e}")
print(f"Command output: {e.stdout}")
print(f"Command errors: {e.stderr}")
return None
except Exception as e:
print(f"Export failed: {e}")
print("Please ensure WeChat is running and WeChatTweak is installed.")
return None
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
"""
Find existing WeChat exports or create new ones.
Args:
export_dir: Directory to save exported chat history if needed
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for existing exports in common locations
possible_export_dirs = [
Path("./wechat_database_export"),
Path("./wechat_export_test"),
Path("./wechat_export"),
Path("./wechat_export_direct"),
Path("./wechat_chat_history"),
Path("./chat_export"),
]
for export_dir_path in possible_export_dirs:
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
export_dirs.append(export_dir_path)
print(f"Found existing export: {export_dir_path}")
# If no existing exports, try to export automatically
if not export_dirs:
print("No existing WeChat exports found. Starting direct export...")
# Try to export using wechat-exporter
exported_path = self.export_wechat_chat_history(export_dir)
if exported_path:
export_dirs = [exported_path]
else:
print(
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
)
return export_dirs

189
apps/wechat_rag.py Normal file
View File

@@ -0,0 +1,189 @@
"""
WeChat History RAG example using the unified interface.
Supports WeChat chat history export and search.
"""
import subprocess
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from .history_data.wechat_history import WeChatHistoryReader
class WeChatRAG(BaseRAGExample):
"""RAG example for WeChat chat history."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Match original default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="WeChat History",
description="Process and query WeChat chat history with LEANN",
default_index_name="wechat_history_magic_test_11Debug_new",
)
def _add_specific_arguments(self, parser):
"""Add WeChat-specific arguments."""
wechat_group = parser.add_argument_group("WeChat Parameters")
wechat_group.add_argument(
"--export-dir",
type=str,
default="./wechat_export",
help="Directory to store WeChat exports (default: ./wechat_export)",
)
wechat_group.add_argument(
"--force-export",
action="store_true",
help="Force re-export of WeChat data even if exports exist",
)
wechat_group.add_argument(
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
)
wechat_group.add_argument(
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
)
def _export_wechat_data(self, export_dir: Path) -> bool:
"""Export WeChat data using wechattweak-cli."""
print("Exporting WeChat data...")
# Check if WeChat is running
try:
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
if result.returncode != 0:
print("WeChat is not running. Please start WeChat first.")
return False
except Exception:
pass # pgrep might not be available on all systems
# Create export directory
export_dir.mkdir(parents=True, exist_ok=True)
# Run export command
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
try:
print(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
print("WeChat data exported successfully!")
return True
else:
print(f"Export failed: {result.stderr}")
return False
except FileNotFoundError:
print("\nError: wechattweak-cli not found!")
print("Please install it first:")
print(" sudo packages/wechat-exporter/wechattweak-cli install")
return False
except Exception as e:
print(f"Export error: {e}")
return False
async def load_data(self, args) -> list[str]:
"""Load WeChat history and convert to text chunks."""
# Initialize WeChat reader with export capabilities
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Trying to find any existing exports...")
# Try to find any existing exports in common locations
export_dirs = reader.find_wechat_export_dirs()
if not export_dirs:
print("No WeChat data found. Please ensure WeChat exports exist.")
return []
# Load documents from all found export directories
all_documents = []
total_processed = 0
for i, export_dir in enumerate(export_dirs):
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
try:
# Apply max_items limit per export
max_per_export = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_export = remaining
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_per_export,
concatenate_messages=True, # Enable message concatenation for better context
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return []
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
print("now starting to split into text chunks ... take some time")
# Convert to text chunks with contact information
all_texts = []
for doc in all_documents:
# Split the document into chunks
from llama_index.core.node_parser import SentenceSplitter
text_splitter = SentenceSplitter(
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
# Add contact information to each chunk
contact_name = doc.metadata.get("contact_name", "Unknown")
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
all_texts.append(text)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
return all_texts
if __name__ == "__main__":
import asyncio
# Check platform
if sys.platform != "darwin":
print("\n⚠️ Warning: WeChat export is only supported on macOS")
print(" You can still query existing exports on other platforms\n")
# Example queries for WeChat RAG
print("\n💬 WeChat History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'Show me conversations about travel plans'")
print("- 'Find group chats about weekend activities'")
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
print("- 'What did we discuss about the project last month?'")
print("\nNote: WeChat must be running for export to work\n")
rag = WeChatRAG()
asyncio.run(rag.run())

BIN
assets/arch.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

BIN
assets/effects.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 339 KiB

BIN
assets/logo-text.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 818 KiB

BIN
assets/logo.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

BIN
assets/mcp_leann.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 224 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 152 KiB

View File

@@ -1,9 +1,24 @@
# 🧪 Leann Sanity Checks # 🧪 LEANN Benchmarks & Testing
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations. This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
## 📁 Test Files ## 📁 Test Files
### `diskann_vs_hnsw_speed_comparison.py`
Performance comparison between DiskANN and HNSW backends:
- ✅ **Search latency** comparison with both backends using recompute
- ✅ **Index size** and **build time** measurements
- ✅ **Score validity** testing (ensures no -inf scores)
- ✅ **Configurable dataset sizes** for different scales
```bash
# Quick comparison with 500 docs, 10 queries
python benchmarks/diskann_vs_hnsw_speed_comparison.py
# Large-scale comparison with 2000 docs, 20 queries
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
```
### `test_distance_functions.py` ### `test_distance_functions.py`
Tests all supported distance functions across DiskANN backend: Tests all supported distance functions across DiskANN backend:
- ✅ **MIPS** (Maximum Inner Product Search) - ✅ **MIPS** (Maximum Inner Product Search)

View File

@@ -0,0 +1,141 @@
import time
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from mlx_lm import load
from sentence_transformers import SentenceTransformer
# --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
NUM_RUNS = 10 # Number of runs to average for each batch size
WARMUP_RUNS = 2 # Number of warm-up runs
# --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- Benchmark Functions ---b
def benchmark_torch(model, sentences):
start_time = time.time()
model.encode(sentences, convert_to_numpy=True)
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
def benchmark_mlx(model, tokenizer, sentences):
start_time = time.time()
# Tokenize sentences using MLX tokenizer
tokens = []
for sentence in sentences:
token_ids = tokenizer.encode(sentence)
tokens.append(token_ids)
# Pad sequences to the same length
max_len = max(len(t) for t in tokens)
input_ids = []
attention_mask = []
for token_seq in tokens:
# Pad sequence
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
input_ids.append(padded)
# Create attention mask (1 for real tokens, 0 for padding)
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
attention_mask.append(mask)
# Convert to MLX arrays
input_ids = mx.array(input_ids)
attention_mask = mx.array(attention_mask)
# Get embeddings
embeddings = model(input_ids)
# Mean pooling
mask = mx.expand_dims(attention_mask, -1)
sum_embeddings = (embeddings * mask).sum(axis=1)
sum_mask = mask.sum(axis=1)
_ = sum_embeddings / sum_mask
mx.eval() # Ensure computation is finished
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution ---
def main():
print("--- Initializing Models ---")
# Load PyTorch model
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
print(f"PyTorch model loaded on: {device}")
# Load MLX model
print(f"Loading MLX model: {MODEL_NAME_MLX}")
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
print("MLX model loaded.")
# --- Warm-up ---
print("\n--- Performing Warm-up Runs ---")
for _ in range(WARMUP_RUNS):
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
print("Warm-up complete.")
# --- Benchmarking ---
print("\n--- Starting Benchmark ---")
results_torch = []
results_mlx = []
for batch_size in BATCH_SIZES:
print(f"Benchmarking batch size: {batch_size}")
sentences_batch = DUMMY_SENTENCES[:batch_size]
# Benchmark PyTorch
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
results_torch.append(np.mean(torch_times))
# Benchmark MLX
mlx_times = [
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
]
results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---")
print(f"Batch Sizes: {BATCH_SIZES}")
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(
BATCH_SIZES,
results_torch,
marker="o",
linestyle="-",
label=f"PyTorch ({device})",
)
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES)
plt.grid(True)
plt.legend()
# Save the plot
output_filename = "embedding_benchmark.png"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,148 @@
import argparse
import os
import time
from pathlib import Path
from leann import LeannBuilder, LeannSearcher
def _meta_exists(index_path: str) -> bool:
p = Path(index_path)
return (p.parent / f"{p.stem}.meta.json").exists()
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
# if _meta_exists(index_path):
# return
kwargs = {}
if backend_name == "hnsw":
kwargs["is_compact"] = is_recompute
builder = LeannBuilder(
backend_name=backend_name,
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
graph_degree=32,
complexity=64,
is_recompute=is_recompute,
num_threads=4,
**kwargs,
)
for i in range(num_docs):
builder.add_text(
f"This is a test document number {i}. It contains some repeated text for benchmarking."
)
builder.build_index(index_path)
def _bench_group(
index_path: str,
recompute: bool,
query: str,
repeats: int,
complexity: int = 32,
top_k: int = 10,
) -> float:
# Independent searcher per group; fixed port when recompute
searcher = LeannSearcher(index_path=index_path)
# Warm-up once
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
def _once() -> float:
t0 = time.time()
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
return time.time() - t0
if repeats <= 1:
t = _once()
else:
vals = [_once() for _ in range(repeats)]
vals.sort()
t = vals[len(vals) // 2]
searcher.cleanup()
return t
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-docs", type=int, default=5000)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument("--complexity", type=int, default=32)
args = parser.parse_args()
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
base.parent.mkdir(parents=True, exist_ok=True)
# ---------- Build HNSW variants ----------
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
# ---------- Build DiskANN variants ----------
diskann_r = str(base / "diskann_r.leann")
diskann_nr = str(base / "diskann_nr.leann")
ensure_index(diskann_r, "diskann", args.num_docs, True)
ensure_index(diskann_nr, "diskann", args.num_docs, False)
# ---------- Helpers ----------
def _size_for(prefix: str) -> int:
p = Path(prefix)
base_dir = p.parent
stem = p.stem
total = 0
for f in base_dir.iterdir():
if f.is_file() and f.name.startswith(stem):
total += f.stat().st_size
return total
# ---------- HNSW benchmark ----------
t_hnsw_r = _bench_group(
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
t_hnsw_nr = _bench_group(
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
size_hnsw_r = _size_for(hnsw_r)
size_hnsw_nr = _size_for(hnsw_nr)
print("Benchmark results (HNSW):")
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
print(
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
)
print(" Expectation: no-recompute should be faster but larger on disk.")
# ---------- DiskANN benchmark ----------
t_diskann_r = _bench_group(
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
)
t_diskann_nr = _bench_group(
diskann_nr,
False,
"DiskANN NR test doc 123",
repeats=args.repeats,
complexity=args.complexity,
)
size_diskann_r = _size_for(diskann_r)
size_diskann_nr = _size_for(diskann_nr)
print("\nBenchmark results (DiskANN):")
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,326 @@
#!/usr/bin/env python3
"""
Memory comparison between Faiss HNSW and LEANN HNSW backend
"""
import gc
import logging
import os
import subprocess
import sys
import time
from pathlib import Path
import psutil
from llama_index.core.node_parser import SentenceSplitter
# Setup logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_memory_usage():
"""Get current memory usage in MB"""
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def print_memory_stats(stage: str, start_mem: float):
"""Print memory statistics"""
current_mem = get_memory_usage()
diff = current_mem - start_mem
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
return current_mem
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
print(f"\n=== {self.name} Memory Summary ===")
for stage, mem in self.stages:
print(f"{stage}: {mem:.1f} MB")
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
return peak_mem
def test_faiss_hnsw():
"""Test Faiss HNSW Vector Store in subprocess"""
print("\n" + "=" * 50)
print("TESTING FAISS HNSW VECTOR STORE")
print("=" * 50)
try:
result = subprocess.run(
[sys.executable, "benchmarks/faiss_only.py"],
capture_output=True,
text=True,
timeout=300,
)
print(result.stdout)
if result.stderr:
print("Stderr:", result.stderr)
if result.returncode != 0:
return {
"peak_memory": float("inf"),
"error": f"Process failed with code {result.returncode}",
}
# Parse peak memory from output
lines = result.stdout.split("\n")
peak_memory = 0.0
for line in lines:
if "Peak Memory:" in line:
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
return {"peak_memory": peak_memory}
except Exception as e:
return {
"peak_memory": float("inf"),
"error": str(e),
}
def test_leann_hnsw():
"""Test LEANN HNSW Search Memory (load existing index)"""
print("\n" + "=" * 50)
print("TESTING LEANN HNSW SEARCH MEMORY")
print("=" * 50)
tracker = MemoryTracker("LEANN HNSW Search")
# Import and setup
tracker.checkpoint("Initial")
from leann.api import LeannSearcher
tracker.checkpoint("After imports")
from leann.api import LeannBuilder
from llama_index.core import SimpleDirectoryReader
# Load and parse documents
documents = SimpleDirectoryReader(
"data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Total number of chunks: {len(all_texts)}")
tracker.checkpoint("After text chunking")
# Build LEANN index
INDEX_DIR = Path("./test_leann_comparison")
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
# Check if index already exists
if os.path.exists(INDEX_PATH + ".meta.json"):
print("Loading existing LEANN HNSW index...")
tracker.checkpoint("After loading existing index")
else:
print("Building new LEANN HNSW index...")
# Clean up previous index
import shutil
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1,
)
tracker.checkpoint("After builder setup")
print("Building LEANN HNSW index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
del builder
gc.collect()
tracker.checkpoint("After index building")
# Find existing LEANN index
index_paths = [
"./test_leann_comparison/comparison.leann",
]
index_path = None
for path in index_paths:
if os.path.exists(path + ".meta.json"):
index_path = path
break
if not index_path:
print("❌ LEANN index not found. Please build it first")
return {"peak_memory": float("inf"), "error": "Index not found"}
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
# Load searcher
searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading")
print("Running search queries...")
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
_ = searcher.search(query, top_k=20, ef=120)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
# Get storage size before cleanup
storage_size = 0
INDEX_DIR = Path(index_path).parent
if INDEX_DIR.exists():
total_size = 0
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
for filename in filenames:
# Only count actual index files, skip text data and backups
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
continue
# Count .index, .idx, .map files (actual index structures)
if filename.endswith((".index", ".idx", ".map")):
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
storage_size = total_size / (1024 * 1024) # Convert to MB
# Clean up
del searcher
gc.collect()
return {
"peak_memory": peak_memory,
"storage_size": storage_size,
}
def main():
"""Run comparison tests"""
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
print("=" * 60)
# Test Faiss HNSW
faiss_results = test_faiss_hnsw()
# Force garbage collection
gc.collect()
time.sleep(2)
# Test LEANN HNSW
leann_results = test_leann_hnsw()
# Final comparison
print("\n" + "=" * 60)
print("STORAGE + SEARCH MEMORY COMPARISON")
print("=" * 60)
# Get storage sizes
faiss_storage_size = 0
leann_storage_size = leann_results.get("storage_size", 0)
# Get Faiss storage size using Python
if os.path.exists("./storage_faiss"):
total_size = 0
for dirpath, _, filenames in os.walk("./storage_faiss"):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
print("Faiss HNSW:")
if "error" in faiss_results:
print(f" ❌ Failed: {faiss_results['error']}")
else:
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f" Storage Size: {faiss_storage_size:.1f} MB")
print("\nLEANN HNSW:")
if "error" in leann_results:
print(f" ❌ Failed: {leann_results['error']}")
else:
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f" Storage Size: {leann_storage_size:.1f} MB")
# Calculate improvements only if both tests succeeded
if "error" not in faiss_results and "error" not in leann_results:
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
# Storage comparison
if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
else:
print(" Storage Size: similar")
else:
if "error" not in leann_results:
print("\n✅ LEANN HNSW completed successfully!")
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
if "error" not in faiss_results:
print("\n✅ Faiss HNSW completed successfully!")
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
if __name__ == "__main__":
main()

44
benchmarks/data/README.md Executable file
View File

@@ -0,0 +1,44 @@
---
license: mit
---
# LEANN-RAG Evaluation Data
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
## Dataset Components
This dataset is structured into three main parts:
1. **Pre-built LEANN Indices**:
* `dpr/`: A pre-built index for the DPR dataset.
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
2. **Ground Truth Data**:
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
3. **Queries**:
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
## Usage
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
```bash
pip install huggingface-hub
```
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
```python
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir="data"
)
```
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.

View File

@@ -0,0 +1,286 @@
#!/usr/bin/env python3
"""
DiskANN vs HNSW Search Performance Comparison
This benchmark compares search performance between DiskANN and HNSW backends:
- DiskANN: With graph partitioning enabled (is_recompute=True)
- HNSW: With recompute enabled (is_recompute=True)
- Tests performance across different dataset sizes
- Measures search latency, recall, and index size
"""
import gc
import multiprocessing as mp
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
try:
mp.set_start_method("fork", force=True)
except Exception:
pass
def create_test_texts(n_docs: int) -> list[str]:
"""Create synthetic test documents for benchmarking."""
np.random.seed(42)
topics = [
"machine learning and artificial intelligence",
"natural language processing and text analysis",
"computer vision and image recognition",
"data science and statistical analysis",
"deep learning and neural networks",
"information retrieval and search engines",
"database systems and data management",
"software engineering and programming",
"cybersecurity and network protection",
"cloud computing and distributed systems",
]
texts = []
for i in range(n_docs):
topic = topics[i % len(topics)]
variation = np.random.randint(1, 100)
text = (
f"This is document {i} about {topic}. Content variation {variation}. "
f"Additional information about {topic} with details and examples. "
f"Technical discussion of {topic} including implementation aspects."
)
texts.append(text)
return texts
def benchmark_backend(
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
) -> dict[str, float]:
"""Benchmark a specific backend with the given configuration."""
from leann.api import LeannBuilder, LeannSearcher
print(f"\n🔧 Testing {backend_name.upper()} backend...")
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
# Build index
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
start_time = time.time()
builder = LeannBuilder(
backend_name=backend_name,
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
**backend_kwargs,
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
build_time = time.time() - start_time
# Measure index size
index_dir = Path(index_path).parent
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
size_mb = total_size / (1024 * 1024)
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
# Search benchmark
print("🔍 Running search benchmark...")
searcher = LeannSearcher(index_path)
search_times = []
all_results = []
for query in test_queries:
start_time = time.time()
results = searcher.search(query, top_k=5)
search_time = time.time() - start_time
search_times.append(search_time)
all_results.append(results)
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
# Check for valid scores (detect -inf issues)
all_scores = [
result.score
for results in all_results
for result in results
if result.score is not None
]
valid_scores = [
score for score in all_scores if score != float("-inf") and score != float("inf")
]
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
# Clean up (ensure embedding server shutdown and object GC)
try:
if hasattr(searcher, "cleanup"):
searcher.cleanup()
del searcher
del builder
gc.collect()
except Exception as e:
print(f"⚠️ Warning: Resource cleanup error: {e}")
return {
"build_time": build_time,
"avg_search_time_ms": avg_search_time,
"index_size_mb": size_mb,
"score_validity_rate": score_validity_rate,
}
def run_comparison(n_docs: int = 500, n_queries: int = 10):
"""Run performance comparison between DiskANN and HNSW."""
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
# Create test data
texts = create_test_texts(n_docs)
test_queries = [
"machine learning algorithms",
"natural language processing",
"computer vision techniques",
"data analysis methods",
"neural network architectures",
"database query optimization",
"software development practices",
"security vulnerabilities",
"cloud infrastructure",
"distributed computing",
][:n_queries]
# HNSW benchmark
hnsw_results = benchmark_backend(
backend_name="hnsw",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable recompute for fair comparison
"M": 16,
"efConstruction": 200,
},
)
# DiskANN benchmark
diskann_results = benchmark_backend(
backend_name="diskann",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable graph partitioning
"num_neighbors": 32,
"search_list_size": 50,
},
)
# Performance comparison
print("\n📈 Performance Comparison Results")
print(f"{'=' * 60}")
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
print(f"{'-' * 60}")
# Build time comparison
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
print(
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
)
# Search time comparison
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
print(
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
)
# Index size comparison
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
print(
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
)
# Score validity
print(
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
)
print(f"{'=' * 60}")
print("\n🎯 Summary:")
if search_speedup > 1:
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
else:
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
if size_ratio > 1:
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
else:
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
print(
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
)
if __name__ == "__main__":
import sys
try:
# Handle help request
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
print()
print("Arguments:")
print(" n_docs Number of documents to index (default: 500)")
print(" n_queries Number of test queries to run (default: 10)")
print()
print("Examples:")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
sys.exit(0)
# Parse command line arguments
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Dataset: {n_docs} documents, {n_queries} queries")
print()
run_comparison(n_docs=n_docs, n_queries=n_queries)
except KeyboardInterrupt:
print("\n⚠️ Benchmark interrupted by user")
sys.exit(130)
except Exception as e:
print(f"\n❌ Benchmark failed: {e}")
sys.exit(1)
finally:
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
try:
gc.collect()
print("\n🧹 Cleanup completed")
# Flush stdio to ensure message is visible before hard-exit
try:
import sys as _sys
_sys.stdout.flush()
_sys.stderr.flush()
except Exception:
pass
except Exception:
pass
# Use os._exit to bypass atexit handlers that may hang in rare cases
import os as _os
_os._exit(0)

151
benchmarks/faiss_only.py Normal file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""Test only Faiss HNSW"""
import os
import sys
import time
import psutil
def get_memory_usage():
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = get_memory_usage()
diff = current_mem - self.start_mem
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
return peak_mem
def main():
try:
import faiss
except ImportError:
print("Faiss is not installed.")
print(
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
)
sys.exit(1)
from llama_index.core import (
Settings,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial")
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
Settings.embed_model = embed_model
tracker.checkpoint("After embedding model setup")
d = 768
faiss_index = faiss.IndexHNSWFlat(d, 32)
faiss_index.hnsw.efConstruction = 64
tracker.checkpoint("After Faiss index creation")
documents = SimpleDirectoryReader(
"data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks using the same splitter as LEANN
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
tracker.checkpoint("After text splitter setup")
# Check if index already exists and try to load it
index_loaded = False
if os.path.exists("./storage_faiss"):
print("Loading existing Faiss HNSW index...")
try:
# Use the correct Faiss loading pattern from the example
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir="./storage_faiss"
)
from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context)
print("Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index")
index_loaded = True
except Exception as e:
print(f"Failed to load existing index: {e}")
print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index
import shutil
if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss")
if not index_loaded:
print("Building new Faiss HNSW index...")
# Use the correct Faiss building pattern from the example
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents, storage_context=storage_context, transformations=[node_parser]
)
tracker.checkpoint("After index building")
# Save index to disk using the correct pattern
index.storage_context.persist(persist_dir="./storage_faiss")
tracker.checkpoint("After index saving")
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
query_engine = index.as_query_engine(similarity_top_k=20)
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
_ = query_engine.query(query)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Peak Memory: {peak_memory:.1f} MB")
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
if __name__ == "__main__":
main()

View File

@@ -2,21 +2,20 @@
import argparse import argparse
import time import time
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torchao import quantize_
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager from transformers import AutoModel, BitsAndBytesConfig
@dataclass @dataclass
class BenchmarkConfig: class BenchmarkConfig:
model_path: str model_path: str
batch_sizes: List[int] batch_sizes: list[int]
seq_length: int seq_length: int
num_runs: int num_runs: int
use_fp16: bool = True use_fp16: bool = True
@@ -27,46 +26,58 @@ class BenchmarkConfig:
use_linear8bitlt: bool = False use_linear8bitlt: bool = False
class CUDAGraphContainer: class GraphContainer:
"""Container for managing CUDA graphs for different batch sizes.""" """Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, seq_length: int): def __init__(self, model: nn.Module, seq_length: int):
self.model = model self.model = model
self.seq_length = seq_length self.seq_length = seq_length
self.graphs: Dict[int, CUDAGraphWrapper] = {} self.graphs: dict[int, GraphWrapper] = {}
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper': def get_or_create(self, batch_size: int) -> "GraphWrapper":
if batch_size not in self.graphs: if batch_size not in self.graphs:
self.graphs[batch_size] = CUDAGraphWrapper( self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
self.model, batch_size, self.seq_length
)
return self.graphs[batch_size] return self.graphs[batch_size]
class CUDAGraphWrapper: class GraphWrapper:
"""Wrapper for CUDA graph capture and replay.""" """Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int): def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model self.model = model
self.device = self._get_device()
self.static_input = self._create_random_batch(batch_size, seq_length) self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input) self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up # Warm up
self._warmup() self._warmup()
# Only use CUDA graphs on NVIDIA GPUs
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
# Capture graph # Capture graph
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph): with torch.cuda.graph(self.graph):
self.static_output = self.model( self.static_output = self.model(
input_ids=self.static_input, input_ids=self.static_input,
attention_mask=self.static_attention_mask attention_mask=self.static_attention_mask,
) )
self.use_cuda_graph = True
else:
# For MPS or CPU, just store the model
self.use_cuda_graph = False
self.static_output = None
def _get_device(self) -> str:
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor: def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint( return torch.randint(
0, 1000, (batch_size, seq_length), 0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
device="cuda",
dtype=torch.long
) )
def _warmup(self, num_warmup: int = 3): def _warmup(self, num_warmup: int = 3):
@@ -74,14 +85,18 @@ class CUDAGraphWrapper:
for _ in range(num_warmup): for _ in range(num_warmup):
self.model( self.model(
input_ids=self.static_input, input_ids=self.static_input,
attention_mask=self.static_attention_mask attention_mask=self.static_attention_mask,
) )
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.use_cuda_graph:
self.static_input.copy_(input_ids) self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask) self.static_attention_mask.copy_(attention_mask)
self.graph.replay() self.graph.replay()
return self.static_output return self.static_output
else:
# For MPS/CPU, just run normally
return self.model(input_ids=input_ids, attention_mask=attention_mask)
class ModelOptimizer: class ModelOptimizer:
@@ -95,8 +110,16 @@ class ModelOptimizer:
raise ValueError("Cannot optimize None model") raise ValueError("Cannot optimize None model")
# Move to GPU # Move to GPU
if torch.cuda.is_available():
model = model.cuda() model = model.cuda()
print("- Model moved to GPU") device = "cuda"
elif torch.backends.mps.is_available():
model = model.to("mps")
device = "mps"
else:
model = model.cpu()
device = "cpu"
print(f"- Model moved to {device}")
# FP16 # FP16
if config.use_fp16 and not config.use_int4: if config.use_fp16 and not config.use_int4:
@@ -105,17 +128,22 @@ class ModelOptimizer:
model = torch.compile(model) model = torch.compile(model)
print("- Using FP16 precision") print("- Using FP16 precision")
# Check if using SDPA # Check if using SDPA (only on CUDA)
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6: if (
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)") print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else: else:
print("- PyTorch SDPA not available") print("- PyTorch SDPA not available")
# Flash Attention # Flash Attention (only on CUDA)
if config.use_flash_attention: if config.use_flash_attention and torch.cuda.is_available():
try: try:
from flash_attn.flash_attention import FlashAttention from flash_attn.flash_attention import FlashAttention # noqa: F401
print("- Flash Attention 2 available") print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"): if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2" model.config.attention_mode = "flash_attention_2"
@@ -123,10 +151,12 @@ class ModelOptimizer:
except ImportError: except ImportError:
print("- Flash Attention not available") print("- Flash Attention not available")
# Memory efficient attention # Memory efficient attention (only on CUDA)
if torch.cuda.is_available():
try: try:
from xformers.ops import memory_efficient_attention from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention() model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention") print("- Enabled xformers memory efficient attention")
else: else:
@@ -141,21 +171,38 @@ class ModelOptimizer:
class Timer: class Timer:
"""Handles accurate GPU timing using CUDA events.""" """Handles accurate GPU timing using GPU events or CPU timing."""
def __init__(self): def __init__(self):
if torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True) self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True) self.end_event = torch.cuda.Event(enable_timing=True)
self.use_gpu_timing = True
elif torch.backends.mps.is_available():
# MPS doesn't have events, use CPU timing
self.use_gpu_timing = False
else:
# CPU timing
self.use_gpu_timing = False
@contextmanager @contextmanager
def timing(self): def timing(self):
if self.use_gpu_timing:
self.start_event.record() self.start_event.record()
yield yield
self.end_event.record() self.end_event.record()
self.end_event.synchronize() self.end_event.synchronize()
else:
# Use CPU timing for MPS/CPU
start_time = time.time()
yield
self.cpu_elapsed = time.time() - start_time
def elapsed_time(self) -> float: def elapsed_time(self) -> float:
if self.use_gpu_timing:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
else:
return self.cpu_elapsed
class Benchmark: class Benchmark:
@@ -168,14 +215,14 @@ class Benchmark:
if self.model is None: if self.model is None:
raise ValueError("Model initialization failed - model is None") raise ValueError("Model initialization failed - model is None")
self.cuda_graphs = ( # Only use CUDA graphs on NVIDIA GPUs
CUDAGraphContainer(self.model, config.seq_length) if config.use_cuda_graphs and torch.cuda.is_available():
if config.use_cuda_graphs self.graphs = GraphContainer(self.model, config.seq_length)
else None else:
) self.graphs = None
self.timer = Timer() self.timer = Timer()
except Exception as e: except Exception as e:
print(f"ERROR in benchmark initialization: {str(e)}") print(f"ERROR in benchmark initialization: {e!s}")
raise raise
def _load_model(self) -> nn.Module: def _load_model(self) -> nn.Module:
@@ -185,15 +232,17 @@ class Benchmark:
# Int4 quantization using HuggingFace integration # Int4 quantization using HuggingFace integration
if self.config.use_int4: if self.config.use_int4:
import bitsandbytes as bnb import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}") print(f"- bitsandbytes version: {bnb.__version__}")
# 检查是否使用自定义的8bit量化 # Check if using custom 8bit quantization
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt: if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers") print("- Using custom Linear8bitLt replacement for all linear layers")
# 加载原始模型(不使用量化配置) # Load original model (without quantization config)
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
# set default to half # set default to half
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32 compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
@@ -202,52 +251,58 @@ class Benchmark:
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
) )
# 定义替换函数 # Define replacement function
def replace_linear_with_linear8bitlt(model): def replace_linear_with_linear8bitlt(model):
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt""" """Recursively replace all nn.Linear layers with Linear8bitLt"""
for name, module in list(model.named_children()): for name, module in list(model.named_children()):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# 获取原始线性层的参数 # Get original linear layer parameters
in_features = module.in_features in_features = module.in_features
out_features = module.out_features out_features = module.out_features
bias = module.bias is not None bias = module.bias is not None
# 创建8bit线性层 # Create 8bit linear layer
# print size # print size
print(f"in_features: {in_features}, out_features: {out_features}") print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt( new_module = bnb.nn.Linear8bitLt(
in_features, in_features,
out_features, out_features,
bias=bias, bias=bias,
has_fp16_weights=False has_fp16_weights=False,
) )
# 复制权重和偏置 # Copy weights and bias
new_module.weight.data = module.weight.data new_module.weight.data = module.weight.data
if bias: if bias:
new_module.bias.data = module.bias.data new_module.bias.data = module.bias.data
# 替换模块 # Replace module
setattr(model, name, new_module) setattr(model, name, new_module)
else: else:
# 递归处理子模块 # Process child modules recursively
replace_linear_with_linear8bitlt(module) replace_linear_with_linear8bitlt(module)
return model return model
# 替换所有线性层 # Replace all linear layers
model = replace_linear_with_linear8bitlt(model) model = replace_linear_with_linear8bitlt(model)
# add torch compile # add torch compile
model = torch.compile(model) model = torch.compile(model)
# 将模型移到GPU量化发生在这里 # Move model to GPU (quantization happens here)
device = "cuda" if torch.cuda.is_available() else "cpu" device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
model = model.to(device) model = model.to(device)
print("- All linear layers replaced with Linear8bitLt") print("- All linear layers replaced with Linear8bitLt")
else: else:
# 使用原来的Int4量化方法 # Use original Int4 quantization method
print("- Using bitsandbytes for Int4 quantization") print("- Using bitsandbytes for Int4 quantization")
# Create quantization config # Create quantization config
@@ -257,7 +312,7 @@ class Benchmark:
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4" bnb_4bit_quant_type="nf4",
) )
print("- Quantization config:", quantization_config) print("- Quantization config:", quantization_config)
@@ -267,7 +322,7 @@ class Benchmark:
self.config.model_path, self.config.model_path,
quantization_config=quantization_config, quantization_config=quantization_config,
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
device_map="auto" # Let HF decide on device mapping device_map="auto", # Let HF decide on device mapping
) )
# Check if model loaded successfully # Check if model loaded successfully
@@ -279,7 +334,7 @@ class Benchmark:
# Apply optimizations directly here # Apply optimizations directly here
print("\nApplying model optimizations:") print("\nApplying model optimizations:")
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt: if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization") print("- Model moved to GPU with Linear8bitLt quantization")
else: else:
# Skip moving to GPU since device_map="auto" already did that # Skip moving to GPU since device_map="auto" already did that
@@ -289,16 +344,20 @@ class Benchmark:
print(f"- Using {compute_dtype} for compute dtype") print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA # Check CUDA and SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6: if (
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)") print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else: else:
print("- PyTorch SDPA not available") print("- PyTorch SDPA not available")
# Try xformers if available # Try xformers if available (only on CUDA)
if torch.cuda.is_available():
try: try:
from xformers.ops import memory_efficient_attention if hasattr(model, "enable_xformers_memory_efficient_attention"):
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention() model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention") print("- Enabled xformers memory efficient attention")
else: else:
@@ -310,58 +369,30 @@ class Benchmark:
model.eval() model.eval()
print("- Model set to eval mode") print("- Model set to eval mode")
# Int8 quantization using HuggingFace integration # Int8 quantization using HuggingFace integration
# Int8 quantization using TorchAO
elif self.config.use_int8: elif self.config.use_int8:
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization") print("- Using INT8 quantization")
# For now, just use standard loading with INT8 config
# Import the quantize_ function and the quantization config compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight quantization_config = BitsAndBytesConfig(
print("- Successfully imported TorchAO") load_in_8bit=True,
llm_int8_threshold=6.0,
# Load model normally first llm_int8_has_fp16_weight=False,
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = AutoModel.from_pretrained(
self.config.model_path,
device_map="auto"
) )
print("- Model loaded in full precision") model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto",
)
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}") print(f"- Model type: {type(model)}")
# Apply quantization - call the function to get the config, then apply it
# quantize_(model, int8_dynamic_activation_int8_weight())
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
quantize_(model, Int8DynamicActivationInt8WeightConfig())
print("- Model successfully quantized with int8 weights and int8 activations")
# add torch compile
model = torch.compile(model)
# For older PyTorch versions that have issues with tensor subclasses
from torchao.utils import unwrap_tensor_subclass
import torch
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
unwrap_tensor_subclass(model)
# Apply optimizations
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Set to eval mode
model.eval() model.eval()
print("- Model set to eval mode") print("- Model set to eval mode")
# For better performance with int8 dynamic quantization
torch._inductor.config.force_fuse_int_mm_with_mul = True
print("- Enabled fusion of int matmul with mul operations")
else: else:
# Standard loading for FP16/FP32 # Standard loading for FP16/FP32
model = AutoModel.from_pretrained(self.config.model_path) model = AutoModel.from_pretrained(self.config.model_path)
@@ -371,6 +402,7 @@ class Benchmark:
# Apply standard optimizations # Apply standard optimizations
# set default to half # set default to half
import torch import torch
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config) model = ModelOptimizer.optimize(model, self.config)
model = model.half() model = model.half()
@@ -385,49 +417,60 @@ class Benchmark:
return model return model
except Exception as e: except Exception as e:
print(f"ERROR loading model: {str(e)}") print(f"ERROR loading model: {e!s}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor: def _create_random_batch(self, batch_size: int) -> torch.Tensor:
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
return torch.randint( return torch.randint(
0, 1000, 0,
1000,
(batch_size, self.config.seq_length), (batch_size, self.config.seq_length),
device="cuda", device=device,
dtype=torch.long dtype=torch.long,
) )
def _run_inference( def _run_inference(
self, self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
input_ids: torch.Tensor, ) -> tuple[float, torch.Tensor]:
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing(): with torch.no_grad(), self.timer.timing():
if cuda_graph_wrapper is not None: if graph_wrapper is not None:
output = cuda_graph_wrapper(input_ids, attention_mask) output = graph_wrapper(input_ids, attention_mask)
else: else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask) output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return self.timer.elapsed_time(), output return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]: def run(self) -> dict[int, dict[str, float]]:
results = {} results = {}
# Reset peak memory stats # Reset peak memory stats
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
elif torch.backends.mps.is_available():
# MPS doesn't have reset_peak_memory_stats, skip it
pass
else:
print("- No GPU memory stats available")
for batch_size in self.config.batch_sizes: for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}") print(f"\nTesting batch size: {batch_size}")
times = [] times = []
# Get or create CUDA graph for this batch size # Get or create graph for this batch size
cuda_graph_wrapper = ( graph_wrapper = (
self.cuda_graphs.get_or_create(batch_size) self.graphs.get_or_create(batch_size) if self.graphs is not None else None
if self.cuda_graphs is not None
else None
) )
# Pre-allocate input tensor # Pre-allocate input tensor
@@ -437,7 +480,7 @@ class Benchmark:
# Run benchmark # Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"): for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try: try:
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper) elapsed_time, output = self._run_inference(input_ids, graph_wrapper)
if i == 0: # Only print on first run if i == 0: # Only print on first run
print(f"Output shape: {output.last_hidden_state.shape}") print(f"Output shape: {output.last_hidden_state.shape}")
times.append(elapsed_time) times.append(elapsed_time)
@@ -464,8 +507,19 @@ class Benchmark:
print(f"Throughput: {throughput:.2f} sequences/second") print(f"Throughput: {throughput:.2f} sequences/second")
# Log memory usage # Log memory usage
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3) peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
elif torch.backends.mps.is_available():
# MPS doesn't have max_memory_allocated, use 0
peak_memory_gb = 0.0
else:
peak_memory_gb = 0.0
print("- No GPU memory usage available")
if peak_memory_gb > 0:
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB") print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
else:
print("\n- GPU memory usage not available")
# Add memory info to results # Add memory info to results
for batch_size in results: for batch_size in results:
@@ -485,7 +539,7 @@ def main():
parser.add_argument( parser.add_argument(
"--batch_sizes", "--batch_sizes",
type=str, type=str,
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192", default="1,2,4,8,16,32",
help="Comma-separated list of batch sizes", help="Comma-separated list of batch sizes",
) )
parser.add_argument( parser.add_argument(
@@ -518,12 +572,12 @@ def main():
parser.add_argument( parser.add_argument(
"--use_cuda_graphs", "--use_cuda_graphs",
action="store_true", action="store_true",
help="Enable CUDA Graphs optimization", help="Enable CUDA Graphs optimization (only on NVIDIA GPUs)",
) )
parser.add_argument( parser.add_argument(
"--use_flash_attention", "--use_flash_attention",
action="store_true", action="store_true",
help="Enable Flash Attention 2 if available", help="Enable Flash Attention 2 if available (only on NVIDIA GPUs)",
) )
parser.add_argument( parser.add_argument(
"--use_linear8bitlt", "--use_linear8bitlt",
@@ -568,7 +622,15 @@ def main():
os.makedirs("results", exist_ok=True) os.makedirs("results", exist_ok=True)
# Generate filename based on configuration # Generate filename based on configuration
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32" precision_type = (
"int4"
if config.use_int4
else "int8"
if config.use_int8
else "fp16"
if config.use_fp16
else "fp32"
)
model_name = os.path.basename(config.model_path) model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json" output_file = f"results/benchmark_{model_name}_{precision_type}.json"
@@ -576,17 +638,20 @@ def main():
with open(output_file, "w") as f: with open(output_file, "w") as f:
json.dump( json.dump(
{ {
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()}, "config": {
"results": {str(k): v for k, v in results.items()} k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
},
"results": {str(k): v for k, v in results.items()},
}, },
f, f,
indent=2 indent=2,
) )
print(f"Results saved to {output_file}") print(f"Results saved to {output_file}")
except Exception as e: except Exception as e:
print(f"Benchmark failed: {e}") print(f"Benchmark failed: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View File

@@ -0,0 +1,393 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import argparse
import json
import sys
import time
from pathlib import Path
import numpy as np
from leann.api import LeannBuilder, LeannChat, LeannSearcher
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
try:
from huggingface_hub import snapshot_download
if download_embeddings:
# Download everything including embeddings (large files)
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
)
print("Data download complete (including embeddings)!")
else:
# Download only specific folders, excluding embeddings
allow_patterns = [
"ground_truth/**",
"indices/**",
"queries/**",
"*.md",
"*.txt",
]
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
print("Data download complete (excluding embeddings)!")
except ImportError:
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
if dataset_type:
# Check if specific dataset embeddings exist
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
print(f"Embeddings for {dataset_type} already exist")
return str(target_file)
print("Downloading embeddings from HuggingFace Hub...")
try:
from huggingface_hub import snapshot_download
# Download only embeddings folder
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=["embeddings/**/*.pkl"],
)
print("Embeddings download complete!")
if dataset_type:
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
return str(target_file)
return str(embeddings_dir)
except Exception as e:
print(f"Error downloading embeddings: {e}")
sys.exit(1)
# --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
"""
Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager.
"""
golden_texts = set()
for gid in golden_ids:
try:
# PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"])
except KeyError:
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
return golden_texts
def load_queries(file_path: Path) -> list[str]:
queries = []
with open(file_path, encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data["query"])
return queries
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
"""
Build a LEANN index from pre-computed embeddings.
Args:
embeddings_file: Path to pickle file with (ids, embeddings) tuple
output_path: Path where to save the index
backend: Backend to use ("hnsw" or "diskann")
"""
print(f"Building {backend} index from embeddings: {embeddings_file}")
# Create builder with appropriate parameters
if backend == "hnsw":
builder_kwargs = {
"M": 32, # Graph degree
"efConstruction": 256, # Construction complexity
"is_compact": True, # Use compact storage
"is_recompute": True, # Enable pruning for better recall
}
elif backend == "diskann":
builder_kwargs = {
"complexity": 64,
"graph_degree": 32,
"search_memory_maximum": 8.0, # GB
"build_memory_maximum": 16.0, # GB
}
else:
builder_kwargs = {}
builder = LeannBuilder(
backend_name=backend,
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
dimensions=768, # Will be auto-detected from embeddings
**builder_kwargs,
)
# Build index from precomputed embeddings
builder.build_index_from_embeddings(output_path, embeddings_file)
print(f"Index saved to: {output_path}")
return output_path
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument(
"index_path",
type=str,
nargs="?",
help="Path to the LEANN index to evaluate or build (optional).",
)
parser.add_argument(
"--mode",
choices=["evaluate", "build"],
default="evaluate",
help="Mode: 'evaluate' existing index or 'build' from embeddings",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Path to embeddings pickle file (optional for build mode)",
)
parser.add_argument(
"--backend",
choices=["hnsw", "diskann"],
default="hnsw",
help="Backend to use for building index (default: hnsw)",
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
parser.add_argument(
"--batch-size",
type=int,
default=0,
help="Batch size for HNSW batched search (0 disables batching)",
)
parser.add_argument(
"--llm-type",
type=str,
choices=["ollama", "hf", "openai", "gemini", "simulated"],
default="ollama",
help="LLM backend type to optionally query during evaluation (default: ollama)",
)
parser.add_argument(
"--llm-model",
type=str,
default="qwen3:1.7b",
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
)
args = parser.parse_args()
# --- Path Configuration ---
# Assumes a project structure where the script is in 'benchmarks/'
# and evaluation data is in 'benchmarks/data/'.
script_dir = Path(__file__).resolve().parent
data_root = script_dir / "data"
# Download data based on mode
if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
embeddings_file = args.embeddings_file
# Try to detect dataset type from embeddings file path
if "rpj_wiki" in str(embeddings_file):
dataset_type = "rpj_wiki"
elif "dpr" in str(embeddings_file):
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default
else:
# Auto-detect from index path if provided, otherwise default to DPR
if args.index_path:
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default to DPR
else:
dataset_type = "dpr" # Default to DPR
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
# Auto-generate index path if not provided
if not args.index_path:
indices_dir = data_root / "indices" / dataset_type
indices_dir.mkdir(parents=True, exist_ok=True)
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
print(f"Auto-generated index path: {args.index_path}")
print(f"Building index from embeddings: {embeddings_file}")
built_index_path = build_index_from_embeddings(
embeddings_file, args.index_path, args.backend
)
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
if eval_response != "y":
print("Index building complete. Exiting.")
return
else:
# For evaluation mode, don't need embeddings
download_data_if_needed(data_root, download_embeddings=False)
# Auto-detect index path if not provided
if not args.index_path:
# Default to using downloaded indices
indices_dir = data_root / "indices"
# Try common datasets in order of preference
for dataset in ["dpr", "rpj_wiki"]:
dataset_dir = indices_dir / dataset
if dataset_dir.exists():
# Look for index files
index_files = list(dataset_dir.glob("*.index")) + list(
dataset_dir.glob("*_disk.index")
)
if index_files:
args.index_path = str(
index_files[0].with_suffix("")
) # Remove .index extension
print(f"Using index: {args.index_path}")
break
if not args.index_path:
print("No indices found. The data download should have included pre-built indices.")
print(
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
# Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
print(f"INFO: Using ground truth file: {golden_results_file}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file)
with open(golden_results_file) as f:
golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(
queries[i],
top_k=args.top_k,
complexity=args.ef_search,
batch_size=args.batch_size,
)
search_times.append(time.time() - start_time)
# Optional: also call the LLM with configurable backend/model (does not affect recall)
llm_config = {"type": args.llm_type, "model": args.llm_model}
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
answer = chat.ask(
queries[i],
top_k=args.top_k,
complexity=args.ef_search,
batch_size=args.batch_size,
)
print(f"Answer: {answer}")
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
# Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,317 @@
import time
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from transformers import AutoModel
# Add MLX imports
try:
import mlx.core as mx
from mlx_lm.utils import load
MLX_AVAILABLE = True
except ImportError:
print("MLX not available. Install with: uv pip install mlx mlx-lm")
MLX_AVAILABLE = False
@dataclass
class BenchmarkConfig:
model_path: str = "facebook/contriever-msmarco"
batch_sizes: list[int] = None
seq_length: int = 256
num_runs: int = 5
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
use_mlx: bool = False # New flag for MLX testing
def __post_init__(self):
if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
class MLXBenchmark:
"""MLX-specific benchmark for embedding models"""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.model, self.tokenizer = self._load_model()
def _load_model(self):
"""Load MLX model and tokenizer following the API pattern"""
print(f"Loading MLX model from {self.config.model_path}...")
try:
model, tokenizer = load(self.config.model_path)
print("MLX model loaded successfully")
return model, tokenizer
except Exception as e:
print(f"Error loading MLX model: {e}")
raise
def _create_random_batch(self, batch_size: int):
"""Create random input batches for MLX testing - same as PyTorch"""
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
def _run_inference(self, input_ids: torch.Tensor) -> float:
"""Run MLX inference with same input as PyTorch"""
start_time = time.time()
try:
# Convert PyTorch tensor to MLX array
input_ids_mlx = mx.array(input_ids.numpy())
# Get embeddings
embeddings = self.model(input_ids_mlx)
# Mean pooling (following the API pattern)
pooled = embeddings.mean(axis=1)
# Convert to numpy (following the API pattern)
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
# Force computation
_ = pooled_numpy.shape
except Exception as e:
print(f"MLX inference error: {e}")
return float("inf")
end_time = time.time()
return end_time - start_time
def run(self) -> dict[int, dict[str, float]]:
"""Run the MLX benchmark across all batch sizes"""
results = {}
print(f"Starting MLX benchmark with model: {self.config.model_path}")
print(f"Testing batch sizes: {self.config.batch_sizes}")
for batch_size in self.config.batch_sizes:
print(f"\n=== Testing MLX batch size: {batch_size} ===")
times = []
# Create input batch (same as PyTorch)
input_ids = self._create_random_batch(batch_size)
# Warm up
print("Warming up...")
for _ in range(3):
try:
self._run_inference(input_ids[:2]) # Warm up with smaller batch
except Exception as e:
print(f"Warmup error: {e}")
break
# Run benchmark
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
if elapsed_time != float("inf"):
times.append(elapsed_time)
except Exception as e:
print(f"Error during MLX inference: {e}")
break
if not times:
print(f"Skipping batch size {batch_size} due to errors")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
"min_time": np.min(times),
"max_time": np.max(times),
}
print(f"MLX Results for batch size {batch_size}:")
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f" Min Time: {np.min(times):.4f}s")
print(f" Max Time: {np.max(times):.4f}s")
print(f" Throughput: {throughput:.2f} sequences/second")
return results
class Benchmark:
def __init__(self, config: BenchmarkConfig):
self.config = config
self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.model = self._load_model()
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path)
if self.config.use_fp16:
model = model.half()
model = torch.compile(model)
model = model.to(self.device)
model.eval()
return model
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0,
1000,
(batch_size, self.config.seq_length),
device=self.device,
dtype=torch.long,
)
def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids)
# print shape of input_ids and attention_mask
print(f"input_ids shape: {input_ids.shape}")
print(f"attention_mask shape: {attention_mask.shape}")
start_time = time.time()
with torch.no_grad():
self.model(input_ids=input_ids, attention_mask=attention_mask)
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.synchronize()
end_time = time.time()
return end_time - start_time
def run(self) -> dict[int, dict[str, float]]:
results = {}
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
input_ids = self._create_random_batch(batch_size)
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
continue
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
else:
peak_memory_gb = 0.0
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def run_benchmark():
"""Main function to run the benchmark with optimized parameters."""
config = BenchmarkConfig()
try:
benchmark = Benchmark(config)
results = benchmark.run()
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results,
}
except Exception as e:
print(f"Benchmark failed: {e}")
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
def run_mlx_benchmark():
"""Run MLX-specific benchmark"""
if not MLX_AVAILABLE:
print("MLX not available, skipping MLX benchmark")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "MLX not available",
}
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
try:
benchmark = MLXBenchmark(config)
results = benchmark.run()
if not results:
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "No valid results",
}
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results,
}
except Exception as e:
print(f"MLX benchmark failed: {e}")
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
if __name__ == "__main__":
print("=== PyTorch Benchmark ===")
pytorch_result = run_benchmark()
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
print("\n=== MLX Benchmark ===")
mlx_result = run_mlx_benchmark()
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
# Compare results
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
print("\n=== Comparison ===")
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")

BIN
data/2501.14312v1 (1).pdf Normal file
View File

Binary file not shown.

14905
data/PrideandPrejudice.txt Normal file
View File

File diff suppressed because it is too large Load Diff

82
data/huawei_pangu.md Normal file
View File

@@ -0,0 +1,82 @@
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
各位好,
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
首先为自证身份,列举一些细节:
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
4. 诺亚曾经传说是研究型的但是来了之后因为在四野做大模型项目项目成员完全变成了交付型的且充满了例会评审汇报。很多时候做实验都要申请。团队需要对接终端小艺华为云ICT等诸多业务线交付压力不小。
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”一开始只有内部需要申请试用的网页版到后续迫于压力在welink上接入和公测开放。
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
华为确实主要在昇腾卡上训练大模型小模型实验室有不少英伟达的卡他们之前也会用来训练后面转移到昇腾。曾经我被华为“打造世界第二选择”的决心而折服我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打从充满bug到现在能训出模型付出了巨大的心血和代价。
最初我们的算力非常有限在910A上训练模型。那会只支持fp16训练的稳定性远不如bf16。盘古的moe开始很早23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型后面主力模型也逐渐在910B上训练。
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低每个单个的符号数字空格乃至汉字都会占用一个token。可想而知这会非常浪费算力且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好虽然事后来看他的怀疑是无疑正确的于是就决定让71B和135B换tokenizer因为小模型实验室曾经尝试过。团队缝合了两个tokenizer开始了tokenizer的更换。71B模型的更换失败了而135B因为采用了更精细的embedding初始化策略续训了至少1T的数据后词表总算更换成功但可想而知效果并不会变好。
于此同期阿里和智谱等国内其他公司在GPU上训练且已经摸索出了正确的方法盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时团队的士气低迷到了极点。团队在算力极其有限的时候做出了很多努力和挣扎。比如团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B架构相对落后团队进行了一系列的操作比如切换绝对位置编码到rope去掉bias切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训变成了第二代38B dense模型在几个月内这个模型都是主要的盘古中档位模型曾经具有一定的竞争力。但是由于更大的135B模型架构落后且更换词表模型损伤巨大后续分析发现当时更换的缝合词表有更严重的bug续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
在这种情况下王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来通过训练短短的几百B数据各项指标平均提升了十个点左右。实际上这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行使得领导完全对于这种扯淡的事情没有概念他们只会觉得肯定是有什么算法创新。经过内部的分析他们实际上是使用Qwen 1.5 110B续训而来通过加层扩增ffn维度添加盘古pi论文的一些机制得来凑够了大概135B的参数。实际上旧的135B有107层而这个模型只有82层各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游甚至包括外部客户。
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击内部很多人其实都知道这件事甚至包括终端和华为云。我们都戏称以后别叫盘古模型了叫千古吧。当时团队成员就想向bcg举报了毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来因为更高级别的领导比如姚老师以及可能熊总和查老其实后面也知道了但是并不管因为通过套壳拿出好的结果对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷离职跑路也逐渐成为挂在嘴边的事。
此时盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来当时诺亚完全没有掌握从头训练的技术何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下盘古开始了第三代模型的训练付出了巨大的努力后在数据架构和训练算法方面都与业界逐渐接轨而这其中的艰辛和小模型实验室的人一点关系都没有。
一开始团队成员毫无信心只从一个13B的模型开始训练但是后面发现效果还不错于是这个模型后续再次进行了一次参数扩增变成了第三代的38B代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的也是业界常见的做法。而当时王云鹤的实验室做出来了另一个词表也就是后续pangu系列的词表。当时两个词表还被迫进行了一次赛马最终没有明显的好坏结论。于是领导当即决定应该统一词表使用王云鹤他们的。于是在后续从头训练的135B V3也就是对外的Pangu Ultra便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑为什么当时同为V3代的两个不同档位的模型会使用不同的tokenizer。
我们打心眼里觉得135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的华为全栈自研正经从头训练的千亿级别的模型且效果与24年同期竞品可比的。写到这里我已经热泪盈眶太不容易了。当时为了稳定训练团队做了大量实验对比并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难我们做到了我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨我们为了它的训练而不眠。在被内部心声骂的一文不值的时候我们有多么不甘有多少的委屈我们挺住了。
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
然而我们的所有辛苦的成果经常被小模型实验室轻飘飘的拿走了。数据直接要走。代码直接要走还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦他们取得荣耀。果然应了那句话你在负重前行是因为有人替你岁月静好。在这种情况下越来越多的战友再也坚持不下去了选择了离开。看到身边那些优秀的同事一个个离职我的内心又感叹又难过。在这种作战一样的环境下我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方堪称良师。看到他们去了诸如字节SeedDeepseek月之暗面腾讯和快手等等很多出色的团队我打心眼里为他们高兴和祝福脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新ta说“来这里是我技术生涯中的耻辱在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足以及没法适应互联网公司高淘汰的环境让我多次想离职的心始终没有迈出这一步。
盘古除了dense模型后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的小模型实验室也开启了第二次主要的套壳行动次要的插曲可能还包括一些别的模型比如math模型即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的就算如此这也与技术报告不符何况是套壳qwen 2.5的14b续训。还记得他们训了没几天内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型都知道他们的套壳行动只是迫于各种原因无法伸张正义。实际上对于后续训了很久很久的这个模型Honestagi能够分析出这个量级的相似性我已经很诧异了因为这个模型为了续训洗参数所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印采取了不少办法甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
24年底和25年初在Deepseek v3和r1发布之后由于其惊艳的技术水平团队受到了巨大的冲击也受到了更大的质疑。于是为了紧跟潮流盘古模仿Deepseek的模型尺寸开启了718B moe的训练。这个时候小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数进行训练。连任务加载ckpt的目录都是deepseekv3改都不改何其嚣张与之相反一些有真正技术信仰的同事在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然这个模型怎么可能比直接套壳的好呢如果不是团队leader坚持早就被叫停了。
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
HonestAGI的事情出来后内部让大家不停的研讨分析如何公关和“回应”。诚然这个原文的分析也许不够有力给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此这两天我内心感到作呕时时怀疑自己的人生意义以及苍天无眼。我不奉陪了我要离职了同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到他们竟然猖狂到敢开源。我没想到他们敢如此愚弄世人大肆宣发。当时我也许是存了侥幸心理没有拒绝署名。我相信很多扎实做事的战友也只是被迫上了贼船或者不知情。但这件事已经无法挽回我希望我的余生能够坚持扎实做真正有意义的事为我当时的软弱和不坚定赎罪。
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
现在,我累了,我想投降。
其实时至今日我还是真心希望华为能认真吸取教训能做好盘古把盘古做到世界一流把昇腾变成英伟达的水平。内部的劣币驱逐良币使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着施展着他们的抱负才华为中美在AI的激烈竞赛中奉献力量。我时常感叹华为不是没有人才而是根本不知道怎么留住人才。如果给这些人合适的环境合适的资源更少的枷锁更少的政治斗争盘古何愁不成
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
如果我消失了就当是我为了真理和理想为了华为乃至中国能够更好地发展算力和AI而牺牲了吧我愿埋葬于那片曾经奋斗过的地方。
诺亚,再见
2025年7月6日凌晨 写于深圳
---
各位好,
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
我补充一些细节,以免某些人继续颠倒黑白。
关于135B V2小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后比如任务令表彰和及时激励因为不想继续支撑下游应用和模型迭代又把这个烫手山芋甩给了四纵。确实技高一筹直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型最终拿回了一个当时一个魔改的先进的千问。做大模型的人自己做的模型就像自己孩子一样熟悉不要把别人都当傻子。就像自家儿子出门一趟回来个别人家孩子。
盘古report的署名是不符合学术规范的。例如135B V3有不少有技术贡献的人因为作者名额数量限制劳动成果没有得到应有的回报团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶甚至是团队当时的精神支柱支撑着不少兄弟们继续留在诺亚。所谓的名额限制以及挂名了一些毫无技术贡献的人如一些小模型实验室的人让兄弟们何其心寒。
---
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317

View File

@@ -1,362 +1,116 @@
{ {
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "markdown",
"execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initializing leann-backend-diskann...\n",
"INFO: Registering backend 'diskann'\n",
"INFO: DiskANN backend loaded successfully\n",
"INFO: LeannBuilder initialized with 'diskann' backend.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/LEANN_clean/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Computing embeddings for 6 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 2.91it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Building DiskANN index for 6 vectors with metric Metric.INNER_PRODUCT...\n",
"Using Inner Product search, so need to pre-process base data into temp file. Please ensure there is additional (n*(d+1)*4) bytes for storing pre-processed base vectors, apart from the interim indices created by DiskANN and the final index.\n",
"Pre-processing base file by adding extra coordinate\n",
"✅ DiskANN index built successfully at 'knowledge'\n",
"Writing bin: knowledge_disk.index_max_base_norm.bin\n",
"bin: #pts = 1, #dims = 1, size = 12B\n",
"Finished writing bin.\n",
"Time for preprocessing data for inner product: 0.000172 seconds\n",
"Reading max_norm_of_base from knowledge_disk.index_max_base_norm.bin\n",
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
"Metadata: #pts = 1, #dims = 1...\n",
"done.\n",
"max_norm_of_base: 1\n",
"! Using prepped_base file at knowledge_prepped_base.bin\n",
"Starting index build: R=32 L=64 Query RAM budget: 4.02653e+09 Indexing ram budget: 8 T: 8\n",
"getting bin metadata\n",
"Time for getting bin metadata: 0.000019 seconds\n",
"Compressing 769-dimensional data into 512 bytes per vector.\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Training data with 6 samples loaded.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"PQ pivot file exists. Not generating again\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 4, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 769, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 513, #dims = 1...\n",
"done.\n",
"Loaded PQ pivot information\n",
"Processing points [0, 6)...done.\n",
"Time for generating quantized data: 0.055587 seconds\n",
"Full index fits in RAM budget, should consume at most 2.03973e-05GiBs, so building in one shot\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"Passed, empty search_params while creating index config\n",
"Using only first 6 from file.. \n",
"Starting index build with 6 points... \n",
"0% of index build completed.Starting final cleanup..done. Link time: 0.00011s\n",
"Index built with degree: max:5 avg:5 min:5 count(deg<2):0\n",
"Not saving tags as they are not enabled.\n",
"Time taken for save: 0.000148s.\n",
"Time for building merged vamana index: 0.000836 seconds\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Vamana index file size=168\n",
"Opened: knowledge_disk.index, cache_size: 67108864\n",
"medoid: 0B\n",
"max_node_len: 3100B\n",
"nnodes_per_sector: 1B\n",
"# sectors: 6\n",
"Sector #0written\n",
"Finished writing 28672B\n",
"Writing bin: knowledge_disk.index\n",
"bin: #pts = 9, #dims = 1, size = 80B\n",
"Finished writing bin.\n",
"Output disk index file written to knowledge_disk.index\n",
"Finished writing 28672B\n",
"Time for generating disk layout: 0.040268 seconds\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Loading base knowledge_prepped_base.bin. #points: 6. #dim: 769.\n",
"Wrote 1 points to sample file: knowledge_sample_data.bin\n",
"Indexing time: 0.0970594\n",
"INFO: Leann metadata saved to knowledge.leann.meta.json\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Opened file : knowledge_disk.index\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ DiskANN index loaded successfully.\n",
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
"Since data is floating point, we assume that it has been appropriately pre-processed (normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we shall invoke an l2 distance function.\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"Before index load\n",
"Reading bin file knowledge_pq_compressed.bin ...\n",
"Opening bin file knowledge_pq_compressed.bin... \n",
"Metadata: #pts = 6, #dims = 512...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 4, #dims = 1...\n",
"done.\n",
"Offsets: 4096 791560 794644 796704\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 769, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 513, #dims = 1...\n",
"done.\n",
"Loaded PQ Pivots: #ctrs: 256, #dims: 769, #chunks: 512\n",
"Loaded PQ centroids and in-memory compressed vectors. #points: 6 #dim: 769 #aligned_dim: 776 #chunks: 512\n",
"Loading index metadata from knowledge_disk.index\n",
"Disk-Index File Meta-data: # nodes per sector: 1, max node len (bytes): 3100, max node degree: 5\n",
"Disk-Index Meta: nodes per sector: 1, max node len: 3100, max node degree: 5\n",
"Setting up thread-specific contexts for nthreads: 8\n",
"allocating ctx: 0x7a33f7204000 to thread-id:134367072315200\n",
"allocating ctx: 0x7a33f6805000 to thread-id:134355206802368\n",
"allocating ctx: 0x7a33f5e72000 to thread-id:134355217288000\n",
"allocating ctx: 0x7a33f5e61000 to thread-id:134355227773632\n",
"allocating ctx: 0x7a33f5e50000 to thread-id:134355196316736\n",
"allocating ctx: 0x7a33f5e3f000 to thread-id:134355164859840\n",
"allocating ctx: 0x7a33f5e2e000 to thread-id:134355175345472\n",
"allocating ctx: 0x7a33f5e1d000 to thread-id:134355185831104\n",
"Loading centroid data from medoids vector data of 1 medoid(s)\n",
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
"Metadata: #pts = 1, #dims = 1...\n",
"done.\n",
"Setting re-scaling factor of base vectors to 1\n",
"load_from_separate_paths done.\n",
"Reading (with alignment) bin file knowledge_sample_data.bin ...Metadata: #pts = 1, #dims = 769, aligned_dim = 776... allocating aligned memory of 3104 bytes... done. Copying data to mem_aligned buffer... done.\n",
"reserve ratio: 1\n",
"Graph traversal completed, hops: 3\n",
"Loading the cache list into memory....done.\n",
"After index load\n",
"INFO: Computing embeddings for 1 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 60.54it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running\n",
"INFO: Starting session-level embedding server as a background process...\n",
"INFO: Running command from project root: /home/ubuntu/LEANN_clean/leann\n",
"INFO: Server process started with PID: 424761\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Embedding server is up and ready for this session.\n",
"[EmbeddingServer LOG]: Initializing leann-backend-diskann...\n",
"[EmbeddingServer LOG]: WARNING: Could not import DiskANN backend: cannot import name '_diskannpy' from partially initialized module 'packages.leann-backend-diskann.leann_backend_diskann' (most likely due to a circular import) (/home/ubuntu/LEANN_clean/leann/packages/leann-backend-diskann/leann_backend_diskann/__init__.py)\n",
"[EmbeddingServer LOG]: INFO: Initializing embedding server thread on port 5555\n",
"[EmbeddingServer LOG]: INFO: Using CUDA device\n",
"[EmbeddingServer LOG]: INFO: Loading model sentence-transformers/all-mpnet-base-v2\n",
"[EmbeddingServer LOG]: INFO: Using FP16 precision with model: sentence-transformers/all-mpnet-base-v2\n",
"[EmbeddingServer LOG]: INFO: Loaded 6 demo documents\n",
"[EmbeddingServer LOG]: INFO: ZMQ ROUTER server listening on port 5555\n",
"[EmbeddingServer LOG]: INFO: Embedding server ready to serve requests\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 3 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 1 node embeddings: [0]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 0\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000028 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 1, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 1\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.019294 seconds\n",
"[EmbeddingServer LOG]: Batch size: 1, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000210 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 3.065444 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.041810 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000194 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 3.128073 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [1, 2, 3, 4, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 1 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000042 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001791 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000112 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 3.674183 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000372 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000177 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 3.677425 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [3, 4, 2, 1, 0]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 4\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000030 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001550 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000097 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.009335 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000154 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000073 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.011773 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [0, 1, 2, 4, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001041 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000125 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008972 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000151 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000048 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010853 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [3, 1, 0, 2, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001350 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000088 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008869 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000146 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000063 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.011054 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [0, 2, 3, 4, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000022 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001195 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000087 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008903 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000145 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000060 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010921 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [1, 0, 3, 4, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001188 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000087 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008858 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000153 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000052 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010886 seconds\n",
"reserve ratio: Score: -0.481 - C++ is a powerful programming language1\n",
"Graph traversal completed, hops: 3\n",
"\n",
"Score: -1.049 - Java is a powerful programming language\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n"
]
}
],
"source": [ "source": [
"from leann.api import LeannBuilder, LeannSearcher\n", "# Quick Start \n",
"import leann_backend_diskann\n", "\n",
"# 1. Build index (no embeddings stored!)\n", "**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
"builder = LeannBuilder(backend_name=\"diskann\")\n", "\n",
"builder.add_text(\"Python is a powerful programming language\")\n", "**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# install this if you are using colab\n",
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
"! uv pip install leann --no-deps\n",
"# For Colab environment, we need to set some environment variables\n",
"import os\n",
"\n",
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"INDEX_DIR = Path(\"./\").resolve()\n",
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannBuilder\n",
"\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
"builder.add_text(\n",
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
")\n",
"builder.add_text(\"Machine learning transforms industries\")\n", "builder.add_text(\"Machine learning transforms industries\")\n",
"builder.add_text(\"Neural networks process complex data\")\n", "builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Java is a powerful programming language\")\n", "builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
"builder.add_text(\"C++ is a powerful programming language\")\n", "builder.build_index(INDEX_PATH)"
"builder.add_text(\"C# is a powerful programming language\")\n", ]
"builder.build_index(\"knowledge.leann\")\n", },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Search with real-time embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannSearcher\n",
"\n", "\n",
"# 2. Search with real-time embeddings\n", "searcher = LeannSearcher(INDEX_PATH)\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n", "results = searcher.search(\"programming languages\", top_k=2)\n",
"results = searcher.search(\"C++ programming languages\", top_k=2,recompute_beighbor_embeddings=True)\n", "results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chat with LEANN using retrieved results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannChat\n",
"\n", "\n",
"for result in results:\n", "llm_config = {\n",
" print(f\"Score: {result['score']:.3f} - {result['text']}\")" " \"type\": \"hf\",\n",
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
"}\n",
"\n",
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
"response = chat.ask(\n",
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
" top_k=2,\n",
" llm_kwargs={\"max_tokens\": 128},\n",
")\n",
"response"
] ]
} }
], ],
@@ -376,7 +130,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.11" "version": "3.11.12"
} }
}, },
"nbformat": 4, "nbformat": 4,

220
docs/CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,220 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
## 🚀 Development Setup
### Prerequisites
1. **Install uv** (fast Python package installer):
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. **Clone the repository**:
```bash
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
cd LEANN-RAG
```
3. **Install system dependencies**:
**macOS:**
```bash
brew install llvm libomp boost protobuf zeromq pkgconf
```
**Ubuntu/Debian:**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
```
4. **Build from source**:
```bash
# macOS
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
# Ubuntu/Debian
uv sync
```
## 🔨 Pre-commit Hooks
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
### Setup Pre-commit
1. **Install pre-commit** (already included when you run `uv sync`):
```bash
uv pip install pre-commit
```
2. **Install the git hooks**:
```bash
pre-commit install
```
3. **Run pre-commit manually** (optional):
```bash
pre-commit run --all-files
```
### Pre-commit Checks
Our pre-commit configuration includes:
- **Trailing whitespace removal**
- **End-of-file fixing**
- **YAML validation**
- **Large file prevention**
- **Merge conflict detection**
- **Debug statement detection**
- **Code formatting with ruff**
- **Code linting with ruff**
## 🧪 Testing
### Running Tests
```bash
# Run all tests
uv run pytest
# Run specific test file
uv run pytest test/test_filename.py
# Run with coverage
uv run pytest --cov=leann
```
### Writing Tests
- Place tests in the `test/` directory
- Follow the naming convention `test_*.py`
- Use descriptive test names that explain what's being tested
- Include both positive and negative test cases
## 📝 Code Style
We use `ruff` for both linting and formatting to ensure consistent code style.
### Format Your Code
```bash
# Format all files
ruff format
# Check formatting without changing files
ruff format --check
```
### Lint Your Code
```bash
# Run linter with auto-fix
ruff check --fix
# Just check without fixing
ruff check
```
### Style Guidelines
- Follow PEP 8 conventions
- Use descriptive variable names
- Add type hints where appropriate
- Write docstrings for all public functions and classes
- Keep functions focused and single-purpose
## 🚦 CI/CD
Our CI pipeline runs automatically on all pull requests. It includes:
1. **Linting and Formatting**: Ensures code follows our style guidelines
2. **Multi-platform builds**: Tests on Ubuntu and macOS
3. **Python version matrix**: Tests on Python 3.9-3.13
4. **Wheel building**: Ensures packages can be built and distributed
### CI Commands
The CI uses the same commands as pre-commit to ensure consistency:
```bash
# Linting
ruff check .
# Format checking
ruff format --check .
```
Make sure your code passes these checks locally before pushing!
## 🔄 Pull Request Process
1. **Fork the repository** and create your branch from `main`:
```bash
git checkout -b feature/your-feature-name
```
2. **Make your changes**:
- Write clean, documented code
- Add tests for new functionality
- Update documentation as needed
3. **Run pre-commit checks**:
```bash
pre-commit run --all-files
```
4. **Test your changes**:
```bash
uv run pytest
```
5. **Commit with descriptive messages**:
```bash
git commit -m "feat: add new search algorithm"
```
Follow [Conventional Commits](https://www.conventionalcommits.org/):
- `feat:` for new features
- `fix:` for bug fixes
- `docs:` for documentation changes
- `test:` for test additions/changes
- `refactor:` for code refactoring
- `perf:` for performance improvements
6. **Push and create a pull request**:
- Provide a clear description of your changes
- Reference any related issues
- Include examples or screenshots if applicable
## 📚 Documentation
When adding new features or making significant changes:
1. Update relevant documentation in `/docs`
2. Add docstrings to new functions/classes
3. Update README.md if needed
4. Include usage examples
## 🤔 Getting Help
- **Discord**: Join our community for discussions
- **Issues**: Check existing issues or create a new one
- **Discussions**: For general questions and ideas
## 📄 License
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
---
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟

22
docs/RELEASE.md Normal file
View File

@@ -0,0 +1,22 @@
# Release Guide
## Setup (One-time)
Add `PYPI_API_TOKEN` to GitHub Secrets:
1. Get token: https://pypi.org/manage/account/token/
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
## Release (One-click)
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
2. Click "Run workflow"
3. Enter version: `0.1.2`
4. Click green "Run workflow" button
That's it! The workflow will automatically:
- ✅ Update version in all packages
- ✅ Build all packages
- ✅ Publish to PyPI
- ✅ Create GitHub tag and release
Check progress: https://github.com/yichuan-w/LEANN/actions

View File

@@ -0,0 +1,123 @@
# Thinking Budget Feature Implementation
## Overview
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
## Feature Description
The thinking budget feature provides three levels of computational effort for reasoning models:
- **`low`**: Fast responses, basic reasoning (default for simple queries)
- **`medium`**: Balanced speed and reasoning depth
- **`high`**: Maximum reasoning effort, best for complex analytical questions
## Implementation Details
### 1. Command Line Interface
Added `--thinking-budget` parameter to both CLI and RAG examples:
```bash
# LEANN CLI
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
# RAG Examples
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
```
### 2. LLM Backend Support
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
```python
def ask(self, prompt: str, **kwargs) -> str:
# Handle thinking budget for reasoning models
options = kwargs.copy()
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget:
options.pop("thinking_budget", None)
if thinking_budget in ["low", "medium", "high"]:
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
```
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
```python
def ask(self, prompt: str, **kwargs) -> str:
# Handle thinking budget for reasoning models
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
# Check if this is an o-series model
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
if any(model in self.model for model in o_series_models):
params["reasoning_effort"] = thinking_budget
```
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
### 3. Parameter Propagation
The thinking budget parameter is properly propagated through the LEANN architecture:
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
4. **LLM Interface**: Handles the parameter in backend-specific implementations
## Files Modified
### Core Implementation
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
### Documentation
- `README.md`: Added thinking budget parameter to usage examples
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
### Examples
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
## Usage Examples
### Basic Usage
```bash
# High reasoning effort for complex questions
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
# Medium reasoning for balanced performance
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
# Low reasoning for fast responses
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
```
### RAG Examples
```bash
# Email RAG with high reasoning
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
# Document RAG with medium reasoning
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
```
## Supported Models
### Ollama Models
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
### OpenAI Models
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
- **GPT-OSS models**: Models that support reasoning capabilities
## Testing
The implementation includes comprehensive testing:
- Parameter handling verification
- Backend-specific API format validation
- CLI argument parsing tests
- Integration with existing LEANN architecture

128
docs/ast_chunking_guide.md Normal file
View File

@@ -0,0 +1,128 @@
# AST-Aware Code chunking guide
## Overview
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
## Quick Start
### Basic Usage
```bash
# Enable AST chunking for mixed content (code + docs)
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
# Specialized code repository indexing
python -m apps.code_rag --repo-dir ./my_codebase
# Global CLI with AST support
leann build my-code-index --docs ./src --use-ast-chunking
```
### Installation
```bash
# Install LEANN with AST chunking support
uv pip install -e "."
```
## Best Practices
### When to Use AST Chunking
**Recommended for:**
- Code repositories with multiple languages
- Mixed documentation and code content
- Complex codebases with deep function/class hierarchies
- When working with Claude Code for code assistance
**Not recommended for:**
- Pure text documents
- Very large files (>1MB)
- Languages not supported by tree-sitter
### Optimal Configuration
```bash
# Recommended settings for most codebases
python -m apps.code_rag \
--repo-dir ./src \
--ast-chunk-size 768 \
--ast-chunk-overlap 96 \
--exclude-dirs .git __pycache__ node_modules build dist
```
### Supported Languages
| Extension | Language | Status |
|-----------|----------|--------|
| `.py` | Python | ✅ Full support |
| `.java` | Java | ✅ Full support |
| `.cs` | C# | ✅ Full support |
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
## Integration Examples
### Document RAG with Code Support
```python
# Enable code chunking in document RAG
python -m apps.document_rag \
--enable-code-chunking \
--data-dir ./project \
--query "How does authentication work in the codebase?"
```
### Claude Code Integration
When using with Claude Code MCP server, AST chunking provides better context for:
- Code completion and suggestions
- Bug analysis and debugging
- Architecture understanding
- Refactoring assistance
## Troubleshooting
### Common Issues
1. **Fallback to Traditional Chunking**
- Normal behavior for unsupported languages
- Check logs for specific language support
2. **Performance with Large Files**
- Adjust `--max-file-size` parameter
- Use `--exclude-dirs` to skip unnecessary directories
3. **Quality Issues**
- Try different `--ast-chunk-size` values (512, 768, 1024)
- Adjust overlap for better context preservation
### Debug Mode
```bash
export LEANN_LOG_LEVEL=DEBUG
python -m apps.code_rag --repo-dir ./my_code
```
## Migration from Traditional Chunking
Existing workflows continue to work without changes. To enable AST chunking:
```bash
# Before
python -m apps.document_rag --chunk-size 256
# After (maintains traditional chunking for non-code files)
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
```
## References
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
- [Research Paper](https://arxiv.org/html/2506.15655v1)
---
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.

View File

@@ -0,0 +1,98 @@
"""
Comparison between Sentence Transformers and OpenAI embeddings
This example shows how different embedding models handle complex queries
and demonstrates the differences between local and API-based embeddings.
"""
import numpy as np
from leann.embedding_compute import compute_embeddings
# OpenAI API key should be set as environment variable
# export OPENAI_API_KEY="your-api-key-here"
# Test data
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
# Two queries with same intent but different wording
query1 = "Tell me my browser history about some conference i often visit"
query2 = "browser history about conference I often visit"
texts = [query1, query2, conference_text, browser_text]
def cosine_similarity(a, b):
return np.dot(a, b) # Already normalized
def analyze_embeddings(embeddings, model_name):
print(f"\n=== {model_name} Results ===")
# Results for Query 1
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
print(f"Query 1: '{query1}'")
print(f" → Conference similarity: {sim1_conf:.4f} {'' if sim1_conf > sim1_browser else ''}")
print(
f" → Browser similarity: {sim1_browser:.4f} {'' if sim1_browser > sim1_conf else ''}"
)
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
# Results for Query 2
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
print(f"\nQuery 2: '{query2}'")
print(f" → Conference similarity: {sim2_conf:.4f} {'' if sim2_conf > sim2_browser else ''}")
print(
f" → Browser similarity: {sim2_browser:.4f} {'' if sim2_browser > sim2_conf else ''}"
)
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
# Show the impact
print("\n=== Impact Analysis ===")
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
print("✅ STABLE: Conference remains winner in both queries")
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
print("✅ STABLE: Browser remains winner in both queries")
else:
print("🔄 MIXED: Results vary between queries")
return {
"query1_conf": sim1_conf,
"query1_browser": sim1_browser,
"query2_conf": sim2_conf,
"query2_browser": sim2_browser,
}
# Test Sentence Transformers
print("Testing Sentence Transformers (facebook/contriever)...")
try:
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
except Exception as e:
print(f"❌ Sentence Transformers failed: {e}")
st_results = None
# Test OpenAI
print("\n" + "=" * 60)
print("Testing OpenAI (text-embedding-3-small)...")
try:
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
except Exception as e:
print(f"❌ OpenAI failed: {e}")
openai_results = None
# Compare results
if st_results and openai_results:
print("\n" + "=" * 60)
print("=== COMPARISON SUMMARY ===")

384
docs/configuration-guide.md Normal file
View File

@@ -0,0 +1,384 @@
# LEANN Configuration Guide
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
## Getting Started: Simple is Better
When first trying LEANN, start with a small dataset to quickly validate your approach:
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
```bash
python -m apps.document_rag --query "What techniques does LEANN use?"
```
**For other data sources**: Limit the dataset size for quick testing
```bash
# WeChat: Test with recent messages only
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
# Browser history: Last few days
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
# Email: Recent inbox
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
```
Once validated, scale up gradually:
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
- This helps identify issues early before committing to long processing times
## Embedding Model Selection: Understanding the Trade-offs
Based on our experience developing LEANN, embedding models fall into three categories:
### Small Models (< 100M parameters)
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
- **Pros**: Lightweight, fast for both indexing and inference
- **Cons**: Lower semantic understanding, may miss nuanced relationships
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
### Medium Models (100M-500M parameters)
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
- **Pros**: Balanced performance, good multilingual support, reasonable speed
- **Cons**: Requires more compute than small models
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
### Large Models (500M+ parameters)
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
- **Cons**: Slower inference, longer index build times
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
### Quick Start: Cloud and Local Embedding Options
**OpenAI Embeddings (Fastest Setup)**
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
```bash
# Set OpenAI embeddings (requires OPENAI_API_KEY)
--embedding-mode openai --embedding-model text-embedding-3-small
```
**Ollama Embeddings (Privacy-Focused)**
For local embeddings with complete privacy:
```bash
# First, pull an embedding model
ollama pull nomic-embed-text
# Use Ollama embeddings
--embedding-mode ollama --embedding-model nomic-embed-text
```
<details>
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
**OpenAI Embeddings** (`text-embedding-3-small/large`)
- **Pros**: No local compute needed, consistently fast, high quality
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- **When to use**: Prototyping, non-sensitive data, need immediate results
**Local Embeddings**
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
- **Cons**: Slower than cloud APIs, requires local compute resources
- **When to use**: Production systems, sensitive data, cost-sensitive applications
</details>
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
- Full recomputation required
- High memory usage during build phase
- Excellent recall (95%+)
```bash
# Optimal for most use cases
--backend-name hnsw --graph-degree 32 --build-complexity 64
```
### DiskANN
**Best for**: Large datasets, especially when you want `recompute=True`.
**Key advantages:**
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
- **Better scaling**: Designed for 100k+ documents
**Recompute behavior:**
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
```bash
# Recommended for most use cases
--backend-name diskann --graph-degree 32 --build-complexity 64
```
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
## LLM Selection: Engine and Model Comparison
### LLM Engines
**OpenAI** (`--llm openai`)
- **Pros**: Best quality, consistent performance, no local resources needed
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
**Ollama** (`--llm ollama`)
- **Pros**: Fully local, free, privacy-preserving, good model variety
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
**HuggingFace** (`--llm hf`)
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
- **Cons**: More complex initial setup
- **Models**: `Qwen/Qwen3-1.7B-FP8`
## Parameter Tuning Guide
### Search Complexity Parameters
**`--build-complexity`** (index building)
- Controls thoroughness during index construction
- Higher = better recall but slower build
- Recommendations:
- 32: Quick prototyping
- 64: Balanced (default)
- 128: Production systems
- 256: Maximum quality
**`--search-complexity`** (query time)
- Controls search thoroughness
- Higher = better results but slower
- Recommendations:
- 16: Fast/Interactive search
- 32: High quality with diversity
- 64+: Maximum accuracy
### Top-K Selection
**`--top-k`** (number of retrieved chunks)
- More chunks = better context but slower LLM processing
- Should be always smaller than `--search-complexity`
- Guidelines:
- 10-20: General questions (default: 20)
- 30+: Complex multi-hop reasoning requiring comprehensive context
**Trade-off formula**:
- Retrieval time ∝ log(n) × search_complexity
- LLM processing time ∝ top_k × chunk_size
- Total context = top_k × chunk_size tokens
### Thinking Budget for Reasoning Models
**`--thinking-budget`** (reasoning effort level)
- Controls the computational effort for reasoning models
- Options: `low`, `medium`, `high`
- Guidelines:
- `low`: Fast responses, basic reasoning (default for simple queries)
- `medium`: Balanced speed and reasoning depth
- `high`: Maximum reasoning effort, best for complex analytical questions
- **Supported Models**:
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
- **Example**: `--thinking-budget high` for complex analytical questions
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
**💡 Quick Examples:**
```bash
# OpenAI o-series reasoning model
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
--index-dir hnswbuild --backend hnsw \
--llm openai --llm-model o3 --thinking-budget medium
# Ollama reasoning model
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
--index-dir hnswbuild --backend hnsw \
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
```
### Graph Degree (HNSW/DiskANN)
**`--graph-degree`**
- Number of connections per node in the graph
- Higher = better recall but more memory
- HNSW: 16-32 (default: 32)
- DiskANN: 32-128 (default: 64)
## Performance Optimization Checklist
### If Embedding is Too Slow
1. **Switch to smaller model**:
```bash
# From large model
--embedding-model Qwen/Qwen3-Embedding-0.6B
# To small model
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
2. **Limit dataset size for testing**:
```bash
--max-items 1000 # Process first 1k items only
```
3. **Use MLX on Apple Silicon** (optional optimization):
```bash
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
```
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
4. **Use Ollama**
```bash
--embedding-mode ollama --embedding-model nomic-embed-text
```
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
### If Search Quality is Poor
1. **Increase retrieval count**:
```bash
--top-k 30 # Retrieve more candidates
```
2. **Upgrade embedding model**:
```bash
# For English
--embedding-model BAAI/bge-base-en-v1.5
# For multilingual
--embedding-model intfloat/multilingual-e5-large
```
## Understanding the Trade-offs
Every configuration choice involves trade-offs:
| Factor | Small/Fast | Large/Quality |
|--------|------------|---------------|
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
| Chunk Size | 512 tokens | 128 tokens |
| Index Type | HNSW | DiskANN |
| LLM | `qwen3:1.7b` | `gpt-4o` |
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
## Low-resource setups
If you dont have a local GPU or builds/searches are too slow, use one or more of the options below.
### 1) Use OpenAI embeddings (no local compute)
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
```bash
export OPENAI_API_KEY=sk-...
# Build with OpenAI embeddings
leann build my-index \
--embedding-mode openai \
--embedding-model text-embedding-3-small
# Search with OpenAI embeddings (recompute at query time)
leann search my-index "your query" \
--recompute
```
### 2) Run remote builds with SkyPilot (cloud GPU)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
```bash
# One-time: install and configure SkyPilot
pip install skypilot
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
sky launch -c leann-gpu sky/leann-build.yaml
# Override parameters via -e key=value (optional)
sky launch -c leann-gpu sky/leann-build.yaml \
-e index_name=my-index \
-e backend=hnsw \
-e embedding_mode=sentence-transformers \
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
# Copy the built index back to your local .leann (use rsync)
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
```
### 3) Disable recomputation to trade storage for speed
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
```bash
# Build without recomputation (HNSW requires non-compact in this mode)
leann build my-index --no-recompute --no-compact
# Search without recomputation
leann search my-index "your query" --no-recompute
```
When to use:
- Extreme low latency requirements (high QPS, interactive assistants)
- Read-heavy workloads where storage is cheaper than latency
- No always-available GPU
Constraints:
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
- DiskANN: supported; `--no-recompute` skips selective recompute during search
Storage impact:
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
Converting an existing index (rebuild required):
```bash
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
leann build my-index --force --no-recompute --no-compact
```
Python API usage:
```python
from leann import LeannSearcher
searcher = LeannSearcher("/path/to/my-index.leann")
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
```
Trade-offs:
- Lower latency and fewer network hops at query time
- Significantly higher storage (10100× vs selective recomputation)
- Slightly larger memory footprint during build and search
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
- HNSW
```text
recompute=True: search_time=0.818s, size=1.1MB
recompute=False: search_time=0.012s, size=16.6MB
```
- DiskANN
```text
recompute=True: search_time=0.041s, size=5.9MB
recompute=False: search_time=0.013s, size=24.6MB
```
Conclusion:
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
## Further Reading
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)

10
docs/faq.md Normal file
View File

@@ -0,0 +1,10 @@
# FAQ
## 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)

23
docs/features.md Normal file
View File

@@ -0,0 +1,23 @@
# ✨ Detailed Features
## 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
## 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
## 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment

300
docs/metadata_filtering.md Normal file
View File

@@ -0,0 +1,300 @@
# LEANN Metadata Filtering Usage Guide
## Overview
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
## Basic Usage
### Adding Metadata to Your Documents
When building your index, add metadata to each text chunk:
```python
from leann.api import LeannBuilder
builder = LeannBuilder("hnsw")
# Add text with metadata
builder.add_text(
text="Chapter 1: Alice falls down the rabbit hole",
metadata={
"chapter": 1,
"character": "Alice",
"themes": ["adventure", "curiosity"],
"word_count": 150
}
)
builder.build_index("alice_in_wonderland_index")
```
### Searching with Metadata Filters
Use the `metadata_filters` parameter in search calls:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("alice_in_wonderland_index")
# Search with filters
results = searcher.search(
query="What happens to Alice?",
top_k=10,
metadata_filters={
"chapter": {"<=": 5}, # Only chapters 1-5
"spoiler_level": {"!=": "high"} # No high spoilers
}
)
```
## Filter Syntax
### Basic Structure
```python
metadata_filters = {
"field_name": {"operator": value},
"another_field": {"operator": value}
}
```
### Supported Operators
#### Comparison Operators
- `"=="`: Equal to
- `"!="`: Not equal to
- `"<"`: Less than
- `"<="`: Less than or equal
- `">"`: Greater than
- `">="`: Greater than or equal
```python
# Examples
{"chapter": {"==": 1}} # Exactly chapter 1
{"page": {">": 100}} # Pages after 100
{"rating": {">=": 4.0}} # Rating 4.0 or higher
{"word_count": {"<": 500}} # Short passages
```
#### Membership Operators
- `"in"`: Value is in list
- `"not_in"`: Value is not in list
```python
# Examples
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
```
#### String Operators
- `"contains"`: String contains substring
- `"starts_with"`: String starts with prefix
- `"ends_with"`: String ends with suffix
```python
# Examples
{"title": {"contains": "alice"}} # Title contains "alice"
{"filename": {"ends_with": ".py"}} # Python files
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
```
#### Boolean Operators
- `"is_true"`: Field is truthy
- `"is_false"`: Field is falsy
```python
# Examples
{"is_published": {"is_true": True}} # Published content
{"is_draft": {"is_false": False}} # Not drafts
```
### Multiple Operators on Same Field
You can apply multiple operators to the same field (AND logic):
```python
metadata_filters = {
"word_count": {
">=": 100, # At least 100 words
"<=": 500 # At most 500 words
}
}
```
### Compound Filters
Multiple fields are combined with AND logic:
```python
metadata_filters = {
"chapter": {"<=": 10}, # Up to chapter 10
"character": {"==": "Alice"}, # About Alice
"spoiler_level": {"!=": "high"} # No major spoilers
}
```
## Use Case Examples
### 1. Spoiler-Free Book Search
```python
# Reader has only read up to chapter 5
def search_spoiler_free(query, max_chapter):
return searcher.search(
query=query,
metadata_filters={
"chapter": {"<=": max_chapter},
"spoiler_level": {"in": ["none", "low"]}
}
)
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
```
### 2. Document Management by Date
```python
# Find recent documents
recent_docs = searcher.search(
query="project updates",
metadata_filters={
"date": {">=": "2024-01-01"},
"document_type": {"==": "report"}
}
)
```
### 3. Code Search by File Type
```python
# Search only Python files
python_code = searcher.search(
query="authentication function",
metadata_filters={
"file_extension": {"==": ".py"},
"lines_of_code": {"<": 100}
}
)
```
### 4. Content Filtering by Audience
```python
# Age-appropriate content
family_content = searcher.search(
query="adventure stories",
metadata_filters={
"age_rating": {"in": ["G", "PG"]},
"content_warnings": {"not_in": ["violence", "adult_themes"]}
}
)
```
### 5. Multi-Book Series Management
```python
# Search across first 3 books only
early_series = searcher.search(
query="character development",
metadata_filters={
"series": {"==": "Harry Potter"},
"book_number": {"<=": 3}
}
)
```
## Running the Example
You can see metadata filtering in action with our spoiler-free book RAG example:
```bash
# Don't forget to set up the environment
uv venv
source .venv/bin/activate
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
export OPENAI_API_KEY="your-api-key-here"
# Run the spoiler-free book RAG example
uv run examples/spoiler_free_book_rag.py
```
This example demonstrates:
- Building an index with metadata (chapter numbers, characters, themes, locations)
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
- Different scenarios for readers at various points in the book
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
## Advanced Patterns
### Custom Chunking with metadata
```python
def chunk_book_with_metadata(book_text, book_info):
chunks = []
for chapter_num, chapter_text in parse_chapters(book_text):
# Extract entities, themes, etc.
characters = extract_characters(chapter_text)
themes = classify_themes(chapter_text)
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
# Create chunks with rich metadata
for paragraph in split_paragraphs(chapter_text):
chunks.append({
"text": paragraph,
"metadata": {
"book_title": book_info["title"],
"chapter": chapter_num,
"characters": characters,
"themes": themes,
"spoiler_level": spoiler_level,
"word_count": len(paragraph.split()),
"reading_level": calculate_reading_level(paragraph)
}
})
return chunks
```
## Performance Considerations
### Efficient Filtering Strategies
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
### Best Practices
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
### Adding Metadata to Existing Indices
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
```python
# Read existing passages and add metadata
def add_metadata_to_existing_chunks(chunks):
for chunk in chunks:
# Extract or assign metadata based on content
chunk["metadata"] = extract_metadata(chunk["text"])
return chunks
# Rebuild index with metadata
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
builder = LeannBuilder("hnsw")
for chunk in enhanced_chunks:
builder.add_text(chunk["text"], chunk["metadata"])
builder.build_index("enhanced_index")
```

View File

@@ -0,0 +1,75 @@
# Normalized Embeddings Support in LEANN
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
## What are Normalized Embeddings?
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
## Automatic Detection
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
1. **Automatically set `distance_metric="cosine"`** if not specified
2. **Show a warning** if you manually specify a different distance metric
3. **Provide optimal search performance** with the correct metric
## Supported Normalized Embedding Models
### OpenAI
All OpenAI text embedding models are normalized:
- `text-embedding-ada-002`
- `text-embedding-3-small`
- `text-embedding-3-large`
### Voyage AI
All Voyage AI embedding models are normalized:
- `voyage-2`
- `voyage-3`
- `voyage-large-2`
- `voyage-multilingual-2`
- `voyage-code-2`
### Cohere
All Cohere embedding models are normalized:
- `embed-english-v3.0`
- `embed-multilingual-v3.0`
- `embed-english-light-v3.0`
- `embed-multilingual-light-v3.0`
## Example Usage
```python
from leann.api import LeannBuilder
# Automatic detection - will use cosine distance
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai"
)
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
# Automatically setting distance_metric='cosine'
# Manual override (not recommended)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
distance_metric="mips" # Will show warning
)
# Warning: Using 'mips' distance metric with normalized embeddings...
```
## Non-Normalized Embeddings
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
## Why This Matters
Using the wrong distance metric with normalized embeddings can lead to:
- **Poor search quality** due to HNSW's early termination with narrow score ranges
- **Incorrect ranking** of search results
- **Suboptimal performance** compared to using the correct metric
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.

21
docs/roadmap.md Normal file
View File

@@ -0,0 +1,21 @@
# 📈 Roadmap
## 🎯 Q2 2025
- [X] HNSW backend integration
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
## 🚀 Q3 2025
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
## 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion

View File

@@ -1,16 +1,23 @@
""" """
Simple demo showing basic leann usage Simple demo showing basic leann usage
Run: uv run python examples/simple_demo.py Run: uv run python examples/basic_demo.py
""" """
import argparse import argparse
from leann import LeannBuilder, LeannSearcher, LeannChat
from leann import LeannBuilder, LeannChat, LeannSearcher
def main(): def main():
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.") parser = argparse.ArgumentParser(
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2", description="Simple demo of Leann with selectable embedding models."
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.") )
parser.add_argument(
"--embedding_model",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
)
args = parser.parse_args() args = parser.parse_args()
print(f"=== Leann Simple Demo with {args.embedding_model} ===") print(f"=== Leann Simple Demo with {args.embedding_model} ===")
@@ -74,7 +81,7 @@ def main():
print() print()
print("Demo completed! Try running:") print("Demo completed! Try running:")
print(" uv run python examples/document_search.py") print(" uv run python apps/document_rag.py")
if __name__ == "__main__": if __name__ == "__main__":

View File

Binary file not shown.

View File

@@ -1,146 +0,0 @@
#!/usr/bin/env python3
"""
Document search demo with recompute mode
"""
import os
from pathlib import Path
import shutil
import time
# Import backend packages to trigger plugin registration
try:
import leann_backend_diskann
import leann_backend_hnsw
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
# Import upper-level API from leann-core
from leann.api import LeannBuilder, LeannSearcher, LeannChat
def load_sample_documents():
"""Create sample documents for demonstration"""
docs = [
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
]
return docs
def main():
print("==========================================================")
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
print("==========================================================")
INDEX_DIR = Path("./test_indices")
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
BACKEND_TO_TEST = "diskann"
if INDEX_DIR.exists():
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
shutil.rmtree(INDEX_DIR)
# --- 1. Build index ---
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
builder = LeannBuilder(
backend_name=BACKEND_TO_TEST,
graph_degree=32,
complexity=64
)
documents = load_sample_documents()
print(f"Loaded {len(documents)} sample documents.")
for doc in documents:
builder.add_text(doc["content"], metadata={"title": doc["title"]})
builder.build_index(INDEX_PATH)
print(f"\nIndex built!")
# --- 2. Basic search demo ---
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
searcher = LeannSearcher(index_path=INDEX_PATH)
query = "What is machine learning?"
print(f"\nQuery: '{query}'")
print("\n--- Basic search mode (PQ computation) ---")
start_time = time.time()
results = searcher.search(query, top_k=2)
basic_time = time.time() - start_time
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<")
for i, res in enumerate(results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
# --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...")
print("\n--- Recompute search mode (get real embeddings via network) ---")
# Configure recompute parameters
recompute_params = {
"recompute_beighbor_embeddings": True, # Enable network recomputation
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
"skip_search_reorder": True, # Skip search reordering
"dedup_node_dis": True, # Enable node distance deduplication
"prune_ratio": 0.1, # Pruning ratio 10%
"batch_recompute": False, # Don't use batch recomputation
"global_pruning": False, # Don't use global pruning
"zmq_port": 5555, # ZMQ port
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
}
print("Recompute parameter configuration:")
for key, value in recompute_params.items():
print(f" {key}: {value}")
print(f"\n🔄 Executing Recompute search...")
try:
start_time = time.time()
recompute_results = searcher.search(query, top_k=2, **recompute_params)
recompute_time = time.time() - start_time
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
# Compare results
print(f"\n--- Result comparison ---")
print(f"Basic search time: {basic_time:.3f} seconds")
print(f"Recompute time: {recompute_time:.3f} seconds")
print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))):
basic_score = results[i]['score']
recompute_score = recompute_results[i]['score']
score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
if recompute_time > basic_time:
print(f"✅ Recompute mode working correctly (more accurate but slower)")
else:
print(f" Recompute time is unusually fast, network recomputation may not be enabled")
except Exception as e:
print(f"❌ Recompute search failed: {e}")
print("This usually indicates an embedding server connection issue")
# --- 4. Chat demo ---
print(f"\n[PHASE 4] Starting chat session...")
chat = LeannChat(index_path=INDEX_PATH)
chat_response = chat.ask(query)
print(f"You: {query}")
print(f"Leann: {chat_response}")
print("\n==========================================================")
print("✅ Demo finished successfully!")
print("==========================================================")
if __name__ == "__main__":
main()

View File

@@ -1,81 +0,0 @@
import faulthandler
faulthandler.enable()
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.readers.base import BaseReader
from llama_index.node_parser.docling import DoclingNodeParser
from llama_index.readers.docling import DoclingReader
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
import asyncio
import os
import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat
import leann_backend_hnsw # Import to ensure backend registration
import shutil
from pathlib import Path
dotenv.load_dotenv()
reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)
file_extractor: dict[str, BaseReader] = {
".docx": reader,
".pptx": reader,
".pdf": reader,
".xlsx": reader,
}
node_parser = DoclingNodeParser(
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64)
)
print("Loading documents...")
documents = SimpleDirectoryReader(
"examples/data",
recursive=True,
file_extractor=file_extractor,
encoding="utf-8",
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.text)
INDEX_DIR = Path("./test_pdf_index")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
print(f"\n[PHASE 1] Building Leann index...")
# CSR compact mode with recompute
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True
)
print(f"Loaded {len(all_texts)} text chunks from documents.")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
print(f"\nLeann index built at {INDEX_PATH}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
async def main():
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
print(f"Leann: {chat_response}")
if __name__ == "__main__":
asyncio.run(main())

43
examples/mlx_demo.py Normal file
View File

@@ -0,0 +1,43 @@
import os
from leann.api import LeannBuilder, LeannChat
# Define the path for our new MLX-based index
INDEX_PATH = "./mlx_diskann_index/leann"
if os.path.exists(INDEX_PATH + ".meta.json"):
print(f"Index already exists at {INDEX_PATH}. Skipping build.")
else:
print("Initializing LeannBuilder with MLX support...")
# 1. Configure LeannBuilder to use MLX
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
embedding_mode="mlx",
)
# 2. Add documents
print("Adding documents...")
docs = [
"MLX is an array framework for machine learning on Apple silicon.",
"It was designed by Apple's machine learning research team.",
"The mlx-community organization provides pre-trained models in MLX format.",
"It supports operations on multi-dimensional arrays.",
"Leann can now use MLX for its embedding models.",
]
for doc in docs:
builder.add_text(doc)
# 3. Build the index
print(f"Building the MLX-based index at: {INDEX_PATH}")
builder.build_index(INDEX_PATH)
print("\nSuccessfully built the index with MLX embeddings!")
print(f"Check the metadata file: {INDEX_PATH}.meta.json")
chat = LeannChat(index_path=INDEX_PATH)
# add query
query = "MLX is an array framework for machine learning on Apple silicon."
print(f"Query: {query}")
response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1)
print(f"Response: {response}")

View File

@@ -0,0 +1,250 @@
#!/usr/bin/env python3
"""
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
This example demonstrates how to use LEANN's metadata filtering to create
a spoiler-free book RAG system where users can search for information
up to a specific chapter they've read.
Usage:
python spoiler_free_book_rag.py
"""
import os
import sys
from typing import Any, Optional
# Add LEANN to path (adjust path as needed)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
from leann.api import LeannBuilder, LeannSearcher
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
"""
Create sample book chunks with metadata for demonstration.
In a real implementation, this would parse actual book files (epub, txt, etc.)
and extract chapter boundaries, character mentions, etc.
Args:
book_title: Title of the book
Returns:
List of chunk dictionaries with text and metadata
"""
# Sample book chunks with metadata
# In practice, you'd use proper text processing libraries
sample_chunks = [
{
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
"metadata": {
"book": book_title,
"chapter": 1,
"page": 1,
"characters": ["Alice", "Sister"],
"themes": ["boredom", "curiosity"],
"location": "riverbank",
},
},
{
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
"metadata": {
"book": book_title,
"chapter": 1,
"page": 2,
"characters": ["Alice", "White Rabbit"],
"themes": ["decision", "surprise", "magic"],
"location": "riverbank",
},
},
{
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
"metadata": {
"book": book_title,
"chapter": 2,
"page": 15,
"characters": ["Alice"],
"themes": ["falling", "wonder", "transformation"],
"location": "rabbit hole",
},
},
{
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
"metadata": {
"book": book_title,
"chapter": 6,
"page": 85,
"characters": ["Alice", "Cheshire Cat"],
"themes": ["madness", "philosophy", "identity"],
"location": "Duchess's house",
},
},
{
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
"metadata": {
"book": book_title,
"chapter": 8,
"page": 120,
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
"themes": ["justice", "absurdity", "authority"],
"location": "Queen's court",
},
},
{
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
"metadata": {
"book": book_title,
"chapter": 12,
"page": 180,
"characters": ["Alice", "Sister", "Rabbit"],
"themes": ["revelation", "reality", "growth"],
"location": "riverbank",
},
},
]
return sample_chunks
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
"""
Build a LEANN index with book chunks that include spoiler metadata.
Args:
book_chunks: List of book chunks with metadata
index_name: Name for the index
Returns:
Path to the built index
"""
print(f"📚 Building spoiler-free book index: {index_name}")
# Initialize LEANN builder
builder = LeannBuilder(
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
)
# Add each chunk with its metadata
for chunk in book_chunks:
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
# Build the index
index_path = f"{index_name}_book_index"
builder.build_index(index_path)
print(f"✅ Index built successfully: {index_path}")
return index_path
def spoiler_free_search(
index_path: str,
query: str,
max_chapter: int,
character_filter: Optional[list[str]] = None,
) -> list[dict[str, Any]]:
"""
Perform a spoiler-free search on the book index.
Args:
index_path: Path to the LEANN index
query: Search query
max_chapter: Maximum chapter number to include
character_filter: Optional list of characters to focus on
Returns:
List of search results safe for the reader
"""
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
searcher = LeannSearcher(index_path)
metadata_filters = {"chapter": {"<=": max_chapter}}
if character_filter:
metadata_filters["characters"] = {"contains": character_filter[0]}
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
return results
def demo_spoiler_free_rag():
"""
Demonstrate the spoiler-free book RAG system.
"""
print("🎭 Spoiler-Free Book RAG Demo")
print("=" * 40)
# Step 1: Prepare book data
book_title = "Alice's Adventures in Wonderland"
book_chunks = chunk_book_with_metadata(book_title)
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
# Step 2: Build the index (in practice, this would be done once)
try:
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
except Exception as e:
print(f"❌ Failed to build index (likely missing dependencies): {e}")
print(
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
)
return
# Step 3: Demonstrate various spoiler-free searches
search_scenarios = [
{
"description": "Reader who has only read Chapter 1",
"query": "What can you tell me about the rabbit?",
"max_chapter": 1,
},
{
"description": "Reader who has read up to Chapter 5",
"query": "Tell me about Alice's adventures",
"max_chapter": 5,
},
{
"description": "Reader who has read most of the book",
"query": "What does the Cheshire Cat represent?",
"max_chapter": 10,
},
{
"description": "Reader who has read the whole book",
"query": "What can you tell me about the rabbit?",
"max_chapter": 12,
},
]
for scenario in search_scenarios:
print(f"\n📚 Scenario: {scenario['description']}")
print(f" Query: {scenario['query']}")
try:
results = spoiler_free_search(
index_path=index_path,
query=scenario["query"],
max_chapter=scenario["max_chapter"],
)
print(f" 📄 Found {len(results)} results:")
for i, result in enumerate(results[:3], 1): # Show top 3
chapter = result.metadata.get("chapter", "?")
location = result.metadata.get("location", "?")
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
except Exception as e:
print(f" ❌ Search failed: {e}")
if __name__ == "__main__":
print("📚 LEANN Spoiler-Free Book RAG Example")
print("=====================================")
try:
demo_spoiler_free_rag()
except ImportError as e:
print(f"❌ Cannot run demo due to missing dependencies: {e}")
except Exception as e:
print(f"❌ Error running demo: {e}")

View File

@@ -1,32 +0,0 @@
{
"version": "0.1.0",
"backend_name": "diskann",
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
"num_chunks": 6,
"chunks": [
{
"text": "Python is a powerful programming language",
"metadata": {}
},
{
"text": "Machine learning transforms industries",
"metadata": {}
},
{
"text": "Neural networks process complex data",
"metadata": {}
},
{
"text": "Java is a powerful programming language",
"metadata": {}
},
{
"text": "C++ is a powerful programming language",
"metadata": {}
},
{
"text": "C# is a powerful programming language",
"metadata": {}
}
]
}

0
packages/__init__.py Normal file
View File

View File

@@ -1,8 +0,0 @@
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
cmake_minimum_required(VERSION 3.20)
project(leann_backend_diskann_wrapper)
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
add_subdirectory(src/third_party/DiskANN)

View File

@@ -0,0 +1 @@
# This file makes the directory a Python package

View File

@@ -0,0 +1,7 @@
from . import diskann_backend as diskann_backend
from . import graph_partition
# Export main classes and functions
from .graph_partition import GraphPartitioner, partition_graph
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]

View File

@@ -1,31 +1,77 @@
import numpy as np
import os
import json
import struct
from pathlib import Path
from typing import Dict
import contextlib import contextlib
import threading import logging
import time import os
import atexit import struct
import socket
import subprocess
import sys import sys
from pathlib import Path
from typing import Any, Literal, Optional
from leann.registry import register_backend import numpy as np
import psutil
from leann.interface import ( from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendSearcherInterface LeannBackendFactoryInterface,
LeannBackendSearcherInterface,
) )
from . import _diskannpy as diskannpy from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
METRIC_MAP = { logger = logging.getLogger(__name__)
@contextlib.contextmanager
def suppress_cpp_output_if_needed():
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
if os.getenv("CI") == "true":
yield
return
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
if not should_suppress:
# Don't suppress, just yield
yield
return
# Save original file descriptors
stdout_fd = sys.stdout.fileno()
stderr_fd = sys.stderr.fileno()
# Save original stdout/stderr
stdout_dup = os.dup(stdout_fd)
stderr_dup = os.dup(stderr_fd)
try:
# Redirect to /dev/null
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, stdout_fd)
os.dup2(devnull, stderr_fd)
os.close(devnull)
yield
finally:
# Restore original file descriptors
os.dup2(stdout_dup, stdout_fd)
os.dup2(stderr_dup, stderr_fd)
os.close(stdout_dup)
os.close(stderr_dup)
def _get_diskann_metrics():
from . import _diskannpy as diskannpy # type: ignore
return {
"mips": diskannpy.Metric.INNER_PRODUCT, "mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2, "l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE, "cosine": diskannpy.Metric.COSINE,
} }
@contextlib.contextmanager @contextlib.contextmanager
def chdir(path): def chdir(path):
original_dir = os.getcwd() original_dir = os.getcwd()
@@ -35,102 +81,51 @@ def chdir(path):
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
num_vectors, dim = data.shape num_vectors, dim = data.shape
with open(file_path, 'wb') as f: with open(file_path, "wb") as f:
f.write(struct.pack('I', num_vectors)) f.write(struct.pack("I", num_vectors))
f.write(struct.pack('I', dim)) f.write(struct.pack("I", dim))
f.write(data.tobytes()) f.write(data.tobytes())
def _check_port(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
class EmbeddingServerManager: def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
def __init__(self): """
self.server_process = None Calculate smart memory configuration for DiskANN based on data size and system specs.
self.server_port = None
atexit.register(self.stop_server)
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"): Args:
if self.server_process and self.server_process.poll() is None: data: The embedding data array
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
return True
# 检查端口是否已被其他无关进程占用 Returns:
if _check_port(port): tuple: (search_memory_maximum, build_memory_maximum) in GB
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.") """
return True num_vectors, dim = data.shape
print(f"INFO: Starting session-level embedding server as a background process...") # Calculate embedding storage size
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
embedding_size_gb = embedding_size_bytes / (1024**3)
try: # search_memory_maximum: 1/10 of embedding size for optimal PQ compression
command = [ # This controls Product Quantization size - smaller means more compression
sys.executable, search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
"--zmq-port", str(port), # build_memory_maximum: Based on available system RAM for sharding control
"--model-name", model_name # This controls how much memory DiskANN uses during index construction
] available_memory_gb = psutil.virtual_memory().available / (1024**3)
project_root = Path(__file__).parent.parent.parent.parent total_memory_gb = psutil.virtual_memory().total / (1024**3)
print(f"INFO: Running command from project root: {project_root}")
self.server_process = subprocess.Popen( # Use 50% of available memory, but at least 2GB and at most 75% of total
command, build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
cwd=project_root,
# stdout=subprocess.PIPE, logger.info(
# stderr=subprocess.PIPE, f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
text=True, f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
encoding='utf-8' f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
) )
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5 return search_memory_gb, build_memory_gb
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
def _log_monitor(self):
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[EmbeddingServer LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[EmbeddingServer ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self):
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: Server process did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None
@register_backend("diskann") @register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface): class DiskannBackend(LeannBackendFactoryInterface):
@@ -140,134 +135,321 @@ class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
with open(meta_path, 'r') as f:
meta = json.load(f)
dimensions = meta.get("dimensions")
if not dimensions:
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
kwargs['dimensions'] = dimensions
return DiskannSearcher(index_path, **kwargs) return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface): class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
def build(self, data: np.ndarray, index_path: str, **kwargs): def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
"""
Safely cleanup files after partition.
In partition mode, C++ doesn't read _disk.index content,
so we can delete it if all derived files exist.
"""
disk_index_file = index_dir / f"{index_prefix}_disk.index"
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
# Required files that C++ partition mode needs
# Note: C++ generates these with _disk.index suffix
disk_suffix = "_disk.index"
required_files = [
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
f"{index_prefix}_pq_pivots.bin", # PQ table
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
]
# Check if all required files exist
missing_files = []
for filename in required_files:
file_path = index_dir / filename
if not file_path.exists():
missing_files.append(filename)
if missing_files:
logger.warning(
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
)
logger.info("Keeping all original files for safety")
return
# Calculate space savings
space_saved = 0
files_to_delete = []
if disk_index_file.exists():
space_saved += disk_index_file.stat().st_size
files_to_delete.append(disk_index_file)
if beam_search_file.exists():
space_saved += beam_search_file.stat().st_size
files_to_delete.append(beam_search_file)
# Safe to delete!
for file_to_delete in files_to_delete:
try:
os.remove(file_to_delete)
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
except Exception as e:
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
if space_saved > 0:
space_saved_mb = space_saved / (1024 * 1024)
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
# Show what files are kept
logger.info("📁 Kept essential files for partition mode:")
for filename in required_files:
file_path = index_dir / filename
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
logger.info(f" - {filename} ({size_mb:.1f} MB)")
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True) index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32: if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32) data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
data_filename = f"{index_prefix}_data.bin" data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename) _write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(metric_str) # Extract is_recompute from nested backend_kwargs if needed
is_recompute = build_kwargs.get("is_recompute", False)
if not is_recompute and "backend_kwargs" in build_kwargs:
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
# Flatten all backend_kwargs parameters to top level for compatibility
if "backend_kwargs" in build_kwargs:
nested_params = build_kwargs.pop("backend_kwargs")
build_kwargs.update(nested_params)
metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
)
complexity = build_kwargs.get("complexity", 64) # Calculate smart memory configuration if not explicitly provided
graph_degree = build_kwargs.get("graph_degree", 32) if (
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0) "search_memory_maximum" not in build_kwargs
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0) or "build_memory_maximum" not in build_kwargs
num_threads = build_kwargs.get("num_threads", 8) ):
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0) smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
codebook_prefix = "" else:
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...") smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
try: try:
from . import _diskannpy as diskannpy # type: ignore
with chdir(index_dir): with chdir(index_dir):
diskannpy.build_disk_float_index( diskannpy.build_disk_float_index(
metric_enum, metric_enum,
data_filename, data_filename,
index_prefix, index_prefix,
complexity, build_kwargs.get("complexity", 64),
graph_degree, build_kwargs.get("graph_degree", 32),
final_index_ram_limit, build_kwargs.get("search_memory_maximum", smart_search_mem),
indexing_ram_budget, build_kwargs.get("build_memory_maximum", smart_build_mem),
num_threads, build_kwargs.get("num_threads", 8),
pq_disk_bytes, build_kwargs.get("pq_disk_bytes", 0),
codebook_prefix "",
) )
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
except Exception as e: # Auto-partition if is_recompute is enabled
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}") if build_kwargs.get("is_recompute", False):
raise logger.info("is_recompute=True, starting automatic graph partitioning...")
from .graph_partition import partition_graph
# Partition the index using absolute paths
# Convert to absolute paths to avoid issues with working directory changes
absolute_index_dir = Path(index_dir).resolve()
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
disk_graph_path, partition_bin_path = partition_graph(
index_prefix_path=absolute_index_prefix_path,
output_dir=str(absolute_index_dir),
partition_prefix=index_prefix,
)
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
self._safe_cleanup_after_partition(index_dir, index_prefix)
logger.info("✅ Graph partitioning completed successfully!")
logger.info(f" - Disk graph: {disk_graph_path}")
logger.info(f" - Partition file: {partition_bin_path}")
finally: finally:
temp_data_file = index_dir / data_filename temp_data_file = index_dir / data_filename
if temp_data_file.exists(): if temp_data_file.exists():
os.remove(temp_data_file) os.remove(temp_data_file)
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
class DiskannSearcher(LeannBackendSearcherInterface):
class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
path = Path(index_path) super().__init__(
index_dir = path.parent index_path,
index_prefix = path.stem backend_module_name="leann_backend_diskann.diskann_embedding_server",
metric_str = kwargs.get("distance_metric", "mips").lower() **kwargs,
metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
dimensions = kwargs.get("dimensions")
if not dimensions:
raise ValueError("Vector dimension not provided to DiskannSearcher.")
try:
full_index_prefix = str(index_dir / index_prefix)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
) )
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager()
print("✅ DiskANN index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]: # Initialize DiskANN index with suppressed C++ output based on log level
complexity = kwargs.get("complexity", 256) with suppress_cpp_output_if_needed():
beam_width = kwargs.get("beam_width", 4) from . import _diskannpy as diskannpy # type: ignore
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False) distance_metric = kwargs.get("distance_metric", "mips").lower()
skip_search_reorder = kwargs.get("skip_search_reorder", False) metric_enum = _get_diskann_metrics().get(distance_metric)
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False) if metric_enum is None:
dedup_node_dis = kwargs.get("dedup_node_dis", False) raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
if recompute_beighbor_embeddings: self.num_threads = kwargs.get("num_threads", 8)
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
zmq_port = kwargs.get("zmq_port", 6666)
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
if not self.embedding_server_manager.start_server(zmq_port, embedding_model): # For DiskANN, we need to reinitialize the index when zmq_port changes
print(f"WARNING: Failed to start embedding server, falling back to PQ computation") # Store the initialization parameters for later use
kwargs['recompute_beighbor_embeddings'] = False # Note: C++ load method expects the BASE path (without _disk.index suffix)
# C++ internally constructs: index_prefix + "_disk.index"
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
# Auto-detect partition files and set partition_prefix
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
partition_prefix = ""
if partition_graph_file.exists() and partition_bin_file.exists():
# C++ expects full path prefix, not just filename
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
else:
logger.debug("No partition files detected, using standard index files")
self._init_params = {
"metric_enum": metric_enum,
"full_index_prefix": full_index_prefix,
"num_threads": self.num_threads,
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
"cache_mechanism": 1,
"pq_prefix": "",
"partition_prefix": partition_prefix,
}
# Log partition configuration for debugging
if partition_prefix:
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
self._diskannpy = diskannpy
self._current_zmq_port = None
self._index = None
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
def _ensure_index_loaded(self, zmq_port: int):
"""Ensure the index is loaded with the correct zmq_port."""
if self._index is None or self._current_zmq_port != zmq_port:
# Need to (re)load the index with the correct zmq_port
with suppress_cpp_output_if_needed():
if self._index is not None:
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
else:
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
self._index = self._diskannpy.StaticDiskFloatIndex(
self._init_params["metric_enum"],
self._init_params["full_index_prefix"],
self._init_params["num_threads"],
self._init_params["num_nodes_to_cache"],
self._init_params["cache_mechanism"],
zmq_port,
self._init_params["pq_prefix"],
self._init_params["partition_prefix"],
)
self._current_zmq_port = zmq_port
def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
) -> dict[str, Any]:
"""
Search for nearest neighbors using DiskANN index.
Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server
pruning_strategy: PQ candidate selection strategy:
- "global": Use global pruning strategy (default)
- "local": Use local pruning strategy
- "proportional": Not supported in DiskANN, falls back to global
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
# Handle zmq_port compatibility: Ensure index is loaded with correct port
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
self._ensure_index_loaded(zmq_port)
else:
# If not recomputing, we still need an index, use a default port
if self._index is None:
self._ensure_index_loaded(6666) # Default port when not recomputing
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":
raise NotImplementedError(
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
)
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
try: # Map pruning_strategy to DiskANN's global_pruning parameter
if pruning_strategy == "local":
use_global_pruning = False
else: # "global"
use_global_pruning = True
# Strategy:
# - Traversal always uses PQ distances
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
# (fetch embeddings for the final candidate set only)
# - Do not recompute neighbor distances along the path
use_deferred_fetch = True if recompute_embeddings else False
recompute_neighors = False # Expected typo. For backward compatibility.
with suppress_cpp_output_if_needed():
labels, distances = self._index.batch_search( labels, distances = self._index.batch_search(
query, query,
query.shape[0], query.shape[0],
@@ -275,21 +457,15 @@ class DiskannSearcher(LeannBackendSearcherInterface):
complexity, complexity,
beam_width, beam_width,
self.num_threads, self.num_threads,
USE_DEFERRED_FETCH, use_deferred_fetch,
skip_search_reorder, kwargs.get("skip_search_reorder", False),
recompute_beighbor_embeddings, recompute_neighors,
dedup_node_dis, dedup_node_dis,
prune_ratio, prune_ratio,
batch_recompute, batch_recompute,
global_pruning use_global_pruning,
) )
return {"labels": labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self): string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server() return {"labels": string_labels, "distances": distances}

View File

@@ -0,0 +1,472 @@
"""
DiskANN-specific embedding server
"""
import argparse
import json
import logging
import os
import sys
import threading
import time
from pathlib import Path
from typing import Optional
import numpy as np
import zmq
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logger = logging.getLogger(__name__)
# Force set logger level (don't rely on basicConfig in subprocess)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have a handler if none exists
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
def create_diskann_embedding_server(
passages_file: Optional[str] = None,
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers",
distance_metric: str = "l2",
):
"""
Create and start a ZMQ-based embedding server for DiskANN backend.
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
"""
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
from leann.embedding_compute import compute_embeddings
logger.info("Successfully imported unified embedding computation module")
except ImportError as e:
logger.error(f"Failed to import embedding computation module: {e}")
return
finally:
sys.path.pop(0)
# Check port availability
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port):
logger.error(f"Port {zmq_port} is already in use")
return
# Only support metadata file, fail fast for everything else
if not passages_file or not passages_file.endswith(".meta.json"):
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file) as f:
meta = json.load(f)
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# Import protobuf after ensuring the path is correct
try:
from . import embedding_pb2
except ImportError as e:
logger.error(f"Failed to import protobuf module: {e}")
return
def zmq_server_thread():
"""ZMQ server thread using REP socket for universal compatibility"""
context = zmq.Context()
socket = context.socket(
zmq.REP
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 1000)
socket.setsockopt(zmq.SNDTIMEO, 1000)
socket.setsockopt(zmq.LINGER, 0)
while True:
try:
# REP socket receives single-part messages
message = socket.recv()
# Check for empty messages - REP socket requires response to every request
if len(message) == 0:
logger.debug("Received empty message, sending empty response")
socket.send(b"") # REP socket must respond to every request
continue
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
e2e_start = time.time()
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
texts = []
node_ids = []
is_text_request = False
try:
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = list(req_proto.node_ids)
if not node_ids:
raise RuntimeError(
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
)
logger.info(
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
)
except Exception as protobuf_error:
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
# Fallback to msgpack (for BaseSearcher direct text requests)
try:
import msgpack
request = msgpack.unpackb(message)
# For BaseSearcher compatibility, request is a list of texts directly
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
else:
raise ValueError("Not a valid msgpack text request")
except Exception as msgpack_error:
raise RuntimeError(
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
)
# Look up texts by node IDs (only if not direct text request)
if not is_text_request:
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError as e:
logger.error(f"Passage ID {nid} not found: {e}")
raise e
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Debug logging
logger.debug(f"Processing {len(texts)} texts")
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
# Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Prepare response based on request type
if is_text_request:
# For BaseSearcher compatibility: return msgpack format
import msgpack
response_data = msgpack.packb(embeddings.tolist())
else:
# For DiskANN C++ compatibility: return protobuf format
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
# Serialize embeddings data
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString()
# Send response back to the client
socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
traceback.print_exc()
raise
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
This creates its own REP socket, binds to zmq_port, and periodically
checks shutdown_event using recv timeouts to exit cleanly.
"""
logger.info("DiskANN ZMQ server thread started with shutdown support")
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
# Set receive timeout so we can check shutdown_event periodically
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
# REP socket receives single-part messages
message = rep_socket.recv()
# Check for empty messages - REP socket requires response to every request
if not message:
logger.warning("Received empty message, sending empty response")
rep_socket.send(b"")
continue
# Try protobuf first (same logic as original)
texts = []
is_text_request = False
try:
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = list(req_proto.node_ids)
# Look up texts by node IDs
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
except Exception:
# Fallback to msgpack for text requests
try:
import msgpack
request = msgpack.unpackb(message)
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(
f"ZMQ received msgpack text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception:
logger.error("Both protobuf and msgpack parsing failed!")
# Send error response
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
continue
# Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error("NaN or Inf detected in embeddings!")
# Send error response
if is_text_request:
import msgpack
response_data = msgpack.packb([])
else:
resp_proto = embedding_pb2.NodeEmbeddingResponse()
response_data = resp_proto.SerializeToString()
rep_socket.send(response_data)
continue
# Prepare response based on request type
if is_text_request:
# For direct text requests, return msgpack
import msgpack
response_data = msgpack.packb(embeddings.tolist())
else:
# For protobuf requests, return protobuf
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString()
# Send response back to the client
rep_socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
try:
# Send error response for REP socket
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("DiskANN ZMQ server thread exiting gracefully")
# Add shutdown coordination
shutdown_event = threading.Event()
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Clean up other resources
try:
import gc
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
logger.info("Graceful shutdown completed")
sys.exit(0)
# Register signal handlers within this function scope
import signal
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Start ZMQ thread (NOT daemon!)
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread.start()
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while not shutdown_event.is_set():
time.sleep(0.1) # Check shutdown more frequently
except KeyboardInterrupt:
logger.info("DiskANN Server shutting down...")
shutdown_zmq_server()
return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__":
import sys
# Signal handlers are now registered within create_diskann_embedding_server
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument(
"--passages-file",
type=str,
help="Metadata JSON file containing passage sources",
)
parser.add_argument(
"--model-name",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(
"--distance-metric",
type=str,
default="l2",
choices=["l2", "mips", "cosine"],
help="Distance metric for similarity computation",
)
args = parser.parse_args()
# Create and start the DiskANN embedding server
create_diskann_embedding_server(
passages_file=args.passages_file,
zmq_port=args.zmq_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
distance_metric=args.distance_metric,
)

View File

@@ -1,24 +1,25 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: embedding.proto # source: embedding.proto
# ruff: noqa
"""Generated protocol buffer code.""" """Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3') )
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
if _descriptor._USE_C_DESCRIPTORS == False: if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._options = None DESCRIPTOR._options = None
_NODEEMBEDDINGREQUEST._serialized_start = 35 _NODEEMBEDDINGREQUEST._serialized_start = 35
_NODEEMBEDDINGREQUEST._serialized_end = 75 _NODEEMBEDDINGREQUEST._serialized_end = 75

View File

@@ -1,397 +0,0 @@
#!/usr/bin/env python3
"""
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
"""
import pickle
import argparse
import threading
import time
from transformers import AutoTokenizer, AutoModel
import os
from contextlib import contextmanager
import zmq
import numpy as np
RED = "\033[91m"
RESET = "\033[0m"
# 简化的文档存储 - 替代 LazyPassages
class SimpleDocumentStore:
"""简化的文档存储支持任意ID"""
def __init__(self, documents: dict = None):
self.documents = documents or {}
# 默认演示文档
self.default_docs = {
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
def __getitem__(self, doc_id):
doc_id = int(doc_id)
# 优先使用指定的文档
if doc_id in self.documents:
return {"text": self.documents[doc_id]}
# 其次使用默认演示文档
if doc_id in self.default_docs:
return {"text": self.default_docs[doc_id]}
# 对于任意其他ID返回通用文档
fallback_docs = [
"This is a general document about technology and programming concepts.",
"This document discusses machine learning and artificial intelligence topics.",
"This content covers data structures, algorithms, and computer science fundamentals.",
"This is a document about software engineering and development practices.",
"This content focuses on databases, data management, and information systems."
]
# 根据ID选择一个fallback文档
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
return {"text": f"[ID:{doc_id}] {fallback_text}"}
def __len__(self):
return len(self.documents) + len(self.default_docs)
def create_embedding_server_thread(
zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128,
):
"""
在当前线程中创建并运行 embedding server
这个函数设计为在单独的线程中调用
"""
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
try:
# 检查端口是否已被占用
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}")
return
# 初始化模型
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
import torch
# 选择设备
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
if cuda_available:
device = torch.device("cuda")
print("INFO: Using CUDA device")
elif mps_available:
device = torch.device("mps")
print("INFO: Using MPS device (Apple Silicon)")
else:
device = torch.device("cpu")
print("INFO: Using CPU device")
# 加载模型
print(f"INFO: Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
# 优化模型
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# 默认演示文档
demo_documents = {
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
passages = SimpleDocumentStore(demo_documents)
print(f"INFO: Loaded {len(passages)} demo documents")
class DeviceTimer:
"""设备计时器"""
def __init__(self, name="", device=device):
self.name = name
self.device = device
self.start_time = 0
self.end_time = 0
if cuda_available:
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
else:
self.start_event = None
self.end_event = None
@contextmanager
def timing(self):
self.start()
yield
self.end()
def start(self):
if cuda_available:
torch.cuda.synchronize()
self.start_event.record()
else:
if self.device.type == "mps":
torch.mps.synchronize()
self.start_time = time.time()
def end(self):
if cuda_available:
self.end_event.record()
torch.cuda.synchronize()
else:
if self.device.type == "mps":
torch.mps.synchronize()
self.end_time = time.time()
def elapsed_time(self):
if cuda_available:
return self.start_event.elapsed_time(self.end_event) / 1000.0
else:
return self.end_time - self.start_time
def print_elapsed(self):
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
def process_batch(texts_batch, ids_batch, missing_ids):
"""处理文本批次"""
batch_size = len(texts_batch)
print(f"INFO: Processing batch of size {batch_size}")
tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device)
embed_timer = DeviceTimer("embedding (batch)", device)
pool_timer = DeviceTimer("mean pooling (batch)", device)
with tokenize_timer.timing():
encoded_batch = tokenizer.batch_encode_plus(
texts_batch,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt",
return_token_type_ids=False,
)
tokenize_timer.print_elapsed()
seq_length = encoded_batch["input_ids"].size(1)
print(f"Batch size: {batch_size}, Sequence length: {seq_length}")
with to_device_timer.timing():
enc = {k: v.to(device) for k, v in encoded_batch.items()}
to_device_timer.print_elapsed()
with torch.no_grad():
with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
embed_timer.print_elapsed()
with pool_timer.timing():
hidden_states = out.last_hidden_state if hasattr(out, "last_hidden_state") else out
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
batch_embeddings = sum_embeddings / sum_mask
pool_timer.print_elapsed()
return batch_embeddings.cpu().numpy()
# ZMQ server 主循环 - 修改为REP套接字
context = zmq.Context()
socket = context.socket(zmq.ROUTER) # 改为REP套接字
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
# 设置超时
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5秒接收超时
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300秒发送超时
from . import embedding_pb2
print(f"INFO: Embedding server ready to serve requests")
while True:
try:
parts = socket.recv_multipart()
# --- 恢复稳健的消息格式判断 ---
# 必须检查 parts 的长度,避免 IndexError
if len(parts) >= 3:
identity = parts[0]
# empty = parts[1] # 中间的空帧我们通常不关心
message = parts[2]
elif len(parts) == 2:
# 也能处理没有空帧的情况
identity = parts[0]
message = parts[1]
else:
# 如果收到格式错误的消息,打印警告并忽略它,而不是崩溃
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
continue
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup", device)
# 解析请求
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = req_proto.node_ids
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
# 添加调试信息
if len(node_ids) > 0:
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
# 查找文本
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
txtinfo = passages[nid]
txt = txtinfo["text"]
texts.append(txt)
lookup_timer.print_elapsed()
if missing_ids:
print(f"WARNING: Missing passages for IDs: {missing_ids}")
# 处理批次
total_size = len(texts)
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
all_embeddings = []
if total_size > max_batch_size:
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings)
print(f"INFO: Combined embeddings shape: {hidden.shape}")
else:
hidden = process_batch(texts, node_ids, missing_ids)
# 序列化响应
ser_start = time.time()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
resp_proto.missing_ids.extend(missing_ids)
response_data = resp_proto.SerializeToString()
# REP 套接字发送单个响应
socket.send_multipart([identity, b'', response_data])
ser_end = time.time()
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
except zmq.Again:
print("INFO: ZMQ socket timeout, continuing to listen")
# REP套接字不需要重新创建只需要继续监听
continue
except Exception as e:
print(f"ERROR: Error in ZMQ server: {e}")
try:
# 发送空响应以维持REQ-REP状态
empty_resp = embedding_pb2.NodeEmbeddingResponse()
socket.send(empty_resp.SerializeToString())
except:
# 如果发送失败重新创建socket
socket.close()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 5000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
print("INFO: ZMQ socket recreated after error")
except Exception as e:
print(f"ERROR: Failed to start embedding server: {e}")
raise
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
def create_embedding_server(
domain="demo",
load_passages=True,
load_embeddings=False,
use_fp16=True,
use_int8=False,
use_cuda_graphs=False,
zmq_port=5555,
max_batch_size=128,
lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2",
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False)
parser.add_argument("--use-int8", action="store_true", default=False)
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
args = parser.parse_args()
create_embedding_server(
domain=args.domain,
load_passages=args.load_passages,
load_embeddings=args.load_embeddings,
use_fp16=args.use_fp16,
use_int8=args.use_int8,
use_cuda_graphs=args.use_cuda_graphs,
zmq_port=args.zmq_port,
max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name,
)

View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
"""
Graph Partition Module for LEANN DiskANN Backend
This module provides Python bindings for the graph partition functionality
of DiskANN, allowing users to partition disk-based indices for better
performance.
"""
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
class GraphPartitioner:
"""
A Python interface for DiskANN's graph partition functionality.
This class provides methods to partition disk-based indices for improved
search performance and memory efficiency.
"""
def __init__(self, build_type: str = "release"):
"""
Initialize the GraphPartitioner.
Args:
build_type: Build type for the executables ("debug" or "release")
"""
self.build_type = build_type
self._ensure_executables()
def _get_executable_path(self, name: str) -> str:
"""Get the path to a graph partition executable."""
# Get the directory where this Python module is located
module_dir = Path(__file__).parent
# Navigate to the graph_partition directory
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
if not executable_path.exists():
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
return str(executable_path)
def _ensure_executables(self):
"""Ensure that the required executables are built."""
try:
self._get_executable_path("partitioner")
self._get_executable_path("index_relayout")
except FileNotFoundError:
# Try to build the executables automatically
print("Executables not found, attempting to build them...")
self._build_executables()
def _build_executables(self):
"""Build the required executables."""
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Clean any existing build
if (graph_partition_dir / "build").exists():
shutil.rmtree(graph_partition_dir / "build")
# Run the build script
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
# Check if executables were created
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
print(f"✅ Built partitioner: {partitioner_path}")
print(f"✅ Built index_relayout: {relayout_path}")
except Exception as e:
raise RuntimeError(f"Failed to build executables: {e}")
finally:
os.chdir(original_dir)
def partition_graph(
self,
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
**kwargs,
) -> tuple[str, str]:
"""
Partition a disk-based index for improved performance.
Args:
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
output_dir: Output directory for results (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
**kwargs: Additional parameters for graph partitioning:
- gp_times: Number of LDG partition iterations (default: 10)
- lock_nums: Number of lock nodes (default: 10)
- cut: Cut adjacency list degree (default: 100)
- scale_factor: Scale factor (default: 1)
- data_type: Data type (default: "float")
- thread_nums: Number of threads (default: 10)
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
Raises:
RuntimeError: If the partitioning process fails
"""
# Set default parameters
params = {
"gp_times": 10,
"lock_nums": 10,
"cut": 100,
"scale_factor": 1,
"data_type": "float",
"thread_nums": 10,
**kwargs,
}
# Determine output directory
if output_dir is None:
output_dir = str(Path(index_prefix_path).parent)
# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Determine partition prefix
if partition_prefix is None:
partition_prefix = Path(index_prefix_path).name
# Get executable paths
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
# Change to the graph_partition directory for temporary files
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Create temporary data directory
temp_data_dir = Path(temp_dir) / "data"
temp_data_dir.mkdir(parents=True, exist_ok=True)
# Set up paths for temporary files
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
graph_gp_path = (
graph_path
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
)
graph_gp_path.mkdir(parents=True, exist_ok=True)
# Find input index file
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
if not os.path.exists(old_index_file):
old_index_file = f"{index_prefix_path}_disk.index"
if not os.path.exists(old_index_file):
raise RuntimeError(f"Index file not found: {old_index_file}")
# Run partitioner
gp_file_path = graph_gp_path / "_part.bin"
partitioner_cmd = [
partitioner_path,
"--index_file",
old_index_file,
"--data_type",
params["data_type"],
"--gp_file",
str(gp_file_path),
"-T",
str(params["thread_nums"]),
"--ldg_times",
str(params["gp_times"]),
"--scale",
str(params["scale_factor"]),
"--mode",
"1",
]
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
result = subprocess.run(
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Partitioner failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Run relayout
part_tmp_index = graph_gp_path / "_part_tmp.index"
relayout_cmd = [
relayout_path,
old_index_file,
str(gp_file_path),
params["data_type"],
"1",
]
print(f"Running relayout: {' '.join(relayout_cmd)}")
result = subprocess.run(
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Relayout failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Copy results to output directory
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
shutil.copy2(part_tmp_index, disk_graph_path)
shutil.copy2(gp_file_path, partition_bin_path)
print(f"Results copied to: {output_dir}")
return str(disk_graph_path), str(partition_bin_path)
finally:
os.chdir(original_dir)
def get_partition_info(self, partition_bin_path: str) -> dict:
"""
Get information about a partition file.
Args:
partition_bin_path: Path to the partition binary file
Returns:
Dictionary containing partition information
"""
if not os.path.exists(partition_bin_path):
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
# For now, return basic file information
# In the future, this could parse the binary file for detailed info
stat = os.stat(partition_bin_path)
return {
"file_size": stat.st_size,
"file_path": partition_bin_path,
"modified_time": stat.st_mtime,
}
def partition_graph(
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
build_type: str = "release",
**kwargs,
) -> tuple[str, str]:
"""
Convenience function to partition a graph index.
Args:
index_prefix_path: Path to the index prefix
output_dir: Output directory (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
build_type: Build type for executables ("debug" or "release")
**kwargs: Additional parameters for graph partitioning
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
"""
partitioner = GraphPartitioner(build_type=build_type)
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
# Example usage:
if __name__ == "__main__":
# Example: partition an index
try:
disk_graph_path, partition_bin_path = partition_graph(
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
)
print("Partitioning completed successfully!")
print(f"Disk graph index: {disk_graph_path}")
print(f"Partition binary: {partition_bin_path}")
except Exception as e:
print(f"Partitioning failed: {e}")

View File

@@ -4,13 +4,18 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.1.0" version = "0.3.2"
dependencies = ["leann-core==0.1.0", "numpy"] dependencies = ["leann-core==0.3.2", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# 关键:简化的 CMake 路径 # Key: simplified CMake path
cmake.source-dir = "third_party/DiskANN" cmake.source-dir = "third_party/DiskANN"
# 关键:Python 包在根目录,路径完全匹配 # Key: Python package in root directory, paths match exactly
wheel.packages = ["leann_backend_diskann"] wheel.packages = ["leann_backend_diskann"]
# 使用默认的 redirect 模式 # Use default redirect mode
editable.mode = "redirect" editable.mode = "redirect"
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]
# Let CMake find packages via Homebrew prefix
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}

View File

@@ -1,7 +1,47 @@
# 最终简化版
cmake_minimum_required(VERSION 3.24) cmake_minimum_required(VERSION 3.24)
project(leann_backend_hnsw_wrapper) project(leann_backend_hnsw_wrapper)
set(CMAKE_C_COMPILER_WORKS 1)
set(CMAKE_CXX_COMPILER_WORKS 1)
# Set OpenMP path for macOS
if(APPLE)
# Detect Homebrew installation path (Apple Silicon vs Intel)
if(EXISTS "/opt/homebrew/opt/libomp")
set(HOMEBREW_PREFIX "/opt/homebrew")
elseif(EXISTS "/usr/local/opt/libomp")
set(HOMEBREW_PREFIX "/usr/local")
else()
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
endif()
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp")
set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib")
# Force use of system libc++ to avoid version mismatch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
# Set minimum macOS version for better compatibility
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
endif()
# Use system ZeroMQ instead of building from source
find_package(PkgConfig REQUIRED)
pkg_check_modules(ZMQ REQUIRED libzmq)
# Add cppzmq headers
include_directories(third_party/cppzmq)
# Configure msgpack-c - disable boost dependency
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
add_compile_definitions(MSGPACK_NO_BOOST)
include_directories(third_party/msgpack-c/include)
# Faiss configuration - streamlined build
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE) set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
@@ -9,4 +49,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE) set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
# Disable additional SIMD versions to speed up compilation
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
# Additional optimization options from INSTALL.md
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
# Avoid building demos and benchmarks
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
# NEW: Tell Faiss to only build the generic version
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
# IMPORTANT: Disable building AVX versions to speed up compilation
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
add_subdirectory(third_party/faiss) add_subdirectory(third_party/faiss)

View File

@@ -1 +1 @@
from . import hnsw_backend from . import hnsw_backend as hnsw_backend

View File

@@ -1,69 +1,98 @@
import struct
import sys
import numpy as np
import os
import argparse import argparse
import gc # Import garbage collector interface import gc # Import garbage collector interface
import logging
import os
import struct
import sys
import time import time
import numpy as np
# Set up logging to avoid print buffer issues
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# --- FourCCs (add more if needed) --- # --- FourCCs (add more if needed) ---
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little') INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
# Add other HNSW fourccs if you expect different storage types inside HNSW # Add other HNSW fourccs if you expect different storage types inside HNSW
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little') # INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little') # INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example # INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little') NULL_INDEX_FOURCC = int.from_bytes(b"null", "little")
# --- Helper functions for reading/writing binary data --- # --- Helper functions for reading/writing binary data ---
def read_struct(f, fmt): def read_struct(f, fmt):
"""Reads data according to the struct format.""" """Reads data according to the struct format."""
size = struct.calcsize(fmt) size = struct.calcsize(fmt)
data = f.read(size) data = f.read(size)
if len(data) != size: if len(data) != size:
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.") raise EOFError(
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
)
return struct.unpack(fmt, data)[0] return struct.unpack(fmt, data)[0]
def read_vector_raw(f, element_fmt_char): def read_vector_raw(f, element_fmt_char):
"""Reads a vector (size followed by data), returns count and raw bytes.""" """Reads a vector (size followed by data), returns count and raw bytes."""
count = -1 # Initialize count count = -1 # Initialize count
total_bytes = -1 # Initialize total_bytes total_bytes = -1 # Initialize total_bytes
try: try:
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned count = read_struct(f, "<Q") # size_t usually 64-bit unsigned
element_size = struct.calcsize(element_fmt_char) element_size = struct.calcsize(element_fmt_char)
# --- FIX for MemoryError: Check for unreasonably large count --- # --- FIX for MemoryError: Check for unreasonably large count ---
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
if count > max_reasonable_count or count < 0: if count > max_reasonable_count or count < 0:
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.") raise MemoryError(
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
)
total_bytes = count * element_size total_bytes = count * element_size
# --- FIX for MemoryError: Check for huge byte size before allocation --- # --- FIX for MemoryError: Check for huge byte size before allocation ---
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.") raise MemoryError(
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
)
data_bytes = f.read(total_bytes) data_bytes = f.read(total_bytes)
if len(data_bytes) != total_bytes: if len(data_bytes) != total_bytes:
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.") raise EOFError(
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
)
return count, data_bytes return count, data_bytes
except (MemoryError, OverflowError) as e: except (MemoryError, OverflowError) as e:
# Add context to the error message # Add context to the error message
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr) print(
f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}",
file=sys.stderr,
)
raise e # Re-raise the original error type raise e # Re-raise the original error type
def read_numpy_vector(f, np_dtype, struct_fmt_char): def read_numpy_vector(f, np_dtype, struct_fmt_char):
"""Reads a vector into a NumPy array.""" """Reads a vector into a NumPy array."""
count = -1 # Initialize count for robust error handling count = -1 # Initialize count for robust error handling
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True) print(
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
end="",
flush=True,
)
try: try:
count, data_bytes = read_vector_raw(f, struct_fmt_char) count, data_bytes = read_vector_raw(f, struct_fmt_char)
print(f"Count={count}, Bytes={len(data_bytes)}") print(f"Count={count}, Bytes={len(data_bytes)}")
if count > 0 and len(data_bytes) > 0: if count > 0 and len(data_bytes) > 0:
arr = np.frombuffer(data_bytes, dtype=np_dtype) arr = np.frombuffer(data_bytes, dtype=np_dtype)
if arr.size != count: if arr.size != count:
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}") raise ValueError(
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
)
return arr return arr
elif count == 0: elif count == 0:
return np.array([], dtype=np_dtype) return np.array([], dtype=np_dtype)
@@ -71,17 +100,23 @@ def read_numpy_vector(f, np_dtype, struct_fmt_char):
raise ValueError("Read zero bytes but count > 0.") raise ValueError("Read zero bytes but count > 0.")
except MemoryError as e: except MemoryError as e:
# Now count should be defined (or -1 if error was in read_struct) # Now count should be defined (or -1 if error was in read_struct)
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr) print(
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
file=sys.stderr,
)
raise e raise e
except Exception as e: # Catch other potential errors like ValueError except Exception as e: # Catch other potential errors like ValueError
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr) print(
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
file=sys.stderr,
)
raise e raise e
def write_numpy_vector(f, arr, struct_fmt_char): def write_numpy_vector(f, arr, struct_fmt_char):
"""Writes a NumPy array as a vector (size followed by data).""" """Writes a NumPy array as a vector (size followed by data)."""
count = arr.size count = arr.size
f.write(struct.pack('<Q', count)) f.write(struct.pack("<Q", count))
try: try:
expected_dtype = np.dtype(struct_fmt_char) expected_dtype = np.dtype(struct_fmt_char)
if arr.dtype != expected_dtype: if arr.dtype != expected_dtype:
@@ -91,21 +126,28 @@ def write_numpy_vector(f, arr, struct_fmt_char):
f.write(data_to_write) f.write(data_to_write)
del data_to_write # Hint GC del data_to_write # Hint GC
except MemoryError as e: except MemoryError as e:
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr) print(
f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}",
file=sys.stderr,
)
raise e raise e
def write_list_vector(f, lst, struct_fmt_char): def write_list_vector(f, lst, struct_fmt_char):
"""Writes a Python list as a vector iteratively.""" """Writes a Python list as a vector iteratively."""
count = len(lst) count = len(lst)
f.write(struct.pack('<Q', count)) f.write(struct.pack("<Q", count))
fmt = '<' + struct_fmt_char fmt = "<" + struct_fmt_char
chunk_size = 1024 * 1024 chunk_size = 1024 * 1024
element_size = struct.calcsize(fmt) element_size = struct.calcsize(fmt)
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation # Allocate buffer outside the loop if possible, or handle MemoryError during allocation
try: try:
buffer = bytearray(chunk_size * element_size) buffer = bytearray(chunk_size * element_size)
except MemoryError: except MemoryError:
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr) print(
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
file=sys.stderr,
)
raise raise
buffer_count = 0 buffer_count = 0
@@ -120,61 +162,75 @@ def write_list_vector(f, lst, struct_fmt_char):
buffer_count = 0 buffer_count = 0
except struct.error as e: except struct.error as e:
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr) print(
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
file=sys.stderr,
)
raise e raise e
def get_cum_neighbors(cum_nneighbor_per_level_np, level): def get_cum_neighbors(cum_nneighbor_per_level_np, level):
"""Helper to get cumulative neighbors count, matching C++ logic.""" """Helper to get cumulative neighbors count, matching C++ logic."""
if level < 0: return 0 if level < 0:
return 0
if level < len(cum_nneighbor_per_level_np): if level < len(cum_nneighbor_per_level_np):
return cum_nneighbor_per_level_np[level] return cum_nneighbor_per_level_np[level]
else: else:
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0 return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
levels_np, compact_level_ptr, compact_node_offsets_np, def write_compact_format(
compact_neighbors_data, storage_fourcc, storage_data): f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
storage_fourcc,
storage_data,
):
"""Write HNSW data in compact format following C++ read order exactly.""" """Write HNSW data in compact format following C++ read order exactly."""
# Write IndexHNSW Header # Write IndexHNSW Header
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc'])) f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack('<i', original_hnsw_data['d'])) f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack('<q', original_hnsw_data['ntotal'])) f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack('<q', original_hnsw_data['dummy1'])) f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack('<q', original_hnsw_data['dummy2'])) f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack('<?', original_hnsw_data['is_trained'])) f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack('<i', original_hnsw_data['metric_type'])) f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data['metric_type'] > 1: if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg'])) f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
# Write HNSW struct parts (standard order) # Write HNSW struct parts (standard order)
write_numpy_vector(f_out, assign_probas_np, 'd') write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i') write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, 'i') write_numpy_vector(f_out, levels_np, "i")
# Write compact format flag # Write compact format flag
f_out.write(struct.pack('<?', True)) # storage_is_compact = True f_out.write(struct.pack("<?", True)) # storage_is_compact = True
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST # Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
if isinstance(compact_level_ptr, np.ndarray): if isinstance(compact_level_ptr, np.ndarray):
write_numpy_vector(f_out, compact_level_ptr, 'Q') write_numpy_vector(f_out, compact_level_ptr, "Q")
else: else:
write_list_vector(f_out, compact_level_ptr, 'Q') write_list_vector(f_out, compact_level_ptr, "Q")
write_numpy_vector(f_out, compact_node_offsets_np, 'Q') write_numpy_vector(f_out, compact_node_offsets_np, "Q")
# Write HNSW scalar parameters # Write HNSW scalar parameters
f_out.write(struct.pack('<i', original_hnsw_data['entry_point'])) f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack('<i', original_hnsw_data['max_level'])) f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction'])) f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack('<i', original_hnsw_data['efSearch'])) f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam'])) f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
# Write storage fourcc (this determines how to read what follows) # Write storage fourcc (this determines how to read what follows)
f_out.write(struct.pack('<I', storage_fourcc)) f_out.write(struct.pack("<I", storage_fourcc))
# Write compact neighbors data AFTER storage fourcc # Write compact neighbors data AFTER storage fourcc
write_list_vector(f_out, compact_neighbors_data, 'i') write_list_vector(f_out, compact_neighbors_data, "i")
# Write storage data if not NULL (only after neighbors) # Write storage data if not NULL (only after neighbors)
if storage_fourcc != NULL_INDEX_FOURCC and storage_data: if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
@@ -183,6 +239,7 @@ def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneigh
# --- Main Conversion Logic --- # --- Main Conversion Logic ---
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True): def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
""" """
Converts an HNSW graph file to the CSR format. Converts an HNSW graph file to the CSR format.
@@ -193,94 +250,120 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
output_filename: Output CSR index file output_filename: Output CSR index file
prune_embeddings: Whether to prune embedding storage (write NULL storage marker) prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
""" """
# Keep prints simple; rely on CI runner to flush output as needed
print(f"Starting conversion: {input_filename} -> {output_filename}") print(f"Starting conversion: {input_filename} -> {output_filename}")
start_time = time.time() start_time = time.time()
original_hnsw_data = {} original_hnsw_data = {}
neighbors_np = None # Initialize to allow check in finally block neighbors_np = None # Initialize to allow check in finally block
try: try:
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out: with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
# --- Read IndexHNSW FourCC and Header --- # --- Read IndexHNSW FourCC and Header ---
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...") print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
# ... (Keep the header reading logic as before) ... # ... (Keep the header reading logic as before) ...
hnsw_index_fourcc = read_struct(f_in, '<I') hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS: if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr) print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False return False
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data['d'] = read_struct(f_in, '<i') original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data['ntotal'] = read_struct(f_in, '<q') original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data['dummy1'] = read_struct(f_in, '<q') original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data['dummy2'] = read_struct(f_in, '<q') original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data['is_trained'] = read_struct(f_in, '?') original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data['metric_type'] = read_struct(f_in, '<i') original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data['metric_arg'] = 0.0 original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data['metric_type'] > 1: if original_hnsw_data["metric_type"] > 1:
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f') original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}") print(
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
)
# --- Read original HNSW struct data --- # --- Read original HNSW struct data ---
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...") print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd') assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})") print(
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
)
gc.collect() gc.collect()
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i') cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})") print(
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
)
gc.collect() gc.collect()
levels_np = read_numpy_vector(f_in, np.int32, 'i') levels_np = read_numpy_vector(f_in, np.int32, "i")
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})") print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
gc.collect() gc.collect()
ntotal = len(levels_np) ntotal = len(levels_np)
if ntotal != original_hnsw_data['ntotal']: if ntotal != original_hnsw_data["ntotal"]:
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr) print(
original_hnsw_data['ntotal'] = ntotal f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.",
file=sys.stderr,
)
original_hnsw_data["ntotal"] = ntotal
# --- Check for compact format flag --- # --- Check for compact format flag ---
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...") print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
pos_before_compact = f_in.tell() pos_before_compact = f_in.tell()
try: try:
is_compact_flag = read_struct(f_in, '<?') is_compact_flag = read_struct(f_in, "<?")
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}") print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
if is_compact_flag: if is_compact_flag:
# Input is already in compact format - read compact data # Input is already in compact format - read compact data
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...") print(
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
)
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q') compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})") print(
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})"
)
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q') compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})") print(
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
)
# Read scalar parameters # Read scalar parameters
original_hnsw_data['entry_point'] = read_struct(f_in, '<i') original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data['max_level'] = read_struct(f_in, '<i') original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i') original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data['efSearch'] = read_struct(f_in, '<i') original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i') original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})") print(
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
)
# Read storage fourcc # Read storage fourcc
storage_fourcc = read_struct(f_in, '<I') storage_fourcc = read_struct(f_in, "<I")
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}") print(
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
)
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC: if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
# Read compact neighbors data # Read compact neighbors data
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i') compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})") print(
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
)
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
# Skip storage data and write with NULL marker # Skip storage data and write with NULL marker
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.") print(
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
)
storage_fourcc = NULL_INDEX_FOURCC storage_fourcc = NULL_INDEX_FOURCC
elif not prune_embeddings: elif not prune_embeddings:
# Read and preserve compact neighbors and storage # Read and preserve compact neighbors and storage
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i') compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
@@ -288,16 +371,25 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
storage_data = f_in.read() storage_data = f_in.read()
else: else:
# Already pruned (NULL storage) # Already pruned (NULL storage)
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i') compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
storage_data = b'' storage_data = b""
# Write the updated compact format # Write the updated compact format
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...") print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, write_compact_format(
levels_np, compact_level_ptr, compact_node_offsets_np, f_out,
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'') original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
storage_fourcc,
storage_data if not prune_embeddings else b"",
)
print(f"[{time.time() - start_time:.2f}s] Conversion complete.") print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
return True return True
@@ -305,63 +397,86 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
else: else:
# is_compact=False, rewind and read original format # is_compact=False, rewind and read original format
f_in.seek(pos_before_compact) f_in.seek(pos_before_compact)
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...") print(
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
)
except EOFError: except EOFError:
# No compact flag found, assume original format # No compact flag found, assume original format
f_in.seek(pos_before_compact) f_in.seek(pos_before_compact)
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...") print(
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
)
# --- Handle potential extra byte in original format (like C++ code) --- # --- Handle potential extra byte in original format (like C++ code) ---
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...") print(
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
)
pos_before_probe = f_in.tell() pos_before_probe = f_in.tell()
try: try:
suspected_flag = read_struct(f_in, '<B') # Read 1 byte suspected_flag = read_struct(f_in, "<B") # Read 1 byte
if suspected_flag == 0x00: if suspected_flag == 0x00:
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.") print(
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
)
elif suspected_flag == 0x01: elif suspected_flag == 0x01:
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False") print(
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
)
raise ValueError("Inconsistent compact flag state") raise ValueError("Inconsistent compact flag state")
else: else:
# Rewind - this byte is part of offsets data # Rewind - this byte is part of offsets data
f_in.seek(pos_before_probe) f_in.seek(pos_before_probe)
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})") print(
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
)
except EOFError: except EOFError:
f_in.seek(pos_before_probe) f_in.seek(pos_before_probe)
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read") print(
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
)
# --- Read original format data --- # --- Read original format data ---
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q') offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})") print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
if len(offsets_np) != ntotal + 1: if len(offsets_np) != ntotal + 1:
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}") raise ValueError(
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
)
gc.collect() gc.collect()
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...") print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
neighbors_np = read_numpy_vector(f_in, np.int32, 'i') neighbors_np = read_numpy_vector(f_in, np.int32, "i")
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})") print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0 expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
if neighbors_np.size != expected_neighbors_size: if neighbors_np.size != expected_neighbors_size:
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.") print(
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
)
gc.collect() gc.collect()
original_hnsw_data['entry_point'] = read_struct(f_in, '<i') original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data['max_level'] = read_struct(f_in, '<i') original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i') original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data['efSearch'] = read_struct(f_in, '<i') original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i') original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})") print(
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
)
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...") print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
storage_fourcc = None storage_fourcc = None
try: try:
storage_fourcc = read_struct(f_in, '<I') storage_fourcc = read_struct(f_in, "<I")
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.") print(
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
)
except EOFError: except EOFError:
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).") print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
except Exception as e: except Exception as e:
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}") print(
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
)
# --- Perform Conversion --- # --- Perform Conversion ---
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...") print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
@@ -380,10 +495,14 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1% if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
progress = (i / ntotal) * 100 progress = (i / ntotal) * 100
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="") print(
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
end="",
)
node_max_level = levels_np[i] - 1 node_max_level = levels_np[i] - 1
if node_max_level < -1: node_max_level = -1 if node_max_level < -1:
node_max_level = -1
node_ptr_start_index = current_level_ptr_idx node_ptr_start_index = current_level_ptr_idx
compact_node_offsets_np[i] = node_ptr_start_index compact_node_offsets_np[i] = node_ptr_start_index
@@ -394,8 +513,12 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
for level in range(node_max_level + 1): for level in range(node_max_level + 1):
compact_level_ptr.append(current_data_idx) compact_level_ptr.append(current_data_idx)
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level) begin_orig_np = original_offset_start + get_cum_neighbors(
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1) cum_nneighbor_per_level_np, level
)
end_orig_np = original_offset_start + get_cum_neighbors(
cum_nneighbor_per_level_np, level + 1
)
begin_orig = int(begin_orig_np) begin_orig = int(begin_orig_np)
end_orig = int(end_orig_np) end_orig = int(end_orig_np)
@@ -413,71 +536,116 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
if num_valid > 0: if num_valid > 0:
# Append valid neighbors # Append valid neighbors
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask]) compact_neighbors_data.extend(
level_neighbors_slice[valid_neighbors_mask]
)
current_data_idx += num_valid current_data_idx += num_valid
total_valid_neighbors_counted += num_valid total_valid_neighbors_counted += num_valid
compact_level_ptr.append(current_data_idx) compact_level_ptr.append(current_data_idx)
current_level_ptr_idx += num_pointers_expected current_level_ptr_idx += num_pointers_expected
compact_node_offsets_np[ntotal] = current_level_ptr_idx compact_node_offsets_np[ntotal] = current_level_ptr_idx
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line print(
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
) # Clear progress line
# --- Validation Checks --- # --- Validation Checks ---
print(f"[{time.time() - start_time:.2f}s] Running validation checks...") print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
valid_check_passed = True valid_check_passed = True
# Check 1: Total valid neighbors count # Check 1: Total valid neighbors count
print(f" Checking total valid neighbor count...") print(" Checking total valid neighbor count...")
expected_valid_count = np.sum(neighbors_np >= 0) expected_valid_count = np.sum(neighbors_np >= 0)
if total_valid_neighbors_counted != len(compact_neighbors_data): if total_valid_neighbors_counted != len(compact_neighbors_data):
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr) print(
f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False valid_check_passed = False
if expected_valid_count != len(compact_neighbors_data): if expected_valid_count != len(compact_neighbors_data):
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr) print(
f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False valid_check_passed = False
else: else:
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}") print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
# Check 2: Final pointer indices consistency # Check 2: Final pointer indices consistency
print(f" Checking final pointer indices...") print(" Checking final pointer indices...")
if compact_node_offsets_np[ntotal] != len(compact_level_ptr): if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr) print(
f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!",
file=sys.stderr,
)
valid_check_passed = False valid_check_passed = False
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \ if (
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0): len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1 last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr) print(
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False valid_check_passed = False
else: else:
print(f" OK: Final pointers match data size.") print(" OK: Final pointers match data size.")
if not valid_check_passed: if not valid_check_passed:
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr) print(
"Error: Validation checks failed. Output file might be incorrect.",
file=sys.stderr,
)
# Optional: Exit here if validation fails # Optional: Exit here if validation fails
# return False # return False
# --- Explicitly delete large intermediate arrays --- # --- Explicitly delete large intermediate arrays ---
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...") print(
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
)
del neighbors_np del neighbors_np
del offsets_np del offsets_np
gc.collect() gc.collect()
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}") print(
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
)
# --- Write CSR HNSW graph data using unified function --- # --- Write CSR HNSW graph data using unified function ---
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...") print(
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
)
# Determine storage fourcc based on prune_embeddings # Determine storage fourcc and data based on prune_embeddings
output_storage_fourcc = NULL_INDEX_FOURCC if prune_embeddings else (storage_fourcc if 'storage_fourcc' in locals() else NULL_INDEX_FOURCC)
if prune_embeddings: if prune_embeddings:
print(f" Pruning embeddings: Writing NULL storage marker.") print(" Pruning embeddings: Writing NULL storage marker.")
storage_data = b'' output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
else:
# Keep embeddings - read and preserve original storage data
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
print(" Preserving embeddings: Reading original storage data...")
storage_data = f_in.read() # Read remaining storage data
output_storage_fourcc = storage_fourcc
print(f" Read {len(storage_data)} bytes of storage data")
else:
print(" No embeddings found in original file (NULL storage)")
output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
# Use the unified write function # Use the unified write function
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, write_compact_format(
levels_np, compact_level_ptr, compact_node_offsets_np, f_out,
compact_neighbors_data, output_storage_fourcc, storage_data if not prune_embeddings else b'') original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
output_storage_fourcc,
storage_data,
)
# Clean up memory # Clean up memory
del assign_probas_np, cum_nneighbor_per_level_np, levels_np del assign_probas_np, cum_nneighbor_per_level_np, levels_np
@@ -492,40 +660,66 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
print(f"Error: Input file not found: {input_filename}", file=sys.stderr) print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
return False return False
except MemoryError as e: except MemoryError as e:
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr) print(
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
file=sys.stderr,
)
# Clean up potentially partially written output file? # Clean up potentially partially written output file?
try: os.remove(output_filename) try:
except OSError: pass os.remove(output_filename)
except OSError:
pass
return False return False
except EOFError as e: except EOFError as e:
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr) print(
try: os.remove(output_filename) f"Error: Reached end of file unexpectedly reading {input_filename}. {e}",
except OSError: pass file=sys.stderr,
)
try:
os.remove(output_filename)
except OSError:
pass
return False return False
except Exception as e: except Exception as e:
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr) print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
try: try:
os.remove(output_filename) os.remove(output_filename)
except OSError: pass except OSError:
pass
return False return False
# Ensure neighbors_np is deleted even if an error occurs after its allocation # Ensure neighbors_np is deleted even if an error occurs after its allocation
finally: finally:
if 'neighbors_np' in locals() and neighbors_np is not None: try:
if "neighbors_np" in locals() and neighbors_np is not None:
del neighbors_np del neighbors_np
gc.collect() gc.collect()
except NameError:
pass
# --- Script Execution --- # --- Script Execution ---
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.") parser = argparse.ArgumentParser(
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
)
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file") parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file") parser.add_argument(
parser.add_argument("--prune-embeddings", action="store_true", default=True, "output_csr_graph_file", help="Path to write the output CSR HNSW graph file"
help="Prune embedding storage (write NULL storage marker)") )
parser.add_argument("--keep-embeddings", action="store_true", parser.add_argument(
help="Keep embedding storage (overrides --prune-embeddings)") "--prune-embeddings",
action="store_true",
default=True,
help="Prune embedding storage (write NULL storage marker)",
)
parser.add_argument(
"--keep-embeddings",
action="store_true",
help="Keep embedding storage (overrides --prune-embeddings)",
)
args = parser.parse_args() args = parser.parse_args()
@@ -534,10 +728,12 @@ if __name__ == "__main__":
sys.exit(1) sys.exit(1)
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file): if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr) print("Error: Input and output filenames cannot be the same.", file=sys.stderr)
sys.exit(1) sys.exit(1)
prune_embeddings = args.prune_embeddings and not args.keep_embeddings prune_embeddings = args.prune_embeddings and not args.keep_embeddings
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings) success = convert_hnsw_graph_to_csr(
args.input_index_file, args.output_csr_graph_file, prune_embeddings
)
if not success: if not success:
sys.exit(1) sys.exit(1)

View File

@@ -1,145 +1,39 @@
import numpy as np import logging
import os import os
import json import shutil
import struct
from pathlib import Path
from typing import Dict, Any
import contextlib
import threading
import time import time
import atexit from pathlib import Path
import socket from typing import Any, Literal, Optional
import subprocess
import sys import numpy as np
from leann.interface import (
LeannBackendBuilderInterface,
LeannBackendFactoryInterface,
LeannBackendSearcherInterface,
)
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend logger = logging.getLogger(__name__)
from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface
)
def get_metric_map(): def get_metric_map():
from . import faiss from . import faiss # type: ignore
return { return {
"mips": faiss.METRIC_INNER_PRODUCT, "mips": faiss.METRIC_INNER_PRODUCT,
"l2": faiss.METRIC_L2, "l2": faiss.METRIC_L2,
"cosine": faiss.METRIC_INNER_PRODUCT, "cosine": faiss.METRIC_INNER_PRODUCT,
} }
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
class HNSWEmbeddingServerManager: def normalize_l2(data: np.ndarray) -> np.ndarray:
""" norms = np.linalg.norm(data, axis=1, keepdims=True)
HNSW-specific embedding server manager that handles the lifecycle of the embedding server process. norms[norms == 0] = 1 # Avoid division by zero
Mirrors the DiskANN EmbeddingServerManager architecture. return data / norms
"""
def __init__(self):
self.server_process = None
self.server_port = None
atexit.register(self.stop_server)
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"):
"""
Start the HNSW embedding server process.
Args:
port: ZMQ port for the server
model_name: Name of the embedding model to use
passages_file: Optional path to passages JSON file
distance_metric: The distance metric to use
"""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
return True
# Check if port is already in use
if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external HNSW server is running and connecting to it.")
return True
print(f"INFO: Starting session-level HNSW embedding server as a background process...")
try:
command = [
sys.executable,
"-m", "leann_backend_hnsw.hnsw_embedding_server",
"--zmq-port", str(port),
"--model-name", model_name,
"--distance-metric", distance_metric
]
if passages_file:
command.extend(["--passages-file", str(passages_file)])
project_root = Path(__file__).parent.parent.parent.parent
print(f"INFO: Running HNSW command from project root: {project_root}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding='utf-8'
)
self.server_port = port
print(f"INFO: HNSW server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ HNSW embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: HNSW server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: HNSW server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
except Exception as e:
print(f"❌ ERROR: Failed to start HNSW embedding server process: {e}")
return False
def _log_monitor(self):
"""Monitor server logs"""
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[HNSWEmbeddingServer LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[HNSWEmbeddingServer ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e:
print(f"HNSW Log monitor error: {e}")
def stop_server(self):
"""Stop the HNSW embedding server process"""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating HNSW session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: HNSW server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: HNSW server process did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None
@register_backend("hnsw") @register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface): class HNSWBackend(LeannBackendFactoryInterface):
@@ -149,372 +43,211 @@ class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
with open(meta_path, 'r') as f:
meta = json.load(f)
dimensions = meta.get("dimensions")
if not dimensions:
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
kwargs['dimensions'] = dimensions
return HNSWSearcher(index_path, **kwargs) return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface): class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs.copy() self.build_params = kwargs.copy()
# --- Configuration defaults with standardized names ---
self.is_compact = self.build_params.setdefault("is_compact", True) self.is_compact = self.build_params.setdefault("is_compact", True)
self.is_recompute = self.build_params.setdefault("is_recompute", True) self.is_recompute = self.build_params.setdefault("is_recompute", True)
# --- Additional Options ---
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0)
self.external_storage_path = self.build_params.get("external_storage_path", None)
# --- Standard HNSW parameters ---
self.M = self.build_params.setdefault("M", 32) self.M = self.build_params.setdefault("M", 32)
self.efConstruction = self.build_params.setdefault("efConstruction", 200) self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips") self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions") self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute and self.is_compact:
# Auto-correct: non-recompute requires non-compact storage for HNSW
logger.warning(
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
)
self.is_compact = False
self.build_params["is_compact"] = False
if self.is_skip_neighbors and not self.is_compact: def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
raise ValueError("is_skip_neighbors can only be used with is_compact=True") from . import faiss # type: ignore
if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, index_path: str, **kwargs):
"""Build HNSW index using FAISS"""
from . import faiss
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True) index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32: if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32) data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
metric_str = self.distance_metric.lower() metric_enum = get_metric_map().get(self.distance_metric.lower())
metric_enum = get_metric_map().get(metric_str)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
M = self.M dim = self.dimensions or data.shape[1]
efConstruction = self.efConstruction index = faiss.IndexHNSWFlat(dim, self.M, metric_enum)
dim = self.dimensions index.hnsw.efConstruction = self.efConstruction
if not dim:
dim = data.shape[1]
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...") if self.distance_metric.lower() == "cosine":
data = normalize_l2(data)
try:
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
index.hnsw.efConstruction = efConstruction
if metric_str == "cosine":
faiss.normalize_L2(data)
index.add(data.shape[0], faiss.swig_ptr(data)) index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index" index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file)) faiss.write_index(index, str(index_file))
print(f"✅ HNSW index built successfully at '{index_file}'")
if self.is_compact: if self.is_compact:
self._convert_to_csr(index_file) self._convert_to_csr(index_file)
if self.is_recompute:
self._generate_passages_file(index_dir, index_prefix, **kwargs)
except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise
def _convert_to_csr(self, index_file: Path): def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format""" """Convert built index to CSR format"""
try:
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard" mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...") logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp") csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr( success = convert_hnsw_graph_to_csr(
str(index_file), str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
str(csr_temp_file),
prune_embeddings=self.is_recompute
) )
if success: if success:
print("✅ CSR conversion successful.") logger.info("✅ CSR conversion successful.")
import shutil # index_file_old = index_file.with_suffix(".old")
# shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file)) shutil.move(str(csr_temp_file), str(index_file))
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'") logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
else: else:
# Clean up and fail fast # Clean up and fail fast
if csr_temp_file.exists(): if csr_temp_file.exists():
os.remove(csr_temp_file) os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format") raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
except Exception as e:
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode"""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation")
return
# Generate node_id to text mapping
passages_data = {}
for node_id, chunk in enumerate(chunks):
passages_data[str(node_id)] = chunk["text"]
# Save passages file
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
# Don't raise - this is not critical for index building
pass
class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
"""
Robustly determines the index's storage status by parsing the file.
Returns:
A tuple (is_compact, is_pruned).
"""
if not index_file.exists():
return False, False
with open(index_file, 'rb') as f:
try:
def read_struct(fmt):
size = struct.calcsize(fmt)
data = f.read(size)
if len(data) != size:
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.")
return struct.unpack(fmt, data)[0]
def skip_vector(element_size):
count = read_struct('<Q')
f.seek(count * element_size, 1)
# 1. Read up to the compact flag
read_struct('<I'); read_struct('<i'); read_struct('<q');
read_struct('<q'); read_struct('<q'); read_struct('<?')
metric_type = read_struct('<i')
if metric_type > 1: read_struct('<f')
skip_vector(8); skip_vector(4); skip_vector(4)
# 2. Check if there's a compact flag byte
# Try to read the compact flag, but handle both old and new formats
pos_before_compact = f.tell()
try:
is_compact = read_struct('<?')
print(f"INFO: Detected is_compact flag as: {is_compact}")
except (EOFError, struct.error):
# Old format without compact flag - assume non-compact
f.seek(pos_before_compact)
is_compact = False
print(f"INFO: No compact flag found, assuming is_compact=False")
# 3. Read storage FourCC to determine if pruned
is_pruned = False
try:
if is_compact:
# For compact, we need to skip pointers and scalars to get to the storage FourCC
skip_vector(8) # level_ptr
skip_vector(8) # node_offsets
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
storage_fourcc = read_struct('<I')
else:
# For non-compact, we need to read the flag probe, then skip offsets and neighbors
pos_before_probe = f.tell()
flag_byte = f.read(1)
if not (flag_byte and flag_byte == b'\x00'):
f.seek(pos_before_probe)
skip_vector(8); skip_vector(4) # offsets, neighbors
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
# Now we are at the storage. The entire rest is storage blob.
storage_fourcc = struct.unpack('<I', f.read(4))[0]
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
if storage_fourcc == NULL_INDEX_FOURCC:
is_pruned = True
except (EOFError, struct.error):
# Cannot determine pruning status, assume not pruned
pass
print(f"INFO: Detected is_pruned as: {is_pruned}")
return is_compact, is_pruned
except (EOFError, struct.error) as e:
print(f"WARNING: Could not parse index file to detect format: {e}. Assuming standard, not pruned.")
return False, False
class HNSWSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
from . import faiss super().__init__(
path = Path(index_path) index_path,
index_dir = path.parent backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
index_prefix = path.stem **kwargs,
)
from . import faiss # type: ignore
# Store configuration and paths for later use self.distance_metric = (
self.config = kwargs.copy() self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
self.config["index_path"] = index_path )
self.index_dir = index_dir metric_enum = get_metric_map().get(self.distance_metric)
self.index_prefix = index_prefix
metric_str = self.config.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(metric_str)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
dimensions = self.config.get("dimensions") self.is_compact, self.is_pruned = (
if not dimensions: self.meta.get("is_compact", True),
raise ValueError("Vector dimension not provided to HNSWSearcher.") self.meta.get("is_pruned", True),
)
index_file = index_dir / f"{index_prefix}.index" index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists(): if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}") raise FileNotFoundError(f"HNSW index file not found at {index_file}")
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
# Validate configuration constraints
if not self.is_compact and self.config.get("is_skip_neighbors", False):
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if self.config.get("is_recompute", False) and self.config.get("external_storage_path"):
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
hnsw_config = faiss.HNSWIndexConfig() hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact hnsw_config.is_compact = self.is_compact
hnsw_config.is_recompute = (
# Apply additional configuration options with strict validation self.is_pruned
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False) ) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False)
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = self.config.get("external_storage_path")
hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}")
print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
if self.is_compact: def search(
print("✅ Compact CSR format HNSW index loaded successfully.") self,
else: query: np.ndarray,
print("✅ Standard HNSW index loaded successfully.") top_k: int,
zmq_port: Optional[int] = None,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
batch_size: int = 0,
**kwargs,
) -> dict[str, Any]:
"""
Search for nearest neighbors using HNSW index.
self.metric_str = metric_str Args:
self.embedding_server_manager = HNSWEmbeddingServerManager() query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/efSearch, higher = more accurate but slower
beam_width: Number of parallel search paths/beam_size
prune_ratio: Ratio of neighbors to prune via PQ (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server
pruning_strategy: PQ candidate selection strategy:
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
def _get_index_file(self, index_dir: Path, index_prefix: str) -> Path: Returns:
"""Get the appropriate index file path based on format""" Dict with 'labels' (list of lists) and 'distances' (ndarray)
# We always use the same filename now, format is detected internally """
return index_dir / f"{index_prefix}.index" from . import faiss # type: ignore
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: if not recompute_embeddings and self.is_pruned:
"""Search using HNSW index with optional recompute functionality""" raise RuntimeError(
from . import faiss "Recompute is required for pruned/compact HNSW index. "
# Merge config with search-time kwargs "Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
search_config = self.config.copy() )
search_config.update(kwargs) if recompute_embeddings:
if zmq_port is None:
ef = search_config.get("ef", 200) # Size of the dynamic candidate list for search raise ValueError("zmq_port must be provided if recompute_embeddings is True")
# Recompute parameters
zmq_port = search_config.get("zmq_port", 5557)
embedding_model = search_config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
passages_file = search_config.get("passages_file", None)
# For recompute mode, try to find the passages file automatically
if self.is_pruned and not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
print(f"DEBUG: Checking for passages file at: {potential_passages_file}")
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Found passages file for recompute mode: {passages_file}")
else:
print(f"WARNING: No passages file found for recompute mode at {potential_passages_file}")
# If index is pruned (embeddings removed), we MUST start embedding server for recompute
if self.is_pruned:
print(f"INFO: Index is pruned - starting embedding server for recompute")
# CRITICAL: Check passages file exists - fail fast if not
if not passages_file:
raise RuntimeError(f"FATAL: Index is pruned but no passages file found. Cannot proceed with recompute mode.")
# Check if server is already running first
if _check_port(zmq_port):
print(f"INFO: Embedding server already running on port {zmq_port}")
else:
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str):
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
# Give server extra time to fully initialize
print(f"INFO: Waiting for embedding server to fully initialize...")
time.sleep(3)
# Final verification
if not _check_port(zmq_port):
raise RuntimeError(f"Embedding server failed to start listening on port {zmq_port}")
else:
print(f"INFO: Index has embeddings stored - no recompute needed")
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
if query.ndim == 1: if self.distance_metric == "cosine":
query = np.expand_dims(query, axis=0) query = normalize_l2(query)
# Normalize query if using cosine similarity params = faiss.SearchParametersHNSW()
if self.metric_str == "cosine": if zmq_port is not None:
faiss.normalize_L2(query) params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
params.efSearch = complexity
params.beam_size = beam_width
try: # For OpenAI embeddings with cosine distance, disable relative distance check
# Set search parameter # This prevents early termination when all scores are in a narrow range
self._index.hnsw.efSearch = ef embedding_model = self.meta.get("embedding_model", "").lower()
if self.distance_metric == "cosine" and any(
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
):
params.check_relative_distance = False
else:
params.check_relative_distance = True
# Prepare output arrays for the older FAISS SWIG API # PQ pruning: direct mapping to HNSW's pq_pruning_ratio
batch_size = query.shape[0] params.pq_pruning_ratio = prune_ratio
distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64)
# Use standard FAISS search - recompute is handled internally by FAISS # Map pruning_strategy to HNSW parameters
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels)) if pruning_strategy == "local":
params.local_prune = True
params.send_neigh_times_ratio = 0.0
elif pruning_strategy == "proportional":
params.local_prune = False
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
else: # "global"
params.local_prune = False
params.send_neigh_times_ratio = 0.0
return {"labels": labels, "distances": distances} # HNSW-specific batch processing parameter
params.batch_size = batch_size
except Exception as e: batch_size_query = query.shape[0]
print(f"💥 ERROR: HNSW search failed. Exception: {e}") distances = np.empty((batch_size_query, top_k), dtype=np.float32)
raise labels = np.empty((batch_size_query, top_k), dtype=np.int64)
def __del__(self): search_time = time.time()
if hasattr(self, 'embedding_server_manager'): self._index.search(
self.embedding_server_manager.stop_server() query.shape[0],
faiss.swig_ptr(query),
top_k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
params,
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
return {"labels": string_labels, "distances": distances}

View File

@@ -1,596 +1,428 @@
#!/usr/bin/env python3
""" """
HNSW-specific embedding server with removed config.py dependencies HNSW-specific embedding server
Based on DiskANN embedding server architecture
""" """
import pickle
import argparse import argparse
import json
import logging
import os
import sys
import threading import threading
import time import time
from transformers import AutoTokenizer, AutoModel
import os
from contextlib import contextmanager
import zmq
import numpy as np
import msgpack
import json
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional, Union from typing import Optional
RED = "\033[91m" import msgpack
RESET = "\033[0m" import numpy as np
import zmq
def is_similarity_metric(): # Set up logging based on environment variable
""" LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
Check if the metric type is similarity-based (like inner product). logger = logging.getLogger(__name__)
0 = L2 (distance metric), 1 = Inner Product (similarity metric)
"""
return True # 1 is METRIC_INNER_PRODUCT in FAISS
# Function for E5-style average pooling # Force set logger level (don't rely on basicConfig in subprocess)
import torch log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
from torch import Tensor logger.setLevel(log_level)
import torch.nn.functional as F
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: # Ensure we have a handler if none exists
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) if not logger.handlers:
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
class SimplePassageLoader:
"""
Simple passage loader that replaces config.py dependencies
"""
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {}
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID"""
str_id = str(passage_id)
if str_id in self.passages_data:
return {"text": self.passages_data[str_id]}
else:
# Return empty text for missing passages
return {"text": ""}
def __len__(self) -> int:
return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSON file
Expected format: {"passage_id": "passage_text", ...}
"""
if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
return SimplePassageLoader()
try:
with open(passages_file, 'r', encoding='utf-8') as f:
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}")
return SimplePassageLoader(passages_data)
except Exception as e:
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader()
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
passages_data: Optional[Dict[str, str]] = None,
embeddings_file: Optional[str] = None,
use_fp16: bool = True,
use_int8: bool = False,
use_cuda_graphs: bool = False,
zmq_port: int = 5555, zmq_port: int = 5555,
max_batch_size: int = 128,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips", distance_metric: str = "mips",
embedding_mode: str = "sentence-transformers",
): ):
""" """
Create and start a ZMQ-based embedding server for HNSW backend. Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module.
Args:
passages_file: Path to JSON file containing passage ID -> text mapping
passages_data: Direct passage data dict (alternative to passages_file)
embeddings_file: Path to pre-computed embeddings file (optional)
use_fp16: Whether to use FP16 precision
use_int8: Whether to use INT8 quantization
use_cuda_graphs: Whether to use CUDA graphs
zmq_port: ZMQ port to bind to
max_batch_size: Maximum batch size for processing
model_name: Transformer model name
custom_max_length_param: Custom max sequence length
distance_metric: The distance metric to use
""" """
print(f"Loading tokenizer for {model_name}...") logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) logger.info(f"Using embedding mode: {embedding_mode}")
print(f"Tokenizer loaded successfully!")
# Device setup # Add leann-core to path for unified embedding computation
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() current_dir = Path(__file__).parent
cuda_available = torch.cuda.is_available() leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
print(f"MPS available: {mps_available}") try:
print(f"CUDA available: {cuda_available}") from leann.api import PassageManager
from leann.embedding_compute import compute_embeddings
if cuda_available: logger.info("Successfully imported unified embedding computation module")
device = torch.device("cuda") except ImportError as e:
print("Using CUDA device") logger.error(f"Failed to import embedding computation module: {e}")
elif mps_available: return
device = torch.device("mps") finally:
print("Using MPS device (Apple Silicon)") sys.path.pop(0)
else:
device = torch.device("cpu")
print("Using CPU device (no GPU acceleration available)")
# Load model to the appropriate device
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Loading model {model_name}... (this may take a while if downloading)")
model = AutoModel.from_pretrained(model_name).to(device).eval()
print(f"Model {model_name} loaded successfully!")
# Check port availability # Check port availability
import socket import socket
def check_port(port): def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0 return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port): if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}") logger.error(f"Port {zmq_port} is already in use")
return return
# Apply model optimizations (similar to DiskANN version) # Only support metadata file, fail fast for everything else
if use_fp16 and (cuda_available or mps_available): if not passages_file or not passages_file.endswith(".meta.json"):
model = model.half() raise ValueError("Only metadata files (.meta.json) are supported")
model = torch.compile(model)
print(f"Using FP16 precision with model: {model_name}")
elif use_int8:
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
quantize_(model, Int8DynamicActivationInt8WeightConfig())
model = torch.compile(model)
model.eval()
print("- Model successfully quantized and compiled")
# Load passages # Load metadata to get passage sources
if passages_data: with open(passages_file) as f:
passages = SimplePassageLoader(passages_data) meta = json.load(f)
print(f"Using provided passages data: {len(passages)} passages")
elif passages_file:
passages = load_passages_from_file(passages_file)
else:
passages = SimplePassageLoader()
print("No passages provided, using empty loader")
# Load embeddings if provided # Let PassageManager handle path resolution uniformly. It supports fallback order:
_embeddings = None # 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
if embeddings_file and os.path.exists(embeddings_file): passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
# Dimension from metadata for shaping responses
try: try:
with open(embeddings_file, "rb") as f: embedding_dim: int = int(meta.get("dimensions", 0))
_embeddings = pickle.load(f) except Exception:
print(f"Loaded embeddings from {embeddings_file}") embedding_dim = 0
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
Creates its own REP socket bound to zmq_port and polls with timeouts
to allow graceful shutdown.
"""
logger.info("ZMQ server thread started with shutdown support")
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
# Track last request type/length for shape-correct fallbacks
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv()
# Rest of the processing logic (same as original)
request = msgpack.unpackb(request_bytes)
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
response_bytes = msgpack.packb([model_name])
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
if (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
node_ids = request[0]
# Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e: except Exception as e:
print(f"Error loading embeddings: {e}") logger.error(f"Exception looking up passage ID {nid}: {e}")
class DeviceTimer: # Prepare full-length response with large sentinel values
"""Device event-based timer for accurate timing.""" large_distance = 1e9
def __init__(self, name="", device=device): response_distances = [large_distance] * len(node_ids)
self.name = name
self.device = device
self.start_time = 0
self.end_time = 0
if cuda_available: if texts:
self.start_event = torch.cuda.Event(enable_timing=True) try:
self.end_event = torch.cuda.Event(enable_timing=True) embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else: else:
self.start_event = None node_ids = []
self.end_event = None last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
@contextmanager # Preallocate zero-filled flat data for robustness
def timing(self): if embedding_dim <= 0:
self.start() dims = [0, 0]
yield flat_data: list[float] = []
self.end()
def start(self):
if cuda_available:
torch.cuda.synchronize()
self.start_event.record()
else: else:
if self.device.type == "mps": dims = [len(node_ids), embedding_dim]
torch.mps.synchronize() flat_data = [0.0] * (dims[0] * dims[1])
self.start_time = time.time()
def end(self): # Collect texts for found ids
if cuda_available: texts: list[str] = []
self.end_event.record() found_indices: list[int] = []
torch.cuda.synchronize() for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else: else:
if self.device.type == "mps": logger.error(f"Empty text for passage ID {nid}")
torch.mps.synchronize() except KeyError:
self.end_time = time.time() logger.error(f"Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
def elapsed_time(self): if texts:
if cuda_available: try:
return self.start_event.elapsed_time(self.end_event) / 1000.0 embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
else: logger.info(
return self.end_time - self.start_time f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
def print_elapsed(self):
return # Disabled for now
def process_batch(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts and return embeddings"""
_is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch)
# E5 model preprocessing
if _is_e5_model:
processed_texts_batch = [f"passage: {text}" for text in texts_batch]
else:
processed_texts_batch = texts_batch
# Set max length
if _is_e5_model:
current_max_length = custom_max_length_param if custom_max_length_param is not None else 512
else:
current_max_length = custom_max_length_param if custom_max_length_param is not None else 256
tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device)
embed_timer = DeviceTimer("embedding (batch)", device)
pool_timer = DeviceTimer("pooling (batch)", device)
norm_timer = DeviceTimer("normalization (batch)", device)
with tokenize_timer.timing():
encoded_batch = tokenizer(
processed_texts_batch,
padding="max_length",
truncation=True,
max_length=current_max_length,
return_tensors="pt",
return_token_type_ids=False,
) )
seq_length = encoded_batch["input_ids"].size(1) if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
with to_device_timer.timing(): f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
enc = {k: v.to(device) for k, v in encoded_batch.items()} )
dims = [0, embedding_dim]
with torch.no_grad(): flat_data = []
with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
with pool_timer.timing():
if _is_bge_model:
pooled_embeddings = out.last_hidden_state[:, 0]
elif not hasattr(out, 'last_hidden_state'):
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out
else: else:
print(f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}") emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
hidden_dim = getattr(model.config, 'hidden_size', 384 if _is_e5_model else 768) flat = emb_f32.flatten().tolist()
pooled_embeddings = torch.zeros((batch_size, hidden_dim), device=device, dtype=enc["input_ids"].dtype if hasattr(enc["input_ids"], "dtype") else torch.float32) for j, pos in enumerate(found_indices):
elif _is_e5_model: start = pos * embedding_dim
pooled_embeddings = e5_average_pool(out.last_hidden_state, enc['attention_mask']) end = start + embedding_dim
else: if end <= len(flat_data):
hidden_states = out.last_hidden_state flat_data[start:end] = flat[
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float() j * embedding_dim : (j + 1) * embedding_dim
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
pooled_embeddings = sum_embeddings / sum_mask
final_embeddings = pooled_embeddings
if _is_e5_model or _is_bge_model:
with norm_timer.timing():
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
print(f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}")
dim_size = final_embeddings.shape[-1]
error_output = torch.zeros((batch_size, dim_size), device='cpu', dtype=torch.float32).numpy()
print(f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}")
return error_output
return final_embeddings.cpu().numpy()
def client_warmup(zmq_port):
"""Perform client-side warmup"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
sample_ids = ["1", "2", "3", "4", "5"]
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
ids_to_send = []
if not ids_to_send:
print("Skipping warmup send.")
return
request_payload = [ids_to_send]
request_bytes = msgpack.packb(request_payload)
for i in range(3):
print(f"Sending warmup request {i+1}/3 via ZMQ (MessagePack)...")
socket.send(request_bytes)
response_bytes = socket.recv()
response_payload = msgpack.unpackb(response_bytes)
dimensions = response_payload[0]
embeddings_count = dimensions[0] if dimensions and len(dimensions) > 0 else 0
print(f"Warmup request {i+1}/3 successful, received {embeddings_count} embeddings")
time.sleep(0.1)
print("Client-side MessagePack ZMQ warmup complete")
socket.close()
context.term()
except Exception as e:
print(f"Error during MessagePack ZMQ warmup: {e}")
def zmq_server_thread():
"""ZMQ server thread"""
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}")
print(f"HNSW ZMQ server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
while True:
try:
message_bytes = socket.recv()
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup", device)
try:
request_payload = msgpack.unpackb(message_bytes)
# Handle distance calculation requests
if isinstance(request_payload, list) and len(request_payload) == 2 and isinstance(request_payload[0], list) and isinstance(request_payload[1], list):
node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32)
print(f"Request for distance calculation: {len(node_ids)} nodes, query vector dim: {len(query_vector)}")
# Get embeddings for node IDs
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
try:
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
else:
txt = txtinfo["text"]
except (KeyError, IndexError):
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
texts.append(txt)
lookup_timer.print_elapsed()
# Process embeddings in chunks if needed
all_node_embeddings = []
total_size = len(texts)
if total_size > max_batch_size:
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
all_node_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
node_embeddings = np.vstack(all_node_embeddings)
else:
node_embeddings = process_batch(texts, node_ids, missing_ids)
# Calculate distances
query_tensor = torch.tensor(query_vector, device=device).float()
node_embeddings_tensor = torch.tensor(node_embeddings, device=device).float()
calc_timer = DeviceTimer("distance calculation", device)
with calc_timer.timing():
with torch.no_grad():
if distance_metric == "l2":
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
query_np = query_tensor.cpu().numpy().astype(np.float32)
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
else: # mips or cosine
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
query_np = query_tensor.cpu().numpy()
distances = -np.dot(node_embeddings_np, query_np)
calc_timer.print_elapsed()
try:
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True)
print(f"Sending distance response with {len(distances)} distances")
except Exception as pack_error:
print(f"Error packing MessagePack distance response: {pack_error}")
response_bytes = msgpack.packb([[]])
socket.send(response_bytes)
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds")
continue
# Standard embedding request
if not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list):
print(f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}")
socket.send(msgpack.packb([[], []]))
continue
node_ids = request_payload[0]
print(f"Request for {len(node_ids)} node embeddings")
except Exception as unpack_error:
print(f"Error unpacking MessagePack request: {unpack_error}")
socket.send(msgpack.packb([[], []]))
continue
# Look up texts by node IDs
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
try:
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
else:
txt = txtinfo["text"]
except (KeyError, IndexError):
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
texts.append(txt)
lookup_timer.print_elapsed()
if missing_ids:
print(f"Missing passages for IDs: {missing_ids}")
# Process in chunks
total_size = len(texts)
print(f"Total batch size: {total_size}, max_batch_size: {max_batch_size}")
all_embeddings = []
if total_size > max_batch_size:
print(f"Splitting batch of size {total_size} into chunks of {max_batch_size}")
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
print(f"Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings)
print(f"Combined embeddings shape: {hidden.shape}")
else:
hidden = process_batch(texts, node_ids, missing_ids)
# Serialization and response
ser_start = time.time()
print(f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}")
if np.isnan(hidden).any() or np.isinf(hidden).any():
print(f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
f"Requested IDs (sample): {node_ids[:5]}...{RESET}")
assert False
try:
hidden_contiguous_f32 = np.ascontiguousarray(hidden, dtype=np.float32)
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist()
] ]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True) response_bytes = msgpack.packb(response_payload, use_single_float=True)
except Exception as pack_error:
print(f"Error packing MessagePack response: {pack_error}")
response_bytes = msgpack.packb([[], []])
socket.send(response_bytes) rep_socket.send(response_bytes)
ser_end = time.time()
print(f"Serialize time: {ser_end - ser_start:.6f} seconds")
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time() e2e_end = time.time()
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds") logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again: except zmq.Again:
print("ZMQ socket timeout, continuing to listen") # Timeout - check shutdown_event and continue
continue continue
except Exception as e: except Exception as e:
print(f"Error in ZMQ server loop: {e}") if not shutdown_event.is_set():
import traceback logger.error(f"Error in ZMQ server loop: {e}")
traceback.print_exc() # Shape-correct fallback
try: try:
socket.send(msgpack.packb([[], []])) if last_request_type == "distance":
except: large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass pass
# Start warmup and server threads logger.info("ZMQ server thread exiting gracefully")
if len(passages) > 0:
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) # Add shutdown coordination
shutdown_event = threading.Event()
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Clean up other resources
try:
import gc
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
logger.info("Graceful shutdown completed")
sys.exit(0)
# Register signal handlers within this function scope
import signal
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Pass shutdown_event to ZMQ thread
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread.start() zmq_thread.start()
print(f"Started HNSW ZMQ server thread on port {zmq_port}") logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive # Keep the main thread alive
try: try:
while True: while not shutdown_event.is_set():
time.sleep(1) time.sleep(0.1) # Check shutdown more frequently
except KeyboardInterrupt: except KeyboardInterrupt:
print("HNSW Server shutting down...") logger.info("HNSW Server shutting down...")
shutdown_zmq_server()
return return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__": if __name__ == "__main__":
import sys
# Signal handlers are now registered within create_hnsw_embedding_server
parser = argparse.ArgumentParser(description="HNSW Embedding service") parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping") parser.add_argument(
parser.add_argument("--embeddings-file", type=str, help="Pickle file containing pre-computed embeddings") "--passages-file",
parser.add_argument("--use-fp16", action="store_true", default=False) type=str,
parser.add_argument("--use-int8", action="store_true", default=False) help="JSON file containing passage ID to text mapping",
parser.add_argument("--use-cuda-graphs", action="store_true", default=False) )
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting") parser.add_argument(
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", "--model-name",
help="Embedding model name") type=str,
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length") default="sentence-transformers/all-mpnet-base-v2",
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use") help="Embedding model name",
)
parser.add_argument(
"--distance-metric", type=str, default="mips", help="Distance metric to use"
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
args = parser.parse_args() args = parser.parse_args()
# Create and start the HNSW embedding server # Create and start the HNSW embedding server
create_hnsw_embedding_server( create_hnsw_embedding_server(
passages_file=args.passages_file, passages_file=args.passages_file,
embeddings_file=args.embeddings_file,
use_fp16=args.use_fp16,
use_int8=args.use_int8,
use_cuda_graphs=args.use_cuda_graphs,
zmq_port=args.zmq_port, zmq_port=args.zmq_port,
max_batch_size=args.max_batch_size,
model_name=args.model_name, model_name=args.model_name,
custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric, distance_metric=args.distance_metric,
embedding_mode=args.embedding_mode,
) )

View File

@@ -1,4 +1,4 @@
# 文件: packages/leann-backend-hnsw/pyproject.toml # packages/leann-backend-hnsw/pyproject.toml
[build-system] [build-system]
requires = ["scikit-build-core>=0.10", "numpy", "swig"] requires = ["scikit-build-core>=0.10", "numpy", "swig"]
@@ -6,13 +6,24 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.1.0" version = "0.3.2"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = ["leann-core==0.1.0", "numpy"] dependencies = [
"leann-core==0.3.2",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
]
# 回归到最标准的 scikit-build-core 配置
[tool.scikit-build] [tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"] wheel.packages = ["leann_backend_hnsw"]
editable.mode = "redirect" editable.mode = "redirect"
cmake.build-type = "Debug" cmake.build-type = "Release"
build.verbose = true build.verbose = true
build.tool-args = ["-j8"]
# CMake definitions to optimize compilation and find Homebrew packages
[tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8"
CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}
OpenMP_ROOT = {env = "OpenMP_ROOT"}

View File

@@ -4,15 +4,49 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.1.0" version = "0.3.2"
description = "Core API and plugin system for Leann." description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"
license = { text = "MIT" } license = { text = "MIT" }
# All required dependencies included
dependencies = [ dependencies = [
"numpy>=1.20.0" "numpy>=1.20.0",
"tqdm>=4.60.0",
"psutil>=5.8.0",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0", # Essential for document reading
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
"python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
"transformers>=4.30.0",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
"pymupdf>=1.23.0",
"pdfplumber>=0.10.0",
"nbconvert>=7.0.0", # For .ipynb file support
"gitignore-parser>=0.1.12", # For proper .gitignore handling
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
] ]
[project.optional-dependencies]
colab = [
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
"transformers>=4.30.0,<5.0.0", # Limit transformers version
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
]
[project.scripts]
leann = "leann.cli:main"
leann_mcp = "leann.mcp:main"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["src"] where = ["src"]

View File

@@ -1,17 +1,21 @@
# This file makes the 'leann' directory a Python package. # packages/leann-core/src/leann/__init__.py
import os
import platform
from .api import LeannBuilder, LeannSearcher, LeannChat, SearchResult # Fix OpenMP threading issues on macOS ARM64
if platform.system() == "Darwin":
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["KMP_BLOCKTIME"] = "0"
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
if os.environ.get("CI") == "true":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Import backends to ensure they are registered from .api import LeannBuilder, LeannChat, LeannSearcher
try: from .registry import BACKEND_REGISTRY, autodiscover_backends
import leann_backend_hnsw
except ImportError:
pass
try: autodiscover_backends()
import leann_backend_diskann
except ImportError:
pass
__all__ = ["BACKEND_REGISTRY", "LeannBuilder", "LeannChat", "LeannSearcher"]
__all__ = ['LeannBuilder', 'LeannSearcher', 'LeannChat', 'SearchResult']

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,855 @@
#!/usr/bin/env python3
"""
This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
import difflib
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Optional
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_ollama_models(host: str) -> list[str]:
"""Check available Ollama models and return a list"""
try:
import requests
response = requests.get(f"{host}/api/tags", timeout=5)
if response.status_code == 200:
data = response.json()
return [model["name"] for model in data.get("models", [])]
return []
except Exception:
return []
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
(model_exists, available_tags): bool and list of matching tags
"""
try:
import re
import requests
# Split model name and tag
if ":" in model_name:
base_model, requested_tag = model_name.split(":", 1)
else:
base_model, requested_tag = model_name, None
# First check if base model exists in library
library_response = requests.get("https://ollama.com/library", timeout=8)
if library_response.status_code != 200:
return True, [] # Assume exists if can't check
# Extract model names from library page
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
if base_model not in models_in_library:
return False, [] # Base model doesn't exist
# If base model exists, get available tags
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
if tags_response.status_code != 200:
return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates
available_tags = []
seen = set()
for tag in raw_tags:
# Skip if it looks like HTML (contains < or >)
if "<" in tag or ">" in tag:
continue
if tag not in seen:
seen.add(tag)
available_tags.append(tag)
# Check if exact model exists
if requested_tag is None:
# User just requested base model, suggest tags
return True, available_tags[:10] # Return up to 10 tags
else:
exact_match = model_name in available_tags
return exact_match, available_tags[:10]
except Exception:
pass
# If scraping fails, assume model might exist (don't block user)
return True, []
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
return []
query_lower = query.lower()
suggestions = []
# 1. Exact matches first
exact_matches = [m for m in available_models if query_lower == m.lower()]
suggestions.extend(exact_matches)
# 2. Starts with query
starts_with = [
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
]
suggestions.extend(starts_with)
# 3. Contains query
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
suggestions.extend(contains)
# 4. Base model name matching (remove version numbers)
def get_base_name(model_name: str) -> str:
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
return model_name.split(":")[0].split("-")[0]
query_base = get_base_name(query_lower)
base_matches = [
m
for m in available_models
if get_base_name(m.lower()) == query_base and m not in suggestions
]
suggestions.extend(base_matches)
# 5. Family/variant matching
model_families = {
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
"qwen": ["qwen", "qwen2", "qwen3"],
"gemma": ["gemma", "gemma2"],
"phi": ["phi", "phi2", "phi3"],
"mistral": ["mistral", "mixtral", "openhermes"],
"dolphin": ["dolphin", "openchat"],
"deepseek": ["deepseek", "deepseek-coder"],
}
query_family = None
for family, variants in model_families.items():
if any(variant in query_lower for variant in variants):
query_family = family
break
if query_family:
family_variants = model_families[query_family]
family_matches = [
m
for m in available_models
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
]
suggestions.extend(family_matches)
# 6. Use difflib for remaining fuzzy matches
remaining_models = [m for m in available_models if m not in suggestions]
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
suggestions.extend(difflib_matches)
return suggestions[:8] # Return top 8 suggestions
# Remove this function entirely - we don't need external API calls for Ollama
# Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
"""Use difflib to find similar model names"""
if not available_models:
return []
# Get close matches using fuzzy matching
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
return suggestions
def check_hf_model_exists(model_name: str) -> bool:
"""Quick check if HuggingFace model exists without downloading"""
try:
from huggingface_hub import model_info
model_info(model_name)
return True
except Exception:
return False
def get_popular_hf_models() -> list[str]:
"""Return a list of popular HuggingFace models for suggestions"""
try:
from huggingface_hub import list_models
# Get popular text-generation models, sorted by downloads
models = list_models(
filter="text-generation",
sort="downloads",
direction=-1,
limit=20, # Get top 20 most downloaded
)
# Extract model names and filter for chat/conversation models
model_names = []
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
for model in models:
model_name = model.id if hasattr(model, "id") else str(model)
# Prioritize models with chat-related keywords
if any(keyword in model_name.lower() for keyword in chat_keywords):
model_names.append(model_name)
elif len(model_names) < 10: # Fill up with other popular models
model_names.append(model_name)
return model_names[:10] if model_names else _get_fallback_hf_models()
except Exception:
# Fallback to static list if API call fails
return _get_fallback_hf_models()
def _get_fallback_hf_models() -> list[str]:
"""Fallback list of popular HuggingFace models"""
return [
"microsoft/DialoGPT-medium",
"microsoft/DialoGPT-large",
"facebook/blenderbot-400M-distill",
"microsoft/phi-2",
"deepseek-ai/deepseek-llm-7b-chat",
"microsoft/DialoGPT-small",
"facebook/blenderbot_small-90M",
"microsoft/phi-1_5",
"facebook/opt-350m",
"EleutherAI/gpt-neo-1.3B",
]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
try:
from huggingface_hub import list_models
# HF Hub's search is already fuzzy! It handles typos and partial matches
models = list_models(
search=query,
filter="text-generation",
sort="downloads",
direction=-1,
limit=limit,
)
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
# If direct search doesn't return enough results, try some variations
if len(model_names) < 3:
# Try searching for partial matches or common variations
variations = []
# Extract base name (e.g., "gpt3" from "gpt-3.5")
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
if base_query != query.lower():
variations.append(base_query)
# Try common model name patterns
if "gpt" in query.lower():
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
elif "llama" in query.lower():
variations.extend(["llama2", "alpaca", "vicuna"])
elif "bert" in query.lower():
variations.extend(["roberta", "distilbert", "albert"])
# Search with variations
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
try:
var_models = list_models(
search=var,
filter="text-generation",
sort="downloads",
direction=-1,
limit=3,
)
var_names = [
model.id if hasattr(model, "id") else str(model) for model in var_models
]
model_names.extend(var_names)
except Exception:
continue
# Remove duplicates while preserving order
seen = set()
unique_models = []
for model in model_names:
if model not in seen:
seen.add(model)
unique_models.append(model)
return unique_models[:limit]
except Exception:
# If search fails, return empty list
return []
def search_hf_models(query: str, limit: int = 10) -> list[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit)
def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434"
) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models(host)
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
# Check if the model exists remotely and get available tags
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it
error_msg += "\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n"
# Show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(":")[0]
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
error_msg += (
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
)
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n"
if len(available_tags) > 8:
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
# Also show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Model doesn't exist remotely - show fuzzy suggestions
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
if suggestions:
error_msg += (
"\n\nDid you mean one of these installed models?\n"
+ "\nTry to use ollama pull to install the model you need\n"
)
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nCommands:"
error_msg += "\n ollama list # List installed models"
if model_exists_remotely and available_tags:
if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model"
else:
error_msg += (
f"\n ollama pull {available_tags[0]} # Install recommended variant"
)
error_msg += "\n https://ollama.com/library # Browse available models"
return error_msg
elif llm_type == "hf":
# For HF models, we can do a quick existence check
if not check_hf_model_exists(model_name):
# Use HF Hub's native fuzzy search directly
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
if search_suggestions:
error_msg += "\n\nDid you mean one of these?\n"
for i, suggestion in enumerate(search_suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Fallback to popular models if search returns nothing
popular_models = get_popular_hf_models()
error_msg += "\n\nPopular chat models:\n"
for i, model in enumerate(popular_models[:5], 1):
error_msg += f" {i}. {model}\n"
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
return error_msg
return None # Model is valid or we can't check
class LLMInterface(ABC):
"""Abstract base class for a generic Language Model (LLM) interface."""
@abstractmethod
def ask(self, prompt: str, **kwargs) -> str:
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
# """
# Sends a prompt to the LLM and returns the generated text.
# Args:
# prompt: The input prompt for the LLM.
# **kwargs: Additional keyword arguments for the LLM backend.
# Returns:
# The response string from the LLM.
# """
pass
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host)
if model_error:
raise ValueError(model_error)
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
import json
import requests
full_url = f"{self.host}/api/generate"
# Handle thinking budget for reasoning models
options = kwargs.copy()
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget:
# Remove thinking_budget from options as it's not a standard Ollama option
options.pop("thinking_budget", None)
# Only apply reasoning parameters to models that support it
reasoning_supported_models = [
"gpt-oss:20b",
"gpt-oss:120b",
"deepseek-r1",
"deepseek-coder",
]
if thinking_budget in ["low", "medium", "high"]:
if any(model in self.model.lower() for model in reasoning_supported_models):
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
logger.info(f"Applied reasoning effort={thinking_budget} to model {self.model}")
else:
logger.warning(
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
)
payload = {
"model": self.model,
"prompt": prompt,
"stream": False, # Keep it simple for now
"options": options,
}
logger.debug(f"Sending request to Ollama: {payload}")
try:
logger.info("Sending request to Ollama and waiting for response...")
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
# The response from Ollama can be a stream of JSON objects, handle this
response_parts = response.text.strip().split("\n")
full_response = ""
for part in response_parts:
if part:
json_part = json.loads(part)
full_response += json_part.get("response", "")
if json_part.get("done"):
break
return full_response
except requests.exceptions.RequestException as e:
logger.error(f"Error communicating with Ollama: {e}")
return f"Error: Could not get a response from Ollama. Details: {e}"
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf")
if model_error:
raise ValueError(model_error)
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
)
# Auto-detect device
if torch.cuda.is_available():
self.device = "cuda"
logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.")
else:
self.device = "cpu"
logger.info("No GPU detected. Using CPU.")
# Load tokenizer and model with timeout protection
try:
import signal
def timeout_handler(signum, frame):
raise TimeoutError("Model download/loading timed out")
# Set timeout for model loading (60 seconds)
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(60)
try:
logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loading model {model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True,
)
logger.info(f"Successfully loaded {model_name}")
finally:
signal.alarm(0) # Cancel the alarm
signal.signal(signal.SIGALRM, old_handler) # Restore old handler
except TimeoutError:
logger.error(f"Model loading timed out for {model_name}")
raise RuntimeError(
f"Model loading timed out for {model_name}. Please check your internet connection or try a smaller model."
)
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
raise
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def ask(self, prompt: str, **kwargs) -> str:
print("kwargs in HF: ", kwargs)
# Check if this is a Qwen model and add /no_think by default
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
# For Qwen models, automatically add /no_think to the prompt
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
prompt = prompt + " /no_think"
# Prepare chat template
messages = [{"role": "user", "content": prompt}]
# Apply chat template if available
if hasattr(self.tokenizer, "apply_chat_template"):
try:
formatted_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed, using raw prompt: {e}")
formatted_prompt = prompt
else:
# Fallback for models without chat template
formatted_prompt = prompt
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
)
# Move inputs to device
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Set generation parameters
generation_config = {
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"do_sample": kwargs.get("temperature", 0.7) > 0,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Handle temperature=0 for greedy decoding
if generation_config["temperature"] == 0.0:
generation_config["do_sample"] = False
generation_config.pop("temperature")
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate
with torch.no_grad():
outputs = self.model.generate(**inputs, **generation_config)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
class GeminiChat(LLMInterface):
"""LLM interface for Google Gemini models."""
def __init__(self, model: str = "gemini-2.5-flash", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
if not self.api_key:
raise ValueError(
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing Gemini Chat with model='{model}'")
try:
import google.genai as genai
self.client = genai.Client(api_key=self.api_key)
except ImportError:
raise ImportError(
"The 'google-genai' library is required for Gemini models. Please install it with 'uv pip install google-genai'."
)
def ask(self, prompt: str, **kwargs) -> str:
logger.info(f"Sending request to Gemini with model {self.model}")
try:
from google.genai.types import GenerateContentConfig
generation_config = GenerateContentConfig(
temperature=kwargs.get("temperature", 0.7),
max_output_tokens=kwargs.get("max_tokens", 1000),
)
# Handle top_p parameter
if "top_p" in kwargs:
generation_config.top_p = kwargs["top_p"]
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=generation_config,
)
# Handle potential None response text
response_text = response.text
if response_text is None:
logger.warning("Gemini returned None response text")
return ""
return response_text.strip()
except Exception as e:
logger.error(f"Error communicating with Gemini: {e}")
return f"Error: Could not get a response from Gemini. Details: {e}"
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing OpenAI Chat with model='{model}'")
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
)
def ask(self, prompt: str, **kwargs) -> str:
# Default parameters for OpenAI
params = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": kwargs.get("temperature", 0.7),
}
# Handle max_tokens vs max_completion_tokens based on model
max_tokens = kwargs.get("max_tokens", 1000)
if "o3" in self.model or "o4" in self.model or "o1" in self.model:
# o-series models use max_completion_tokens
params["max_completion_tokens"] = max_tokens
params["temperature"] = 1.0
else:
# Other models use max_tokens
params["max_tokens"] = max_tokens
# Handle thinking budget for reasoning models
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
# Check if this is an o-series model (partial match for model names)
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
if any(model in self.model for model in o_series_models):
# Use the correct OpenAI reasoning parameter format
params["reasoning_effort"] = thinking_budget
logger.info(f"Applied reasoning_effort={thinking_budget} to model {self.model}")
else:
logger.warning(
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
)
# Add other kwargs (excluding thinking_budget as it's handled above)
for k, v in kwargs.items():
if k not in ["max_tokens", "temperature", "thinking_budget"]:
params[k] = v
logger.info(f"Sending request to OpenAI with model {self.model}")
try:
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")
return f"Error: Could not get a response from OpenAI. Details: {e}"
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
def ask(self, prompt: str, **kwargs) -> str:
logger.info("Simulating LLM call...")
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.
Args:
llm_config: A dictionary specifying the LLM type and its parameters.
Example: {"type": "ollama", "model": "llama3"}
{"type": "hf", "model": "distilgpt2"}
None (for simulation mode)
Returns:
An instance of an LLMInterface subclass.
"""
if llm_config is None:
llm_config = {
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
}
llm_type = llm_config.get("type", "openai")
model = llm_config.get("model")
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
if llm_type == "ollama":
return OllamaChat(
model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"),
)
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":
return SimulatedChat()
else:
raise ValueError(f"Unknown LLM type: '{llm_type}'")

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,869 @@
"""
Unified embedding computation module
Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import logging
import os
import time
from typing import Any
import numpy as np
import torch
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
def compute_embeddings(
texts: list[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
batch_size: int = 32,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
) -> np.ndarray:
"""
Unified embedding computation entry point
Args:
texts: List of texts to compute embeddings for
model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
is_build: Whether this is a build operation (shows progress bar)
batch_size: Batch size for processing
adaptive_optimization: Whether to use adaptive optimization based on batch size
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
texts,
model_name,
is_build=is_build,
batch_size=batch_size,
adaptive_optimization=adaptive_optimization,
manual_tokenize=manual_tokenize,
max_length=max_length,
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
def compute_embeddings_sentence_transformers(
texts: list[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
batch_size: int = 32,
is_build: bool = False,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
Args:
texts: List of texts to compute embeddings for
model_name: Model name
use_fp16: Whether to use FP16 precision
device: Device to use ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size
"""
# Handle empty input
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
logger.info(
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
# Auto-detect device
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Apply optimizations based on benchmark results
if adaptive_optimization:
# Use optimal batch_size constants for different devices based on benchmark results
if device == "mps":
batch_size = 128 # MPS optimal batch size from benchmark
if model_name == "Qwen/Qwen3-Embedding-0.6B":
batch_size = 32
elif device == "cuda":
batch_size = 256 # CUDA optimal batch size
# Keep original batch_size for CPU
# Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
# Check if model is already cached
if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key]
else:
logger.info(f"Loading and caching optimized SentenceTransformer model: {model_name}")
from sentence_transformers import SentenceTransformer
logger.info(f"Using device: {device}")
# Apply hardware optimizations
if device == "cuda":
# TODO: Haven't tested this yet
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.cuda.set_per_process_memory_fraction(0.9)
elif device == "mps":
try:
if hasattr(torch.mps, "set_per_process_memory_fraction"):
torch.mps.set_per_process_memory_fraction(0.9)
except AttributeError:
logger.warning("Some MPS optimizations not available in this PyTorch version")
elif device == "cpu":
# TODO: Haven't tested this yet
torch.set_num_threads(min(8, os.cpu_count() or 4))
try:
torch.backends.mkldnn.enabled = True
except AttributeError:
pass
# Prepare optimized model and tokenizer parameters
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True,
"attn_implementation": "eager", # Use eager attention for speed
}
tokenizer_kwargs = {
"use_fast": True,
"padding": True,
"truncation": True,
}
try:
# Try local loading first
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
logger.info("Model loaded successfully! (local + optimized)")
except Exception as e:
logger.warning(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + optimized)")
# Apply additional optimizations based on mode
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
logger.info(f"Applied FP16 precision: {model_name}")
except Exception as e:
logger.warning(f"FP16 optimization failed: {e}")
# Apply torch.compile optimization
if device in ["cuda", "mps"]:
try:
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
logger.info(f"Applied torch.compile optimization: {model_name}")
except Exception as e:
logger.warning(f"torch.compile optimization failed: {e}")
# Set model to eval mode and disable gradients for inference
model.eval()
for param in model.parameters():
param.requires_grad_(False)
# Cache the model
_model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
start_time = time.time()
if not manual_tokenize:
# Use SentenceTransformer's optimized encode path (default)
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
# Synchronize if CUDA to measure accurate wall time
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
if device in ["cuda", "mps"]:
try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
except Exception:
pass
_model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model
all_embeddings: list[np.ndarray] = []
# Progress bar when building or for large inputs
show_progress = is_build or len(texts) > 32
try:
if show_progress:
from tqdm import tqdm # type: ignore
batch_iter = tqdm(
range(0, len(texts), batch_size),
desc="Embedding (manual)",
unit="batch",
)
else:
batch_iter = range(0, len(texts), batch_size)
except Exception:
batch_iter = range(0, len(texts), batch_size)
start_time_manual = time.time()
with torch.inference_mode():
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1)
else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings)
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
end_time = time.time()
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
end_time = time.time()
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
logger.info(f"Time taken: {end_time - start_time} seconds")
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import os
import openai
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
# Validate input list
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
# Extra validation: abort early if any item is empty/whitespace
invalid_count = sum(1 for t in texts if not isinstance(t, str) or not t.strip())
if invalid_count > 0:
raise ValueError(
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = "openai_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
print(f"len of texts: {len(texts)}")
# OpenAI has limits on batch size and input length
max_batch_size = 800 # Conservative batch size because the token limit is 300K
all_embeddings = []
# get the avg len of texts
avg_len = sum(len(text) for text in texts) / len(texts)
print(f"avg len of texts: {avg_len}")
# if avg len is less than 1000, use the max batch size
if avg_len > 300:
max_batch_size = 500
# if avg len is less than 1000, use the max batch size
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
print(f"len of embeddings: {len(embeddings)}")
return embeddings
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
logger.info(
f"Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Cache MLX model and tokenizer
cache_key = f"mlx_{model_name}"
if cache_key in _model_cache:
logger.info(f"Using cached MLX model: {model_name}")
model, tokenizer = _model_cache[cache_key]
else:
logger.info(f"Loading and caching MLX model: {model_name}")
model, tokenizer = load(model_name)
_model_cache[cache_key] = (model, tokenizer)
logger.info(f"MLX model cached: {cache_key}")
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
)
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i : i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)
def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Args:
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
try:
import requests
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
)
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
)
# Check if Ollama is running
try:
response = requests.get(f"{host}/api/version", timeout=5)
response.raise_for_status()
except requests.exceptions.ConnectionError:
error_msg = (
f"❌ Could not connect to Ollama at {host}.\n\n"
"Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n"
"Installation: https://ollama.com/download"
)
raise RuntimeError(error_msg)
except Exception as e:
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
# Check if model exists and provide helpful suggestions
try:
response = requests.get(f"{host}/api/tags", timeout=5)
response.raise_for_status()
models = response.json()
model_names = [model["name"] for model in models.get("models", [])]
# Filter for embedding models (models that support embeddings)
embedding_models = []
suggested_embedding_models = [
"nomic-embed-text",
"mxbai-embed-large",
"bge-m3",
"all-minilm",
"snowflake-arctic-embed",
]
for model in model_names:
# Check if it's an embedding model (by name patterns or known models)
base_name = model.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
embedding_models.append(model)
# Check if model exists (handle versioned names) and resolve to full name
resolved_model_name = None
for name in model_names:
# Exact match
if model_name == name:
resolved_model_name = name
break
# Match without version tag (use the versioned name)
elif model_name == name.split(":")[0]:
resolved_model_name = name
break
if not resolved_model_name:
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
# Suggest pulling the model
error_msg += "📦 To install this embedding model:\n"
error_msg += f" ollama pull {model_name}\n\n"
# Show available embedding models
if embedding_models:
error_msg += "✅ Available embedding models:\n"
for model in embedding_models[:5]:
error_msg += f"{model}\n"
if len(embedding_models) > 5:
error_msg += f" ... and {len(embedding_models) - 5} more\n"
else:
error_msg += "💡 Popular embedding models to install:\n"
for model in suggested_embedding_models[:3]:
error_msg += f" • ollama pull {model}\n"
error_msg += "\n📚 Browse more: https://ollama.com/library"
raise ValueError(error_msg)
# Use the resolved model name for all subsequent operations
if resolved_model_name != model_name:
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
)
if test_response.status_code != 200:
error_msg = (
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
f"Please use an embedding model like:\n"
)
for model in suggested_embedding_models[:3]:
error_msg += f"{model}\n"
raise ValueError(error_msg)
except requests.exceptions.RequestException:
# If test fails, continue anyway - model might still work
pass
except requests.exceptions.RequestException as e:
logger.warning(f"Could not verify model existence: {e}")
# Determine batch size based on device availability
# Check for CUDA/MPS availability using torch if available
batch_size = 32 # Default for MPS/CPU
try:
import torch
if torch.cuda.is_available():
batch_size = 128 # CUDA gets larger batch size
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
batch_size = 32 # MPS gets smaller batch size
except ImportError:
# If torch is not available, use conservative batch size
batch_size = 32
logger.info(f"Using batch size: {batch_size}")
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts."""
all_embeddings = []
failed_indices = []
for i, text in enumerate(batch_texts):
max_retries = 3
retry_count = 0
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
# Process texts in batches
all_embeddings = []
all_failed_indices = []
# Setup progress bar if needed
show_progress = is_build or len(texts) > 10
try:
if show_progress:
from tqdm import tqdm
except ImportError:
show_progress = False
# Process batches
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
else:
batch_iterator = range(num_batches)
for batch_idx in batch_iterator:
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
all_embeddings.extend(batch_embeddings)
# Handle failed embeddings
if all_failed_indices:
if len(all_failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings")
logger.warning(
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
)
# Use zero embeddings as fallback for failed ones
valid_embedding = next((e for e in all_embeddings if e is not None), None)
if valid_embedding:
embedding_dim = len(valid_embedding)
for i, embedding in enumerate(all_embeddings):
if embedding is None:
all_embeddings[i] = [0.0] * embedding_dim
# Remove None values
all_embeddings = [e for e in all_embeddings if e is not None]
if not all_embeddings:
raise RuntimeError("No valid embeddings were computed")
# Validate embedding dimensions
expected_dim = len(all_embeddings[0])
inconsistent_dims = []
for i, embedding in enumerate(all_embeddings):
if len(embedding) != expected_dim:
inconsistent_dims.append((i, len(embedding)))
if inconsistent_dims:
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
error_msg += f" - Text {idx}: {dim} dimensions\n"
if len(inconsistent_dims) > 10:
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
error_msg += (
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
)
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
raise ValueError(error_msg)
# Convert to numpy array and normalize
embeddings = np.array(all_embeddings, dtype=np.float32)
# Normalize embeddings (L2 normalization)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
return embeddings
def compute_embeddings_gemini(
texts: list[str], model_name: str = "text-embedding-004", is_build: bool = False
) -> np.ndarray:
"""
Compute embeddings using Google Gemini API.
Args:
texts: List of texts to compute embeddings for
model_name: Gemini model name (default: "text-embedding-004")
is_build: Whether this is a build operation (shows progress bar)
Returns:
Embeddings array, shape: (len(texts), embedding_dim)
"""
try:
import os
import google.genai as genai
except ImportError as e:
raise ImportError(f"Google GenAI package not installed: {e}")
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise RuntimeError("GEMINI_API_KEY environment variable not set")
# Cache Gemini client
cache_key = "gemini_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = genai.Client(api_key=api_key)
_model_cache[cache_key] = client
logger.info("Gemini client cached")
logger.info(
f"Computing embeddings for {len(texts)} texts using Gemini API, model: '{model_name}'"
)
# Gemini supports batch embedding
max_batch_size = 100 # Conservative batch size for Gemini
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
# Use the embed_content method from the new Google GenAI SDK
response = client.models.embed_content(
model=model_name,
contents=batch_texts,
config=genai.types.EmbedContentConfig(
task_type="RETRIEVAL_DOCUMENT" # For document embedding
),
)
# Extract embeddings from response
for embedding_data in response.embeddings:
all_embeddings.append(embedding_data.values)
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
return embeddings

View File

@@ -0,0 +1,371 @@
import atexit
import logging
import os
import socket
import subprocess
import sys
import time
from pathlib import Path
from typing import Optional
# Lightweight, self-contained server manager with no cross-process inspection
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(levelname)s - %(name)s - %(message)s",
)
logger = logging.getLogger(__name__)
def _is_colab_environment() -> bool:
"""Check if we're running in Google Colab environment."""
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
def _get_available_port(start_port: int = 5557) -> int:
"""Get an available port starting from start_port."""
port = start_port
while port < start_port + 100: # Try up to 100 ports
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
return port
except OSError:
port += 1
raise RuntimeError(f"No available ports found in range {start_port}-{start_port + 100}")
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager:
"""
A simplified manager for embedding server processes that avoids complex update mechanisms.
"""
def __init__(self, backend_module_name: str):
"""
Initializes the manager for a specific backend.
Args:
backend_module_name (str): The full module name of the backend's server script.
e.g., "leann_backend_diskann.embedding_server"
"""
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
# Track last-started config for in-process reuse only
self._server_config: Optional[dict] = None
self._atexit_registered = False
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
self._finalizer = None
def start_server(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
# If this manager already has a live server, just reuse it
if self.server_process and self.server_process.poll() is None and self.server_port:
logger.info("Reusing in-process server")
return True, self.server_port
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
# Always pick a fresh available port
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
return False, port
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
# Try to find an available port
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
return False, port
logger.info(f"Starting server on port {actual_port} for Colab environment")
# Use a simpler startup strategy for Colab
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
except Exception as e:
logger.error(f"Failed to start embedding server: {e}")
return False, port
def _build_server_command(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> list:
"""Build the command to start the embedding server."""
command = [
sys.executable,
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
if kwargs.get("passages_file"):
# Convert to absolute path to ensure subprocess can find the file
passages_file = Path(kwargs["passages_file"]).resolve()
command.extend(["--passages-file", str(passages_file)])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("distance_metric"):
command.extend(["--distance-metric", kwargs["distance_metric"]])
return command
def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
# Embedding servers use many print statements that can fill stdout buffers
is_ci = os.environ.get("CI") == "true"
if is_ci:
stdout_target = subprocess.DEVNULL
stderr_target = None # Keep stderr for error debugging in CI
logger.info(
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
)
else:
stdout_target = None # Direct to console for visible logs
stderr_target = None # Direct to console for visible logs
# Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=stdout_target,
stderr=stderr_target,
)
self.server_port = port
# Record config for in-process reuse
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Always attempt best-effort finalize at interpreter exit
atexit.register(self._finalize_process)
self._atexit_registered = True
# Touch finalizer so it knows there is a live process
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
pass
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
logger.error("Server terminated during startup.")
return False, port
time.sleep(wait_interval)
logger.error(f"Server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port
def stop_server(self):
"""Stops the embedding server process if it's running."""
if not self.server_process:
return
if self.server_process and self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
self.server_port = None
self._server_config = None
return
logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
# Use simple termination first; if the server installed signal handlers,
# it will exit cleanly. Otherwise escalate to kill after a short wait.
try:
self.server_process.terminate()
except Exception:
pass
try:
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
except subprocess.TimeoutExpired:
logger.warning(
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
)
try:
self.server_process.kill()
except Exception:
pass
try:
self.server_process.wait(timeout=2)
logger.info(f"Server process {self.server_process.pid} killed successfully.")
except subprocess.TimeoutExpired:
logger.error(
f"Failed to kill server process {self.server_process.pid} - it may be hung"
)
# Clean up process resources with timeout to avoid CI hang
try:
# Use shorter timeout in CI environments
is_ci = os.environ.get("CI") == "true"
timeout = 3 if is_ci else 10
self.server_process.wait(timeout=timeout)
logger.info(f"Server process {self.server_process.pid} cleanup completed")
except subprocess.TimeoutExpired:
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
except Exception as e:
logger.warning(f"Error during process cleanup: {e}")
finally:
self.server_process = None
self.server_port = None
self._server_config = None
def _finalize_process(self) -> None:
"""Best-effort cleanup used by weakref.finalize/atexit."""
try:
self.stop_server()
except Exception:
pass
def _adopt_existing_server(self, *args, **kwargs) -> None:
# Removed: cross-process adoption no longer supported
return
def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback (unified)
if not self._atexit_registered:
atexit.register(self._finalize_process)
self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Colab embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
# Check for error output
stdout, stderr = self.server_process.communicate()
logger.error("Colab server terminated during startup.")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
return False, port
time.sleep(wait_interval)
logger.error(f"Colab server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port

View File

@@ -1,59 +1,107 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Literal, Optional
import numpy as np import numpy as np
from typing import Dict, Any
class LeannBackendBuilderInterface(ABC): class LeannBackendBuilderInterface(ABC):
"""用于构建索引的后端接口""" """Backend interface for building indexes"""
@abstractmethod @abstractmethod
def build(self, data: np.ndarray, index_path: str, **kwargs) -> None: def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
"""构建索引 """Build index
Args: Args:
data: 向量数据 (N, D) data: Vector data (N, D)
index_path: 索引保存路径 ids: List of string IDs for each vector
**kwargs: 后端特定的构建参数 index_path: Path to save index
**kwargs: Backend-specific build parameters
""" """
pass pass
class LeannBackendSearcherInterface(ABC): class LeannBackendSearcherInterface(ABC):
"""用于搜索的后端接口""" """Backend interface for searching"""
@abstractmethod @abstractmethod
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
"""初始化搜索器 """Initialize searcher
Args: Args:
index_path: 索引文件路径 index_path: Path to index file
**kwargs: 后端特定的加载参数 **kwargs: Backend-specific loading parameters
""" """
pass pass
@abstractmethod @abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def _ensure_server_running(
"""搜索最近邻 self, passages_source_file: str, port: Optional[int], **kwargs
) -> int:
"""Ensure server is running"""
pass
@abstractmethod
def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
**kwargs,
) -> dict[str, Any]:
"""Search for nearest neighbors
Args: Args:
query: 查询向量 (1, D) 或 (B, D) query: Query vectors (B, D) where B is batch size, D is dimension
top_k: 返回的最近邻数量 top_k: Number of nearest neighbors to return
**kwargs: 搜索参数 complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
**kwargs: Backend-specific parameters
Returns: Returns:
{"labels": [...], "distances": [...]} {"labels": [...], "distances": [...]}
""" """
pass pass
@abstractmethod
def compute_query_embedding(
self,
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
) -> np.ndarray:
"""Compute embedding for a query string
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array with shape (1, D)
"""
pass
class LeannBackendFactoryInterface(ABC): class LeannBackendFactoryInterface(ABC):
"""后端工厂接口""" """Backend factory interface"""
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def builder(**kwargs) -> LeannBackendBuilderInterface: def builder(**kwargs) -> LeannBackendBuilderInterface:
"""创建 Builder 实例""" """Create Builder instance"""
pass pass
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
"""创建 Searcher 实例""" """Create Searcher instance"""
pass pass

View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python3
import json
import subprocess
import sys
def handle_request(request):
if request.get("method") == "initialize":
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"capabilities": {"tools": {}},
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
},
}
elif request.get("method") == "tools/list":
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"tools": [
{
"name": "leann_search",
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
🎯 **Perfect for**:
- "How does authentication work?" → finds auth-related code
- "Error handling patterns" → locates try-catch blocks and error logic
- "Database connection setup" → finds DB initialization code
- "API endpoint definitions" → locates route handlers
- "Configuration management" → finds config files and usage
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
"inputSchema": {
"type": "object",
"properties": {
"index_name": {
"type": "string",
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
},
"query": {
"type": "string",
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
},
"top_k": {
"type": "integer",
"default": 5,
"minimum": 1,
"maximum": 20,
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
},
"complexity": {
"type": "integer",
"default": 32,
"minimum": 16,
"maximum": 128,
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
},
},
"required": ["index_name", "query"],
},
},
{
"name": "leann_list",
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
"inputSchema": {"type": "object", "properties": {}},
},
]
},
}
elif request.get("method") == "tools/call":
tool_name = request["params"]["name"]
args = request["params"].get("arguments", {})
try:
if tool_name == "leann_search":
# Validate required parameters
if not args.get("index_name") or not args.get("query"):
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"content": [
{
"type": "text",
"text": "Error: Both index_name and query are required",
}
]
},
}
# Build simplified command with non-interactive flag for MCP compatibility
cmd = [
"leann",
"search",
args["index_name"],
args["query"],
f"--top-k={args.get('top_k', 5)}",
f"--complexity={args.get('complexity', 32)}",
"--non-interactive",
]
result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_list":
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"content": [
{
"type": "text",
"text": result.stdout
if result.returncode == 0
else f"Error: {result.stderr}",
}
]
},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {"code": -1, "message": str(e)},
}
def main():
for line in sys.stdin:
try:
request = json.loads(line.strip())
response = handle_request(request)
if response:
print(json.dumps(response))
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"id": None,
"error": {"code": -1, "message": str(e)},
}
print(json.dumps(error_response))
sys.stdout.flush()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,240 @@
"""
Metadata filtering engine for LEANN search results.
This module provides generic metadata filtering capabilities that can be applied
to search results from any LEANN backend. The filtering supports various
operators for different data types including numbers, strings, booleans, and lists.
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
# Type alias for filter specifications
FilterValue = Union[str, int, float, bool, list]
FilterSpec = dict[str, FilterValue]
MetadataFilters = dict[str, FilterSpec]
class MetadataFilterEngine:
"""
Engine for evaluating metadata filters against search results.
Supports various operators for filtering based on metadata fields:
- Comparison: ==, !=, <, <=, >, >=
- Membership: in, not_in
- String operations: contains, starts_with, ends_with
- Boolean operations: is_true, is_false
"""
def __init__(self):
"""Initialize the filter engine with supported operators."""
self.operators = {
"==": self._equals,
"!=": self._not_equals,
"<": self._less_than,
"<=": self._less_than_or_equal,
">": self._greater_than,
">=": self._greater_than_or_equal,
"in": self._in,
"not_in": self._not_in,
"contains": self._contains,
"starts_with": self._starts_with,
"ends_with": self._ends_with,
"is_true": self._is_true,
"is_false": self._is_false,
}
def apply_filters(
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
) -> list[dict[str, Any]]:
"""
Apply metadata filters to a list of search results.
Args:
search_results: List of result dictionaries, each containing 'metadata' field
metadata_filters: Dictionary of filter specifications
Format: {"field_name": {"operator": value}}
Returns:
Filtered list of search results
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying filters: {metadata_filters}")
logger.debug(f"Input results count: {len(search_results)}")
filtered_results = []
for result in search_results:
if self._evaluate_filters(result, metadata_filters):
filtered_results.append(result)
logger.debug(f"Filtered results count: {len(filtered_results)}")
return filtered_results
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
"""
Evaluate all filters against a single search result.
All filters must pass (AND logic) for the result to be included.
Args:
result: Full search result dictionary (including metadata, text, etc.)
filters: Filter specifications to evaluate
Returns:
True if all filters pass, False otherwise
"""
for field_name, filter_spec in filters.items():
if not self._evaluate_field_filter(result, field_name, filter_spec):
return False
return True
def _evaluate_field_filter(
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
) -> bool:
"""
Evaluate a single field filter against a search result.
Args:
result: Full search result dictionary
field_name: Name of the field to filter on
filter_spec: Filter specification for this field
Returns:
True if the filter passes, False otherwise
"""
# First check top-level fields, then check metadata
field_value = result.get(field_name)
if field_value is None:
# Try to get from metadata if not found at top level
metadata = result.get("metadata", {})
field_value = metadata.get(field_name)
# Handle missing fields - they fail all filters except existence checks
if field_value is None:
logger.debug(f"Field '{field_name}' not found in result or metadata")
return False
# Evaluate each operator in the filter spec
for operator, expected_value in filter_spec.items():
if operator not in self.operators:
logger.warning(f"Unsupported operator: {operator}")
return False
try:
if not self.operators[operator](field_value, expected_value):
logger.debug(
f"Filter failed: {field_name} {operator} {expected_value} "
f"(actual: {field_value})"
)
return False
except Exception as e:
logger.warning(
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
)
return False
return True
# Comparison operators
def _equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value equals expected value."""
return field_value == expected_value
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value does not equal expected value."""
return field_value != expected_value
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
# Membership operators
def _in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'in' operator requires a list, tuple, or set")
return field_value in expected_value
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is not in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'not_in' operator requires a list, tuple, or set")
return field_value not in expected_value
# String operators
def _contains(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value contains the expected substring."""
field_str = str(field_value)
expected_str = str(expected_value)
return expected_str in field_str
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value starts with the expected prefix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.startswith(expected_str)
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value ends with the expected suffix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.endswith(expected_str)
# Boolean operators
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is truthy."""
return bool(field_value)
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is falsy."""
return not bool(field_value)
# Helper methods
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
"""
Helper for numeric comparisons with type coercion.
Args:
field_value: Value from metadata
expected_value: Value to compare against
compare_func: Comparison function to apply
Returns:
Result of comparison
"""
try:
# Try to convert both values to numbers for comparison
if isinstance(field_value, str) and isinstance(expected_value, str):
# String comparison if both are strings
return compare_func(field_value, expected_value)
# Numeric comparison - attempt to convert to float
field_num = (
float(field_value) if not isinstance(field_value, (int, float)) else field_value
)
expected_num = (
float(expected_value)
if not isinstance(expected_value, (int, float))
else expected_value
)
return compare_func(field_num, expected_num)
except (ValueError, TypeError):
# Fall back to string comparison if numeric conversion fails
return compare_func(str(field_value), str(expected_value))

View File

@@ -1,15 +1,98 @@
# packages/leann-core/src/leann/registry.py # packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING import importlib
import importlib.metadata
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {} # Set up logger for this module
logger = logging.getLogger(__name__)
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str): def register_backend(name: str):
"""A decorator to register a new backend class.""" """A decorator to register a new backend class."""
def decorator(cls): def decorator(cls):
print(f"INFO: Registering backend '{name}'") logger.debug(f"Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls BACKEND_REGISTRY[name] = cls
return cls return cls
return decorator return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
# print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata["name"]
if dist_name.startswith("leann-backend-"):
backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError:
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass
# print("INFO: Backend auto-discovery finished.")
def register_project_directory(project_dir: Optional[Union[str, Path]] = None):
"""
Register a project directory in the global LEANN registry.
This allows `leann list` to discover indexes created by apps or other tools.
Args:
project_dir: Directory to register. If None, uses current working directory.
"""
if project_dir is None:
project_dir = Path.cwd()
else:
project_dir = Path(project_dir)
# Only register directories that have some kind of LEANN content
# Either .leann/indexes/ (CLI format) or *.leann.meta.json files (apps format)
has_cli_indexes = (project_dir / ".leann" / "indexes").exists()
has_app_indexes = any(project_dir.rglob("*.leann.meta.json"))
if not (has_cli_indexes or has_app_indexes):
# Don't register if there are no LEANN indexes
return
global_registry = Path.home() / ".leann" / "projects.json"
global_registry.parent.mkdir(exist_ok=True)
project_str = str(project_dir.resolve())
# Load existing registry
projects = []
if global_registry.exists():
try:
with open(global_registry) as f:
projects = json.load(f)
except Exception:
logger.debug("Could not load existing project registry")
projects = []
# Add project if not already present
if project_str not in projects:
projects.append(project_str)
# Save updated registry
try:
with open(global_registry, "w") as f:
json.dump(projects, f, indent=2)
logger.debug(f"Registered project directory: {project_str}")
except Exception as e:
logger.warning(f"Could not save project registry: {e}")

Some files were not shown because too many files have changed in this diff Show More