Compare commits

..

123 Commits

Author SHA1 Message Date
Andy Lee
1d084f678c Merge remote-tracking branch 'origin/main' into perf-build 2025-07-21 20:13:12 -07:00
Andy Lee
54155e8b10 fix: same embedding logic 2025-07-21 20:12:40 -07:00
yichuan520030910320
5259ace111 [Readme] 2025-07-21 20:06:21 -07:00
yichuan520030910320
48ea5566e9 [Readme] detail number 2025-07-21 19:51:51 -07:00
yichuan520030910320
3f8b6c5bbd [Readme] 2025-07-21 18:15:00 -07:00
yichuan520030910320
725b32e74f [Readme] 2025-07-21 17:57:44 -07:00
yichuan520030910320
d0b71f393f [Readme] 2025-07-21 17:56:10 -07:00
yichuan520030910320
8a92efdae3 [Readme] 2025-07-21 17:53:59 -07:00
yichuan520030910320
019cdce2e8 [Readme] 2025-07-21 17:30:11 -07:00
yichuan520030910320
b64aa54fac fix break link 2025-07-21 17:29:35 -07:00
yichuan520030910320
c0d040f9d4 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 16:22:24 -07:00
yichuan520030910320
32364320f8 update wechat and we should fix the bug introduced in 1c5fec5 2025-07-21 16:22:16 -07:00
Andy Lee
f47f76d6d7 feat: cli more args 2025-07-20 22:17:55 -07:00
Andy Lee
1dc3923b53 feat: cli tool 2025-07-20 20:54:52 -07:00
Andy Lee
7e226a51c9 fix: do not reuse emb_server and close it properly 2025-07-20 18:07:51 -07:00
Andy Lee
f4998bb316 fix: no longger do embedding server reuse 2025-07-20 12:15:17 -07:00
Andy Lee
7522de1d41 chore: update faiss 2025-07-20 11:19:44 -07:00
Andy Lee
15f8bd1cc9 chore: shorter build time 2025-07-19 23:49:04 -07:00
Andy Lee
34c71c072d chore: parallel compile fix 2025-07-19 22:51:47 -07:00
Andy Lee
6d2149c503 chore: parallel compile fix 2025-07-19 22:46:24 -07:00
Andy Lee
043b0bf69d chore: parallel compile fix 2025-07-19 22:34:19 -07:00
Andy Lee
9b07e392c6 chore: parallel compile 2025-07-19 22:32:13 -07:00
Andy Lee
e60fad8c73 chore: mark diskann as optional 2025-07-19 22:24:44 -07:00
Andy Lee
19c1b182c3 docs: effects figure 2025-07-19 22:07:04 -07:00
Andy Lee
49edea780c docs: figure 2025-07-19 21:59:58 -07:00
Andy Lee
12ef5a1900 docs: effects 2025-07-19 21:57:12 -07:00
Andy Lee
d21a134b2a docs: polish 2025-07-19 21:53:41 -07:00
Andy Lee
1cd809aa41 [Docs] README polished version (#4)
* docs: polish

* docs: logo

* docs: logo

* docs: logo with text

* docs: readme effects

* docs: polish

* docs: highlight applications

* docs: polish

* docs: how it works earlier

* docs: polish

* docs: polish

* docs: follow yichuan's suggestion

* docs: follow yichuan's suggestion

---------

Co-authored-by: Yichuan Wang <73766326+yichuan-w@users.noreply.github.com>
2025-07-19 21:47:25 -07:00
yichuan520030910320
e728449b8f change chinese 2025-07-19 19:54:02 -07:00
yichuan520030910320
d0c20b14d5 clear output pf ipynb 2025-07-19 19:48:56 -07:00
yichuan520030910320
83b7ea5a59 change wecaht app split logic& merge 2025-07-19 19:44:33 -07:00
yichuan520030910320
0796a52df1 change wecaht app split logic 2025-07-19 19:43:30 -07:00
Andy Lee
85b7ba0168 feat: allow build from existed embeddings 2025-07-19 01:27:37 -07:00
yichuan520030910320
e117743d24 Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-17 22:29:39 -07:00
yichuan520030910320
aec2291f04 add embedding api 2025-07-17 22:29:31 -07:00
yichuan520030910320
335ae003ac add data 2025-07-17 22:29:03 -07:00
Andy Lee
71c7de9c84 fix: build with direct embedding 2025-07-17 21:49:36 -07:00
Andy Lee
1c5fec5565 perf: make embedder loading faster by 6x, and embed queries through the server 2025-07-17 20:08:06 -07:00
yichuan520030910320
99d439577d Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-17 18:15:27 -07:00
yichuan520030910320
4f83086788 update readme and auto find email 2025-07-17 18:15:17 -07:00
Andy Lee
a13c527e39 feat: openai embeddings 2025-07-17 17:02:47 -07:00
yichuan520030910320
90d9f27383 update readme and main example 2025-07-17 15:03:22 -07:00
yichuan520030910320
0db81c16cd update readme and chrome example 2025-07-17 12:58:11 -07:00
yichuan520030910320
e115e186b7 update example and more stats on result 2025-07-16 22:07:15 -07:00
yichuan520030910320
6546b29ef7 update readme 2025-07-16 20:29:45 -07:00
yichuan520030910320
51255bdffa update readme and add timer 2025-07-16 17:15:51 -07:00
Andy Lee
f77c4e38cb perf: reuse embedding server for query embed 2025-07-16 16:12:15 -07:00
Andy Lee
2a1a152073 refactor: nits 2025-07-16 15:39:58 -07:00
Andy Lee
7b9406a3ea feat: different search_args and docstrings 2025-07-16 15:25:58 -07:00
Andy Lee
c3fb949693 docs: ollama 2025-07-16 15:12:37 -07:00
yichuan520030910320
ed3f8dbfd6 update readme 2025-07-15 23:32:25 -07:00
yichuan520030910320
42aa6db170 update readme 2025-07-15 23:23:04 -07:00
yichuan520030910320
a6591d20ca Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-15 23:18:08 -07:00
yichuan520030910320
c1bc2603a2 update readme and 30 seconds example 2025-07-15 23:18:01 -07:00
Andy Lee
e595bbb5fb feat: hint for users about wrong model name 2025-07-15 22:40:40 -07:00
yichuan520030910320
4a2cb914d7 clean dict 2025-07-15 22:30:52 -07:00
yichuan520030910320
b1c93fe178 Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-15 22:29:09 -07:00
yichuan520030910320
0719458775 upd readme stats 2025-07-15 22:28:59 -07:00
Andy Lee
6a1dc895fb feat: disable warmup by default 2025-07-15 22:16:02 -07:00
Andy Lee
125c1f6f25 fix: model name 2025-07-15 21:48:45 -07:00
yichuan520030910320
1ceaa7d709 Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-15 21:19:25 -07:00
yichuan520030910320
dec3ee85fd fix main cli 2025-07-15 21:19:16 -07:00
Andy Lee
d94a5176dc docs: storage reduction data 2025-07-15 15:37:17 -07:00
yichuan520030910320
326783f7f1 fix mem compare fix split 2025-07-14 23:07:46 -07:00
yichuan520030910320
e5a9ca8787 fix mem compare 2025-07-14 22:55:10 -07:00
Andy Lee
f2feccdbd0 fix: mem compare 2025-07-14 16:35:08 -07:00
yichuan520030910320
246a077d64 upd readme 2025-07-14 16:21:34 -07:00
yichuan520030910320
3ba100ff25 upd readme 2025-07-14 16:18:39 -07:00
yichuan520030910320
1e3b571e72 add readme bench 2025-07-14 16:13:21 -07:00
Andy Lee
b89e56e9c2 fix: file name 2025-07-14 15:34:56 -07:00
yichuan520030910320
ed8a02e721 update readme and mlx support 2025-07-14 15:23:56 -07:00
Andy Lee
baa60b40d1 fix: smaller warmup id 2025-07-14 15:20:45 -07:00
Andy Lee
ef01d6997a fix: faiss only 2025-07-14 13:15:55 -07:00
Andy Lee
3da5b44d7f fix: mlx when searching, added to embedding_server 2025-07-14 01:11:21 -07:00
Andy Lee
8b4654921b fix: run faiss in subprocess to prevent kmp 2025-07-14 00:29:21 -07:00
yichuan520030910320
cf1cbafa78 Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG 2025-07-13 23:19:54 -07:00
yichuan520030910320
c96091744b update readme 2025-07-13 23:19:44 -07:00
Andy Lee
711fb4a775 feat: compare faiss 2025-07-13 22:44:16 -07:00
Andy Lee
3b5a185e60 refactor: check if current emb_server has correct passages/embedder 2025-07-13 22:43:51 -07:00
yichuan520030910320
77ac013a74 update readem 2025-07-13 22:37:41 -07:00
yichuan520030910320
b8e5728e6a fix wechat application 2025-07-13 22:29:54 -07:00
yichuan520030910320
d038319d8b upd readme wechat application 2025-07-13 22:00:49 -07:00
yichuan520030910320
c611d0f30f upd readme mail application 2025-07-13 21:48:57 -07:00
yichuan520030910320
c17899662f upd readme mail application 2025-07-13 21:30:08 -07:00
yichuan520030910320
c51d5320fa upd test/mlx 2025-07-13 20:16:02 -07:00
yichuan520030910320
6fa9512a64 fix wechat path 2025-07-13 18:23:31 -07:00
Andy Lee
fddc61df5e chore: reset to latest version 2025-07-13 17:06:48 -07:00
Andy Lee
53c58fa755 perf: switch to tranditional pdf reader 2025-07-13 17:04:23 -07:00
yichuan520030910320
c69afb56e4 Resolve submodule conflict - update to af2a264 2025-07-13 17:03:42 -07:00
yichuan520030910320
0fa8a9191f add wechat history extract app 2025-07-13 16:52:08 -07:00
Andy Lee
48dda1cb5b feat: mlx 2025-07-13 02:13:04 -07:00
Andy Lee
71ef4b7d4c fix: reproducible dpr on mac 2025-07-12 18:13:22 -07:00
Andy Lee
ecab43e307 feat: dataset for evaluation 2025-07-12 23:43:10 +00:00
Fangzhou66
88ca09440d fix some hf problem 2025-07-12 16:13:15 -07:00
Andy Lee
8e0ab4a28d chore: update deps 2025-07-12 22:48:13 +00:00
yichuan520030910320
9b8c5041dc update readme 2025-07-12 13:01:11 -07:00
yichuan520030910320
74ffd7ec64 add email test code 2025-07-11 23:59:47 -07:00
Andy Lee
eb6f504789 Datastore reproduce (#3)
* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann

* docs: embedding pruning

* refactor: passage structure

* feat: reproducible research datas, rpj_wiki & dpr

* refactor: chat and base searcher

* feat: chat on mps
2025-07-11 23:37:23 -07:00
yichuan520030910320
91a026f38b polish readme 2025-07-11 23:06:08 -07:00
yichuan520030910320
595138a0a3 upd readme 2025-07-11 22:43:48 -07:00
yichuan520030910320
19df04095f add readme 2025-07-11 22:34:54 -07:00
yichuan520030910320
8239bbb48f add google hostory api 2025-07-11 21:21:36 -07:00
yichuan520030910320
16ee9d0422 add traverse all dict interface 2025-07-10 15:59:16 -07:00
yichuan520030910320
8a961f8ab3 align the llamaindex result w leann& test attachment 2025-07-09 21:42:15 -07:00
yichuan520030910320
558126c46e add leann and llamaindex email infra, and need to align the results 2025-07-09 16:27:11 -07:00
yichuan520030910320
04c9684488 add email test code 2025-07-09 15:06:31 -07:00
Andy Lee
b744faa7e6 chore: all deps 2025-07-08 23:37:40 +00:00
Andy Lee
27b3a26e75 fix(deps): Update DiskANN with cleaned up CMake configuration 2025-07-08 23:27:05 +00:00
Andy Lee
41d872504e feat(deps): Update DiskANN to use system-installed Boost and Protobuf 2025-07-08 23:13:36 +00:00
Andy Lee
963cd05273 chore: diskann modules 2025-07-08 21:57:38 +00:00
Andy Lee
09b6e67baf chore: diskann upg boost 2025-07-08 21:44:44 +00:00
yichuan520030910320
dafb2aacab update macos env 2025-07-08 14:37:41 -07:00
Andy Lee
a6c400cd4f chroe: linux boost and protobuf 2025-07-08 21:25:43 +00:00
Andy Lee
c013e5ccce chore: linux deps 2025-07-08 13:55:39 -07:00
Andy Lee
f25a1a3840 chore: macos compatible 2025-07-08 13:32:00 -07:00
yichuan520030910320
6497e17671 add gpu chunk embedd and add complexity in hnsw 2025-07-08 18:40:52 +00:00
yichuan520030910320
44369a8138 update diskann module 2025-07-07 18:27:07 -07:00
yichuan520030910320
dfca00c21b add mac support in this repo 2025-07-07 18:22:24 -07:00
yichuan520030910320
637dab379e add workaround code 2025-07-07 23:13:47 +00:00
yichuan520030910320
6fc57eb48e add reuse code 2025-07-07 21:07:00 +00:00
yichuan520030910320
95a653993a rm useless 2025-07-06 06:47:20 +00:00
yichuan520030910320
af0959818d rm useless 2025-07-06 05:21:05 +00:00
Andy Lee
cf17c85607 Make DiskANN and HNSW work on main example (#2)
* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann
2025-07-05 22:18:12 -07:00
78 changed files with 24377 additions and 3607 deletions

13
.gitignore vendored
View File

@@ -8,11 +8,17 @@ demo/indices/
*pycache*
outputs/
*.pkl
*.pdf
*.idx
*.map
.history/
scripts/
lm_eval.egg-info/
demo/experiment_results/**/*.json
*.jsonl
*.eml
*.emlx
*.json
*.sh
*.txt
!CMakeLists.txt
@@ -29,7 +35,11 @@ build/
nprobe_logs/
micro/results
micro/contriever-INT8
examples/data/
examples/data/*
!examples/data/2501.14312v1 (1).pdf
!examples/data/2506.08276v1.pdf
!examples/data/PrideandPrejudice.txt
!examples/data/README.md
*.qdstrm
benchmark_results/
results/
@@ -42,6 +52,7 @@ embedding_comparison_results/
*.ivecs
*.index
*.bin
*.old
read_graph
analyze_diskann_graph

14
.gitmodules vendored
View File

@@ -1,6 +1,16 @@
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
path = packages/leann-backend-diskann/third_party/DiskANN
url = https://github.com/yichuan520030910320/DiskANN.git
url = https://github.com/yichuan-w/DiskANN.git
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
path = packages/leann-backend-hnsw/third_party/faiss
url = https://github.com/yichuan520030910320/faiss.git
url = https://github.com/yichuan-w/faiss.git
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
path = packages/leann-backend-hnsw/third_party/msgpack-c
url = https://github.com/msgpack/msgpack-c.git
branch = cpp_master
[submodule "packages/leann-backend-hnsw/third_party/cppzmq"]
path = packages/leann-backend-hnsw/third_party/cppzmq
url = https://github.com/zeromq/cppzmq.git
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git

View File

@@ -1,6 +1,6 @@
MIT License
Copyright (c) 2024 Rulin Shao
Copyright (c) 2025 LEANN Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

523
README.md
View File

@@ -1,171 +1,360 @@
# 🚀 LEANN: A Low-Storage Vector Index
<p align="center">
<img src="assets/logo-text.png" alt="LEANN Logo" width="400">
</p>
<p align="center">
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome">
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
</p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
The smallest vector index in the world. RAG Everything with LEANN!
</h2>
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
## Why LEANN?
<p align="center">
<strong>⚡ Real-time embedding computation for large-scale RAG on consumer hardware</strong>
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
</p>
<p align="center">
<a href="#-quick-start">Quick Start</a> •
<a href="#-features">Features</a> •
<a href="#-benchmarks">Benchmarks</a> •
<a href="#-documentation">Documentation</a> •
<a href="#-paper">Paper</a>
</p>
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
---
## Why This Matters
## 🌟 What is Leann?
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
**Leann** revolutionizes Retrieval-Augmented Generation (RAG) by eliminating the storage bottleneck of traditional vector databases. Instead of pre-computing and storing billions of embeddings, Leann dynamically computes embeddings at query time using highly optimized graph-based search algorithms.
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
### 🎯 Why Leann?
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
Traditional RAG systems face a fundamental trade-off:
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
- **💰 Cost**: Vector databases are expensive to scale
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
**Leann solves this by:**
-**Zero embedding storage** - Only graph structure is persisted
-**Real-time computation** - Embeddings computed on-demand with ms latency
-**Memory efficient** - Runs on consumer hardware (8GB RAM)
-**Always fresh** - No stale embeddings, ever
## 🚀 Quick Start
### Installation
## Quick Start in 1 minute
```bash
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
git clone git@github.com:yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
uv sync
```
### 30-Second Example
**macOS:**
```bash
brew install llvm libomp boost protobuf zeromq
export CC=$(brew --prefix llvm)/bin/clang
export CXX=$(brew --prefix llvm)/bin/clang++
# Install with HNSW backend (default, recommended for most users)
uv sync
# Or add DiskANN backend if you want to test more options
uv sync --extra diskann
```
**Linux (Ubuntu/Debian):**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
# Install with HNSW backend (default, recommended for most users)
uv sync
# Or add DiskANN backend if you want to test more options
uv sync --extra diskann
```
**Ollama Setup (Optional for Local LLM):**
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
*macOS:*
First, [download Ollama for macOS](https://ollama.com/download/mac).
```bash
# Pull a lightweight model (recommended for consumer hardware)
ollama pull llama3.2:1b
```
*Linux:*
```bash
# Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh
# Start Ollama service manually
ollama serve &
# Pull a lightweight model (recommended for consumer hardware)
ollama pull llama3.2:1b
```
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
## Dead Simple API
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
```python
from leann.api import LeannBuilder, LeannSearcher
# 1. Build index (no embeddings stored!)
builder = LeannBuilder(backend_name="diskann")
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("C# is a powerful programming language")
builder.add_text("Python is a powerful programming language")
builder.add_text("Machine learning transforms industries")
builder.add_text("Neural networks process complex data")
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
builder.build_index("knowledge.leann")
# 2. Search with real-time embeddings
searcher = LeannSearcher("knowledge.leann")
results = searcher.search("programming languages", top_k=2)
for result in results:
print(f"Score: {result['score']:.3f} - {result['text']}")
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
print(results)
```
### Run the Demo
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
[Try the interactive demo →](demo.ipynb)
## Wild Things You Can Do
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
### Process Any Documents (.pdf, .txt, .md)
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
```bash
uv run examples/document_search.py
# Drop your PDFs, .txt, .md files into examples/data/
uv run ./examples/main_cli_example.py
# Or use python directly
source .venv/bin/activate
python ./examples/main_cli_example.py
```
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
This demo showcases how to build a RAG system for PDF documents using Leann.
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
### Search Your Entire Life
```bash
python examples/mail_reader_leann.py
# "What did my boss say about the Christmas party last year?"
# "Find all emails from my mom about birthday plans"
```
**90K emails → 14MB.** Finally, search your email like you search Google.
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
uv run examples/main_cli_example.py
# Use default mail path (works for most macOS setups)
python examples/mail_reader_leann.py
# Run with custom index directory
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
# Process all emails (may take time but indexes everything)
python examples/mail_reader_leann.py --max-emails -1
# Limit number of emails processed (useful for testing)
python examples/mail_reader_leann.py --max-emails 1000
# Run a single query
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
```
## ✨ Features
</details>
### 🔥 Core Features
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
- **📈 Scalable Architecture**: Handles millions of documents on consumer hardware
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
<details>
<summary><strong>📋 Click to expand: Example queries you can try</strong></summary>
### 🛠️ Technical Highlights
- **Zero-copy operations** for maximum performance
- **SIMD-optimized** distance computations (AVX2/AVX512)
- **Async embedding pipeline** with batched processing
- **Memory-mapped indices** for fast startup
- **Recompute mode** for highest accuracy scenarios
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
- **Rich debugging tools** - Built-in performance profiling
## 📊 Benchmarks
### Memory Usage Comparison
| System | 1M Documents | 10M Documents | 100M Documents |
|--------|-------------|---------------|----------------|
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
### Query Performance
| Backend | Index Size | Query Time | Recall@10 |
|---------|------------|------------|-----------|
| DiskANN | 1M docs | 12ms | 0.95 |
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
| HNSW | 1M docs | 8ms | 0.93 |
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
## 🏗️ Architecture
Once the index is built, you can ask questions like:
- "Find emails from my boss about deadlines"
- "What did John say about the project timeline?"
- "Show me emails about travel expenses"
</details>
### Time Machine for the Web
```bash
python examples/google_history_reader_leann.py
# "What was that AI paper I read last month?"
# "Show me all the cooking videos I watched"
```
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
│ Query Text │───▶│ Embedding │───▶│ Graph-based │
│ │ │ Computation │ │ Search │
└─────────────────┘ └──────────────────┘ └─────────────────┘
│ │
▼ ▼
┌──────────────┐ ┌──────────────┐
│ ZMQ Server │ │ Pruned Graph │
│ (Cached) │ │ Index │
└──────────────┘ └──────────────┘
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default Chrome profile (auto-finds all profiles)
python examples/google_history_reader_leann.py
# Run with custom index directory
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
# Limit number of history entries processed (useful for testing)
python examples/google_history_reader_leann.py --max-entries 500
# Run a single query
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
```
### Key Components
</details>
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
2. **📊 Graph Index**: Memory-efficient navigation structures
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
<details>
<summary><strong>📋 Click to expand: How to find your Chrome profile</strong></summary>
## 🎓 Supported Models & Backends
The default Chrome profile path is configured for a typical macOS setup. If you need to find your specific Chrome profile:
### 🤖 Embedding Models
- **sentence-transformers/all-mpnet-base-v2** (default)
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
- Any HuggingFace sentence-transformer model
- Custom model support via API
1. Open Terminal
2. Run: `ls ~/Library/Application\ Support/Google/Chrome/`
3. Look for folders like "Default", "Profile 1", "Profile 2", etc.
4. Use the full path as your `--chrome-profile` argument
### 🔧 Search Backends
- **DiskANN**: Microsoft's billion-scale ANN algorithm
- **HNSW**: Hierarchical Navigable Small World graphs
- **Coming soon**: ScaNN, Faiss-IVF, NGT
**Common Chrome profile locations:**
- macOS: `~/Library/Application Support/Google/Chrome/Default`
- Linux: `~/.config/google-chrome/Default`
### 📏 Distance Functions
- **L2**: Euclidean distance for precise similarity
- **Cosine**: Angular similarity for normalized vectors
- **MIPS**: Maximum Inner Product Search for recommendation systems
</details>
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "What websites did I visit about machine learning?"
- "Find my search history about programming"
- "What YouTube videos did I watch recently?"
- "Show me websites I visited about travel planning"
</details>
### WeChat Detective
```bash
python examples/wechat_history_reader_leann.py
# "Show me all group chats about weekend plans"
```
**400K messages → 64MB.** Search years of chat history in any language.
<details>
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
First, you need to install the WeChat exporter:
```bash
sudo packages/wechat-exporter/wechattweak-cli install
```
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
</details>
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default settings (recommended for first run)
python examples/wechat_history_reader_leann.py
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
# Run with custom index directory
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
# Limit number of chat entries processed (useful for testing)
python examples/wechat_history_reader_leann.py --max-entries 1000
# Run a single query
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
```
</details>
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?" (Chinese: Show me chat records about buying Magic Johnson's jersey)
</details>
## 🏗️ Architecture & How It Works
<p align="center">
<img src="assets/arch.png" alt="LEANN Architecture" width="800">
</p>
**The magic:** Most vector DBs store every single embedding (expensive). LEANN stores a pruned graph structure (cheap) and recomputes embeddings only when needed (fast).
**Core techniques:**
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
**Backends:** DiskANN or HNSW - pick what works for your data size.
## Benchmarks
Run the comparison yourself:
```bash
python examples/compare_faiss_vs_leann.py
```
| System | Storage |
|--------|---------|
| FAISS HNSW | 5.5 MB |
| LEANN | 0.5 MB |
| **Savings** | **91%** |
Same dataset, same hardware, same embedding model. LEANN just works better.
## Reproduce Our Results
```bash
uv pip install -e ".[dev]" # Install dev dependencies
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
```
The evaluation script downloads data automatically on first run.
### Storage Usage Comparison
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
<!-- ### Memory Usage Comparison
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
| --------------------- | ---------------- | ---------------- | ---------------- |
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
| **Leann** | **xx MB** | **x GB** | **x GB** |
| **Reduction** | **x%** | **x%** | **x%** |
### Query Performance of LEANN
| Backend | Index Size | Query Time | Recall@3 |
| ------------------- | ---------- | ---------- | --------- |
| DiskANN | 1M docs | xms | 0.95 |
| HNSW | 1M docs | xms | 0.95 | -->
*Benchmarks run on Apple M3 Pro 36 GB*
## 🔬 Paper
@@ -185,73 +374,56 @@ If you find Leann useful, please cite:
}
```
## 🌍 Use Cases
## ✨ Features
### 💼 Enterprise RAG
```python
# Handle millions of documents with limited resources
builder = LeannBuilder(
backend_name="diskann",
distance_metric="cosine",
graph_degree=64,
memory_budget="4GB"
)
```
### 🔥 Core Features
### 🔬 Research & Experimentation
```python
# Quick prototyping with different algorithms
for backend in ["diskann", "hnsw"]:
searcher = LeannSearcher(index_path, backend=backend)
evaluate_recall(searcher, queries, ground_truth)
```
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
### 🚀 Real-time Applications
```python
# Sub-second response times
chat = LeannChat("knowledge.leann")
response = chat.ask("What is quantum computing?")
# Returns in <100ms with recompute mode
```
### 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
## 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
### Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
### Development Setup
```bash
git clone https://github.com/yourname/leann
cd leann
uv sync --dev
uv run pytest tests/
```
### Quick Tests
```bash
# Sanity check all distance functions
uv run python tests/sanity_checks/test_distance_functions.py
# Verify L2 implementation
uv run python tests/sanity_checks/test_l2_verification.py
```
## ❓ FAQ
<!-- ## ❓ FAQ
### Common Issues
#### NCCL Topology Error
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
```
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
```
**Solution**: Set these environment variables before running your script:
```bash
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
export NCCL_DEBUG=INFO
@@ -259,36 +431,30 @@ export NCCL_DEBUG_SUBSYS=INIT,GRAPH
export NCCL_IB_DISABLE=1
export NCCL_NET_PLUGIN=none
export NCCL_SOCKET_IFNAME=ens5
``` -->
## 📈 Roadmap
### 🎯 Q1 2024
- [x] DiskANN backend with MIPS/L2/Cosine support
- [x] HNSW backend integration
- [x] Real-time embedding pipeline
- [x] Memory-efficient graph pruning
### 🎯 Q2 2025
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
### 🚀 Q3 2025
### 🚀 Q2 2024
- [ ] Distributed search across multiple nodes
- [ ] ScaNN backend support
- [ ] Advanced caching strategies
- [ ] Kubernetes deployment guides
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
### 🌟 Q4 2025
### 🌟 Q3 2024
- [ ] GPU-accelerated embedding computation
- [ ] Approximate distance functions
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
## 💬 Community
Join our growing community of researchers and engineers!
- 🐦 **Twitter**: [@LeannAI](https://twitter.com/LeannAI)
- 💬 **Discord**: [Join our server](https://discord.gg/leann)
- 📧 **Email**: leann@yourcompany.com
- 🐙 **GitHub Discussions**: [Ask questions here](https://github.com/yourname/leann/discussions)
- [ ] Query rewrtiting, rerank and expansion
## 📄 License
@@ -297,7 +463,7 @@ MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments
- **Microsoft Research** for the DiskANN algorithm
- **Meta AI** for FAISS and optimization insights
- **Meta AI** for FAISS and optimization insights
- **HuggingFace** for the transformer ecosystem
- **Our amazing contributors** who make this possible
@@ -309,4 +475,5 @@ MIT License - see [LICENSE](LICENSE) for details.
<p align="center">
Made with ❤️ by the Leann team
</p>
</p>

BIN
assets/arch.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

BIN
assets/effects.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 339 KiB

BIN
assets/logo-text.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 818 KiB

BIN
assets/logo.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

82
data/.gitattributes vendored Normal file
View File

@@ -0,0 +1,82 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.lz4 filter=lfs diff=lfs merge=lfs -text
*.mds filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
# Audio files - uncompressed
*.pcm filter=lfs diff=lfs merge=lfs -text
*.sam filter=lfs diff=lfs merge=lfs -text
*.raw filter=lfs diff=lfs merge=lfs -text
# Audio files - compressed
*.aac filter=lfs diff=lfs merge=lfs -text
*.flac filter=lfs diff=lfs merge=lfs -text
*.mp3 filter=lfs diff=lfs merge=lfs -text
*.ogg filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
# Image files - uncompressed
*.bmp filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.tiff filter=lfs diff=lfs merge=lfs -text
# Image files - compressed
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.webp filter=lfs diff=lfs merge=lfs -text
# Video files - compressed
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.webm filter=lfs diff=lfs merge=lfs -text
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text

44
data/README.md Normal file
View File

@@ -0,0 +1,44 @@
---
license: mit
---
# LEANN-RAG Evaluation Data
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
## Dataset Components
This dataset is structured into three main parts:
1. **Pre-built LEANN Indices**:
* `dpr/`: A pre-built index for the DPR dataset.
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
2. **Ground Truth Data**:
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
3. **Queries**:
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
## Usage
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
```bash
pip install huggingface-hub
```
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
```python
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir="data"
)
```
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.

View File

@@ -2,361 +2,34 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initializing leann-backend-diskann...\n",
"INFO: Registering backend 'diskann'\n",
"INFO: DiskANN backend loaded successfully\n",
"INFO: LeannBuilder initialized with 'diskann' backend.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/LEANN_clean/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Computing embeddings for 6 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 2.91it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Building DiskANN index for 6 vectors with metric Metric.INNER_PRODUCT...\n",
"Using Inner Product search, so need to pre-process base data into temp file. Please ensure there is additional (n*(d+1)*4) bytes for storing pre-processed base vectors, apart from the interim indices created by DiskANN and the final index.\n",
"Pre-processing base file by adding extra coordinate\n",
"✅ DiskANN index built successfully at 'knowledge'\n",
"Writing bin: knowledge_disk.index_max_base_norm.bin\n",
"bin: #pts = 1, #dims = 1, size = 12B\n",
"Finished writing bin.\n",
"Time for preprocessing data for inner product: 0.000172 seconds\n",
"Reading max_norm_of_base from knowledge_disk.index_max_base_norm.bin\n",
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
"Metadata: #pts = 1, #dims = 1...\n",
"done.\n",
"max_norm_of_base: 1\n",
"! Using prepped_base file at knowledge_prepped_base.bin\n",
"Starting index build: R=32 L=64 Query RAM budget: 4.02653e+09 Indexing ram budget: 8 T: 8\n",
"getting bin metadata\n",
"Time for getting bin metadata: 0.000019 seconds\n",
"Compressing 769-dimensional data into 512 bytes per vector.\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Training data with 6 samples loaded.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"PQ pivot file exists. Not generating again\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 4, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 769, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 513, #dims = 1...\n",
"done.\n",
"Loaded PQ pivot information\n",
"Processing points [0, 6)...done.\n",
"Time for generating quantized data: 0.055587 seconds\n",
"Full index fits in RAM budget, should consume at most 2.03973e-05GiBs, so building in one shot\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"Passed, empty search_params while creating index config\n",
"Using only first 6 from file.. \n",
"Starting index build with 6 points... \n",
"0% of index build completed.Starting final cleanup..done. Link time: 0.00011s\n",
"Index built with degree: max:5 avg:5 min:5 count(deg<2):0\n",
"Not saving tags as they are not enabled.\n",
"Time taken for save: 0.000148s.\n",
"Time for building merged vamana index: 0.000836 seconds\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Vamana index file size=168\n",
"Opened: knowledge_disk.index, cache_size: 67108864\n",
"medoid: 0B\n",
"max_node_len: 3100B\n",
"nnodes_per_sector: 1B\n",
"# sectors: 6\n",
"Sector #0written\n",
"Finished writing 28672B\n",
"Writing bin: knowledge_disk.index\n",
"bin: #pts = 9, #dims = 1, size = 80B\n",
"Finished writing bin.\n",
"Output disk index file written to knowledge_disk.index\n",
"Finished writing 28672B\n",
"Time for generating disk layout: 0.040268 seconds\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Loading base knowledge_prepped_base.bin. #points: 6. #dim: 769.\n",
"Wrote 1 points to sample file: knowledge_sample_data.bin\n",
"Indexing time: 0.0970594\n",
"INFO: Leann metadata saved to knowledge.leann.meta.json\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Opened file : knowledge_disk.index\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ DiskANN index loaded successfully.\n",
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
"Since data is floating point, we assume that it has been appropriately pre-processed (normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we shall invoke an l2 distance function.\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"Before index load\n",
"Reading bin file knowledge_pq_compressed.bin ...\n",
"Opening bin file knowledge_pq_compressed.bin... \n",
"Metadata: #pts = 6, #dims = 512...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 4, #dims = 1...\n",
"done.\n",
"Offsets: 4096 791560 794644 796704\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 769, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 513, #dims = 1...\n",
"done.\n",
"Loaded PQ Pivots: #ctrs: 256, #dims: 769, #chunks: 512\n",
"Loaded PQ centroids and in-memory compressed vectors. #points: 6 #dim: 769 #aligned_dim: 776 #chunks: 512\n",
"Loading index metadata from knowledge_disk.index\n",
"Disk-Index File Meta-data: # nodes per sector: 1, max node len (bytes): 3100, max node degree: 5\n",
"Disk-Index Meta: nodes per sector: 1, max node len: 3100, max node degree: 5\n",
"Setting up thread-specific contexts for nthreads: 8\n",
"allocating ctx: 0x7a33f7204000 to thread-id:134367072315200\n",
"allocating ctx: 0x7a33f6805000 to thread-id:134355206802368\n",
"allocating ctx: 0x7a33f5e72000 to thread-id:134355217288000\n",
"allocating ctx: 0x7a33f5e61000 to thread-id:134355227773632\n",
"allocating ctx: 0x7a33f5e50000 to thread-id:134355196316736\n",
"allocating ctx: 0x7a33f5e3f000 to thread-id:134355164859840\n",
"allocating ctx: 0x7a33f5e2e000 to thread-id:134355175345472\n",
"allocating ctx: 0x7a33f5e1d000 to thread-id:134355185831104\n",
"Loading centroid data from medoids vector data of 1 medoid(s)\n",
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
"Metadata: #pts = 1, #dims = 1...\n",
"done.\n",
"Setting re-scaling factor of base vectors to 1\n",
"load_from_separate_paths done.\n",
"Reading (with alignment) bin file knowledge_sample_data.bin ...Metadata: #pts = 1, #dims = 769, aligned_dim = 776... allocating aligned memory of 3104 bytes... done. Copying data to mem_aligned buffer... done.\n",
"reserve ratio: 1\n",
"Graph traversal completed, hops: 3\n",
"Loading the cache list into memory....done.\n",
"After index load\n",
"INFO: Computing embeddings for 1 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 60.54it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running\n",
"INFO: Starting session-level embedding server as a background process...\n",
"INFO: Running command from project root: /home/ubuntu/LEANN_clean/leann\n",
"INFO: Server process started with PID: 424761\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Embedding server is up and ready for this session.\n",
"[EmbeddingServer LOG]: Initializing leann-backend-diskann...\n",
"[EmbeddingServer LOG]: WARNING: Could not import DiskANN backend: cannot import name '_diskannpy' from partially initialized module 'packages.leann-backend-diskann.leann_backend_diskann' (most likely due to a circular import) (/home/ubuntu/LEANN_clean/leann/packages/leann-backend-diskann/leann_backend_diskann/__init__.py)\n",
"[EmbeddingServer LOG]: INFO: Initializing embedding server thread on port 5555\n",
"[EmbeddingServer LOG]: INFO: Using CUDA device\n",
"[EmbeddingServer LOG]: INFO: Loading model sentence-transformers/all-mpnet-base-v2\n",
"[EmbeddingServer LOG]: INFO: Using FP16 precision with model: sentence-transformers/all-mpnet-base-v2\n",
"[EmbeddingServer LOG]: INFO: Loaded 6 demo documents\n",
"[EmbeddingServer LOG]: INFO: ZMQ ROUTER server listening on port 5555\n",
"[EmbeddingServer LOG]: INFO: Embedding server ready to serve requests\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 3 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 1 node embeddings: [0]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 0\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000028 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 1, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 1\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.019294 seconds\n",
"[EmbeddingServer LOG]: Batch size: 1, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000210 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 3.065444 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.041810 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000194 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 3.128073 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [1, 2, 3, 4, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 1 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000042 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001791 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000112 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 3.674183 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000372 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000177 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 3.677425 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [3, 4, 2, 1, 0]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 4\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000030 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001550 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000097 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.009335 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000154 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000073 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.011773 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [0, 1, 2, 4, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001041 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000125 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008972 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000151 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000048 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010853 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [3, 1, 0, 2, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001350 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000088 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008869 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000146 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000063 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.011054 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [0, 2, 3, 4, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000022 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001195 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000087 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008903 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000145 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000060 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010921 seconds\n",
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [1, 0, 3, 4, 5]\n",
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001188 seconds\n",
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000087 seconds\n",
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008858 seconds\n",
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000153 seconds\n",
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000052 seconds\n",
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010886 seconds\n",
"reserve ratio: Score: -0.481 - C++ is a powerful programming language1\n",
"Graph traversal completed, hops: 3\n",
"\n",
"Score: -1.049 - Java is a powerful programming language\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n"
]
}
],
"outputs": [],
"source": [
"from leann.api import LeannBuilder, LeannSearcher\n",
"import leann_backend_diskann\n",
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
"# 1. Build index (no embeddings stored!)\n",
"builder = LeannBuilder(backend_name=\"diskann\")\n",
"builder.add_text(\"Python is a powerful programming language\")\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
"builder.add_text(\"Machine learning transforms industries\") \n",
"builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Java is a powerful programming language\")\n",
"builder.add_text(\"C++ is a powerful programming language\")\n",
"builder.add_text(\"C# is a powerful programming language\")\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
"builder.build_index(\"knowledge.leann\")\n",
"\n",
"# 2. Search with real-time embeddings\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n",
"results = searcher.search(\"C++ programming languages\", top_k=2,recompute_beighbor_embeddings=True)\n",
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
"print(results)\n",
"\n",
"for result in results:\n",
" print(f\"Score: {result['score']:.3f} - {result['text']}\")"
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
"\n",
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
"\n",
"response = chat.ask(\n",
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
" top_k=2,\n",
" recompute_beighbor_embeddings=True,\n",
")\n",
"print(response)"
]
}
],
@@ -376,7 +49,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
"version": "3.11.12"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,335 @@
#!/usr/bin/env python3
"""
Memory comparison between Faiss HNSW and LEANN HNSW backend
"""
import logging
import os
import sys
import time
import psutil
import gc
import subprocess
from pathlib import Path
from llama_index.core.node_parser import SentenceSplitter
# Setup logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_memory_usage():
"""Get current memory usage in MB"""
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def print_memory_stats(stage: str, start_mem: float):
"""Print memory statistics"""
current_mem = get_memory_usage()
diff = current_mem - start_mem
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
return current_mem
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
print(f"\n=== {self.name} Memory Summary ===")
for stage, mem in self.stages:
print(f"{stage}: {mem:.1f} MB")
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
return peak_mem
def test_faiss_hnsw():
"""Test Faiss HNSW Vector Store in subprocess"""
print("\n" + "=" * 50)
print("TESTING FAISS HNSW VECTOR STORE")
print("=" * 50)
try:
result = subprocess.run(
[sys.executable, "examples/faiss_only.py"],
capture_output=True,
text=True,
timeout=300,
)
print(result.stdout)
if result.stderr:
print("Stderr:", result.stderr)
if result.returncode != 0:
return {
"peak_memory": float("inf"),
"error": f"Process failed with code {result.returncode}",
}
# Parse peak memory from output
lines = result.stdout.split("\n")
peak_memory = 0.0
for line in lines:
if "Peak Memory:" in line:
peak_memory = float(
line.split("Peak Memory:")[1].split("MB")[0].strip()
)
return {"peak_memory": peak_memory}
except Exception as e:
return {
"peak_memory": float("inf"),
"error": str(e),
}
def test_leann_hnsw():
"""Test LEANN HNSW Search Memory (load existing index)"""
print("\n" + "=" * 50)
print("TESTING LEANN HNSW SEARCH MEMORY")
print("=" * 50)
tracker = MemoryTracker("LEANN HNSW Search")
# Import and setup
tracker.checkpoint("Initial")
from leann.api import LeannSearcher
tracker.checkpoint("After imports")
from llama_index.core import SimpleDirectoryReader
from leann.api import LeannBuilder, LeannSearcher
# Load and parse documents
documents = SimpleDirectoryReader(
"examples/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
tracker.checkpoint("After text chunking")
# Build LEANN index
INDEX_DIR = Path("./test_leann_comparison")
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
# Check if index already exists
if os.path.exists(INDEX_PATH + ".meta.json"):
print("Loading existing LEANN HNSW index...")
tracker.checkpoint("After loading existing index")
else:
print("Building new LEANN HNSW index...")
# Clean up previous index
import shutil
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1,
)
tracker.checkpoint("After builder setup")
print("Building LEANN HNSW index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
del builder
gc.collect()
tracker.checkpoint("After index building")
# Find existing LEANN index
index_paths = [
"./test_leann_comparison/comparison.leann",
]
index_path = None
for path in index_paths:
if os.path.exists(path + ".meta.json"):
index_path = path
break
if not index_path:
print("❌ LEANN index not found. Please build it first")
return {"peak_memory": float("inf"), "error": "Index not found"}
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
# Load searcher
searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading")
print("Running search queries...")
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
_ = searcher.search(query, top_k=20, ef=120)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
# Get storage size before cleanup
storage_size = 0
INDEX_DIR = Path(index_path).parent
if INDEX_DIR.exists():
total_size = 0
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
for filename in filenames:
# Only count actual index files, skip text data and backups
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
continue
# Count .index, .idx, .map files (actual index structures)
if filename.endswith((".index", ".idx", ".map")):
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
storage_size = total_size / (1024 * 1024) # Convert to MB
# Clean up
del searcher
gc.collect()
return {
"peak_memory": peak_memory,
"storage_size": storage_size,
}
def main():
"""Run comparison tests"""
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
print("=" * 60)
# Test Faiss HNSW
faiss_results = test_faiss_hnsw()
# Force garbage collection
gc.collect()
time.sleep(2)
# Test LEANN HNSW
leann_results = test_leann_hnsw()
# Final comparison
print("\n" + "=" * 60)
print("STORAGE + SEARCH MEMORY COMPARISON")
print("=" * 60)
# Get storage sizes
faiss_storage_size = 0
leann_storage_size = leann_results.get("storage_size", 0)
# Get Faiss storage size using Python
if os.path.exists("./storage_faiss"):
total_size = 0
for dirpath, _, filenames in os.walk("./storage_faiss"):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
print("Faiss HNSW:")
if "error" in faiss_results:
print(f" ❌ Failed: {faiss_results['error']}")
else:
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f" Storage Size: {faiss_storage_size:.1f} MB")
print("\nLEANN HNSW:")
if "error" in leann_results:
print(f" ❌ Failed: {leann_results['error']}")
else:
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f" Storage Size: {leann_storage_size:.1f} MB")
# Calculate improvements only if both tests succeeded
if "error" not in faiss_results and "error" not in leann_results:
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print(
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
)
# Storage comparison
if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size
print(
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
)
elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size
print(
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
)
else:
print(" Storage Size: similar")
else:
if "error" not in leann_results:
print("\n✅ LEANN HNSW completed successfully!")
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
if "error" not in faiss_results:
print("\n✅ Faiss HNSW completed successfully!")
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
if __name__ == "__main__":
main()

View File

Binary file not shown.

View File

Binary file not shown.

View File

File diff suppressed because it is too large Load Diff

82
examples/data/README.md Normal file
View File

@@ -0,0 +1,82 @@
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
各位好,
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
首先为自证身份,列举一些细节:
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
4. 诺亚曾经传说是研究型的但是来了之后因为在四野做大模型项目项目成员完全变成了交付型的且充满了例会评审汇报。很多时候做实验都要申请。团队需要对接终端小艺华为云ICT等诸多业务线交付压力不小。
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”一开始只有内部需要申请试用的网页版到后续迫于压力在welink上接入和公测开放。
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
华为确实主要在昇腾卡上训练大模型小模型实验室有不少英伟达的卡他们之前也会用来训练后面转移到昇腾。曾经我被华为“打造世界第二选择”的决心而折服我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打从充满bug到现在能训出模型付出了巨大的心血和代价。
最初我们的算力非常有限在910A上训练模型。那会只支持fp16训练的稳定性远不如bf16。盘古的moe开始很早23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型后面主力模型也逐渐在910B上训练。
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低每个单个的符号数字空格乃至汉字都会占用一个token。可想而知这会非常浪费算力且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好虽然事后来看他的怀疑是无疑正确的于是就决定让71B和135B换tokenizer因为小模型实验室曾经尝试过。团队缝合了两个tokenizer开始了tokenizer的更换。71B模型的更换失败了而135B因为采用了更精细的embedding初始化策略续训了至少1T的数据后词表总算更换成功但可想而知效果并不会变好。
于此同期阿里和智谱等国内其他公司在GPU上训练且已经摸索出了正确的方法盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时团队的士气低迷到了极点。团队在算力极其有限的时候做出了很多努力和挣扎。比如团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B架构相对落后团队进行了一系列的操作比如切换绝对位置编码到rope去掉bias切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训变成了第二代38B dense模型在几个月内这个模型都是主要的盘古中档位模型曾经具有一定的竞争力。但是由于更大的135B模型架构落后且更换词表模型损伤巨大后续分析发现当时更换的缝合词表有更严重的bug续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
在这种情况下王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来通过训练短短的几百B数据各项指标平均提升了十个点左右。实际上这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行使得领导完全对于这种扯淡的事情没有概念他们只会觉得肯定是有什么算法创新。经过内部的分析他们实际上是使用Qwen 1.5 110B续训而来通过加层扩增ffn维度添加盘古pi论文的一些机制得来凑够了大概135B的参数。实际上旧的135B有107层而这个模型只有82层各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游甚至包括外部客户。
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击内部很多人其实都知道这件事甚至包括终端和华为云。我们都戏称以后别叫盘古模型了叫千古吧。当时团队成员就想向bcg举报了毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来因为更高级别的领导比如姚老师以及可能熊总和查老其实后面也知道了但是并不管因为通过套壳拿出好的结果对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷离职跑路也逐渐成为挂在嘴边的事。
此时盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来当时诺亚完全没有掌握从头训练的技术何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下盘古开始了第三代模型的训练付出了巨大的努力后在数据架构和训练算法方面都与业界逐渐接轨而这其中的艰辛和小模型实验室的人一点关系都没有。
一开始团队成员毫无信心只从一个13B的模型开始训练但是后面发现效果还不错于是这个模型后续再次进行了一次参数扩增变成了第三代的38B代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的也是业界常见的做法。而当时王云鹤的实验室做出来了另一个词表也就是后续pangu系列的词表。当时两个词表还被迫进行了一次赛马最终没有明显的好坏结论。于是领导当即决定应该统一词表使用王云鹤他们的。于是在后续从头训练的135B V3也就是对外的Pangu Ultra便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑为什么当时同为V3代的两个不同档位的模型会使用不同的tokenizer。
我们打心眼里觉得135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的华为全栈自研正经从头训练的千亿级别的模型且效果与24年同期竞品可比的。写到这里我已经热泪盈眶太不容易了。当时为了稳定训练团队做了大量实验对比并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难我们做到了我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨我们为了它的训练而不眠。在被内部心声骂的一文不值的时候我们有多么不甘有多少的委屈我们挺住了。
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
然而我们的所有辛苦的成果经常被小模型实验室轻飘飘的拿走了。数据直接要走。代码直接要走还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦他们取得荣耀。果然应了那句话你在负重前行是因为有人替你岁月静好。在这种情况下越来越多的战友再也坚持不下去了选择了离开。看到身边那些优秀的同事一个个离职我的内心又感叹又难过。在这种作战一样的环境下我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方堪称良师。看到他们去了诸如字节SeedDeepseek月之暗面腾讯和快手等等很多出色的团队我打心眼里为他们高兴和祝福脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新ta说“来这里是我技术生涯中的耻辱在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足以及没法适应互联网公司高淘汰的环境让我多次想离职的心始终没有迈出这一步。
盘古除了dense模型后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的小模型实验室也开启了第二次主要的套壳行动次要的插曲可能还包括一些别的模型比如math模型即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的就算如此这也与技术报告不符何况是套壳qwen 2.5的14b续训。还记得他们训了没几天内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型都知道他们的套壳行动只是迫于各种原因无法伸张正义。实际上对于后续训了很久很久的这个模型Honestagi能够分析出这个量级的相似性我已经很诧异了因为这个模型为了续训洗参数所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印采取了不少办法甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
24年底和25年初在Deepseek v3和r1发布之后由于其惊艳的技术水平团队受到了巨大的冲击也受到了更大的质疑。于是为了紧跟潮流盘古模仿Deepseek的模型尺寸开启了718B moe的训练。这个时候小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数进行训练。连任务加载ckpt的目录都是deepseekv3改都不改何其嚣张与之相反一些有真正技术信仰的同事在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然这个模型怎么可能比直接套壳的好呢如果不是团队leader坚持早就被叫停了。
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
HonestAGI的事情出来后内部让大家不停的研讨分析如何公关和“回应”。诚然这个原文的分析也许不够有力给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此这两天我内心感到作呕时时怀疑自己的人生意义以及苍天无眼。我不奉陪了我要离职了同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到他们竟然猖狂到敢开源。我没想到他们敢如此愚弄世人大肆宣发。当时我也许是存了侥幸心理没有拒绝署名。我相信很多扎实做事的战友也只是被迫上了贼船或者不知情。但这件事已经无法挽回我希望我的余生能够坚持扎实做真正有意义的事为我当时的软弱和不坚定赎罪。
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
现在,我累了,我想投降。
其实时至今日我还是真心希望华为能认真吸取教训能做好盘古把盘古做到世界一流把昇腾变成英伟达的水平。内部的劣币驱逐良币使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着施展着他们的抱负才华为中美在AI的激烈竞赛中奉献力量。我时常感叹华为不是没有人才而是根本不知道怎么留住人才。如果给这些人合适的环境合适的资源更少的枷锁更少的政治斗争盘古何愁不成
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
如果我消失了就当是我为了真理和理想为了华为乃至中国能够更好地发展算力和AI而牺牲了吧我愿埋葬于那片曾经奋斗过的地方。
诺亚,再见
2025年7月6日凌晨 写于深圳
---
各位好,
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
我补充一些细节,以免某些人继续颠倒黑白。
关于135B V2小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后比如任务令表彰和及时激励因为不想继续支撑下游应用和模型迭代又把这个烫手山芋甩给了四纵。确实技高一筹直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型最终拿回了一个当时一个魔改的先进的千问。做大模型的人自己做的模型就像自己孩子一样熟悉不要把别人都当傻子。就像自家儿子出门一趟回来个别人家孩子。
盘古report的署名是不符合学术规范的。例如135B V3有不少有技术贡献的人因为作者名额数量限制劳动成果没有得到应有的回报团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶甚至是团队当时的精神支柱支撑着不少兄弟们继续留在诺亚。所谓的名额限制以及挂名了一些毫无技术贡献的人如一些小模型实验室的人让兄弟们何其心寒。
---
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317

View File

@@ -0,0 +1,124 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str = None) -> List[Path]:
"""
Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects.
"""
if root is None:
# Auto-detect user's mail path
home_dir = os.path.expanduser("~")
root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = []
for dirpath, dirnames, filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath))
return messages_dirs
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self, include_html: bool = False) -> None:
"""
Initialize.
Args:
include_html: Whether to include HTML content in the email body (default: False)
"""
self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
if part.get_content_type() == "text/html" and not self.include_html:
continue
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content with metadata embedded in text
doc_content = f"""
[EMAIL METADATA]
File: {filename}
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
[END METADATA]
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs

View File

@@ -0,0 +1,192 @@
"""
Mbox parser.
Contains simple parser for mbox files.
"""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
from fsspec import AbstractFileSystem
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
logger = logging.getLogger(__name__)
class MboxReader(BaseReader):
"""
Mbox parser.
Extract messages from mailbox files.
Returns string including date, subject, sender, receiver and
content for each message.
"""
DEFAULT_MESSAGE_FORMAT: str = (
"Date: {_date}\n"
"From: {_from}\n"
"To: {_to}\n"
"Subject: {_subject}\n"
"Content: {_content}"
)
def __init__(
self,
*args: Any,
max_count: int = 0,
message_format: str = DEFAULT_MESSAGE_FORMAT,
**kwargs: Any,
) -> None:
"""Init params."""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError(
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
)
super().__init__(*args, **kwargs)
self.max_count = max_count
self.message_format = message_format
def load_data(
self,
file: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse file into string."""
# Import required libraries
import mailbox
from email.parser import BytesParser
from email.policy import default
from bs4 import BeautifulSoup
if fs:
logger.warning(
"fs was specified but MboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
i = 0
results: List[str] = []
# Load file using mailbox
bytes_parser = BytesParser(policy=default).parse
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
# Iterate through all messages
for _, _msg in enumerate(mbox):
try:
msg: mailbox.mboxMessage = _msg
# Parse multipart messages
if msg.is_multipart():
for part in msg.walk():
ctype = part.get_content_type()
cdispo = str(part.get("Content-Disposition"))
if "attachment" in cdispo:
print(f"Attachment found: {part.get_filename()}")
if ctype == "text/plain" and "attachment" not in cdispo:
content = part.get_payload(decode=True) # decode
break
# Get plain message payload for non-multipart messages
else:
content = msg.get_payload(decode=True)
# Parse message HTML content and remove unneeded whitespace
soup = BeautifulSoup(content)
stripped_content = " ".join(soup.get_text().split())
# Format message to include date, sender, receiver and subject
msg_string = self.message_format.format(
_date=msg["date"],
_from=msg["from"],
_to=msg["to"],
_subject=msg["subject"],
_content=stripped_content,
)
# Add message string to results
results.append(msg_string)
except Exception as e:
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
# Increment counter and return if max count is met
i += 1
if self.max_count > 0 and i >= self.max_count:
break
return [Document(text=result, metadata=extra_info or {}) for result in results]
class EmlxMboxReader(MboxReader):
"""
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
Extends MboxReader to work with Apple Mail's .emlx format by:
1. Reading .emlx files from a directory
2. Converting them to mbox format in memory
3. Using the parent MboxReader's parsing logic
"""
def load_data(
self,
directory: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse .emlx files from directory into strings using MboxReader logic."""
import tempfile
import os
if fs:
logger.warning(
"fs was specified but EmlxMboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
# Find all .emlx files in the directory
emlx_files = list(directory.glob("*.emlx"))
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
if not emlx_files:
logger.warning(f"No .emlx files found in {directory}")
return []
# Create a temporary mbox file
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
temp_mbox_path = temp_mbox.name
# Convert .emlx files to mbox format
for emlx_file in emlx_files:
try:
# Read the .emlx file
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx format: first line is length, rest is email content
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1] # Skip the length line
# Write to mbox format (each message starts with "From " and ends with blank line)
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
except Exception as e:
logger.warning(f"Failed to process {emlx_file}: {e}")
continue
# Close the temporary file so MboxReader can read it
temp_mbox.close()
try:
# Use the parent MboxReader's logic to parse the mbox file
return super().load_data(Path(temp_mbox_path), extra_info, fs)
finally:
# Clean up temporary file
try:
os.unlink(temp_mbox_path)
except:
pass

151
examples/faiss_only.py Normal file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""Test only Faiss HNSW"""
import sys
import time
import psutil
import gc
import os
def get_memory_usage():
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = get_memory_usage()
diff = current_mem - self.start_mem
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
return peak_mem
def main():
try:
import faiss
except ImportError:
print("Faiss is not installed.")
print("Please install it with `uv pip install faiss-cpu`")
sys.exit(1)
from llama_index.core import (
SimpleDirectoryReader,
VectorStoreIndex,
StorageContext,
Settings,
node_parser,
Document,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial")
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
Settings.embed_model = embed_model
tracker.checkpoint("After embedding model setup")
d = 768
faiss_index = faiss.IndexHNSWFlat(d, 32)
faiss_index.hnsw.efConstruction = 64
tracker.checkpoint("After Faiss index creation")
documents = SimpleDirectoryReader(
"examples/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks using the same splitter as LEANN
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
tracker.checkpoint("After text splitter setup")
# Check if index already exists and try to load it
index_loaded = False
if os.path.exists("./storage_faiss"):
print("Loading existing Faiss HNSW index...")
try:
# Use the correct Faiss loading pattern from the example
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir="./storage_faiss"
)
from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context)
print(f"Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index")
index_loaded = True
except Exception as e:
print(f"Failed to load existing index: {e}")
print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index
import shutil
if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss")
if not index_loaded:
print("Building new Faiss HNSW index...")
# Use the correct Faiss building pattern from the example
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
transformations=[node_parser]
)
tracker.checkpoint("After index building")
# Save index to disk using the correct pattern
index.storage_context.persist(persist_dir="./storage_faiss")
tracker.checkpoint("After index saving")
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
query_engine = index.as_query_engine(similarity_top_k=20)
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
_ = query_engine.query(query)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Peak Memory: {peak_memory:.1f} MB")
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,281 @@
import os
import asyncio
import argparse
try:
import dotenv
dotenv.load_dotenv()
except ModuleNotFoundError:
# python-dotenv is not installed; skip loading environment variables
dotenv = None
from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
# dotenv.load_dotenv() # handled above if python-dotenv is available
# Default Chrome profile path
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
"""
Create LEANN index from multiple Chrome profile data sources.
Args:
profile_dirs: List of Path objects pointing to Chrome profile directories
index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process per profile
"""
print("Creating LEANN index from multiple Chrome profile data sources...")
# Load documents using ChromeHistoryReader from history_data
from history_data.history import ChromeHistoryReader
reader = ChromeHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Chrome profile directory
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
try:
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_count
)
if documents:
print(f"Loaded {len(documents)} history documents from {profile_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {profile_dir}")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} history chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
"""
Create LEANN index from Chrome history data.
Args:
profile_path: Path to the Chrome profile directory (optional, uses default if None)
index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process
"""
print("Creating LEANN index from Chrome history data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using ChromeHistoryReader from history_data
from history_data.history import ChromeHistoryReader
reader = ChromeHistoryReader()
documents = reader.load_data(
chrome_profile_path=profile_path,
max_count=max_count
)
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} history documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} history chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=32,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={
"temperature": 0.0,
"max_tokens": 1000
}
)
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')
parser.add_argument('--query', type=str, default=None,
help='Single query to run (default: runs example queries)')
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
help='Automatically find all Chrome profiles (default: True)')
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
print(f"Using Chrome profile: {args.chrome_profile}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Find Chrome profile directories
from history_data.history import ChromeHistoryReader
if args.auto_find_profiles:
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
if not profile_dirs:
print("No Chrome profiles found automatically. Exiting.")
return
else:
# Use single specified profile
profile_path = Path(args.chrome_profile)
if not profile_path.exists():
print(f"Chrome profile not found: {profile_path}")
return
profile_dirs = [profile_path]
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"What websites did I visit about machine learning?",
"Find my search history about programming"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,3 @@
from .history import ChromeHistoryReader
__all__ = ['ChromeHistoryReader']

View File

@@ -0,0 +1,176 @@
import sqlite3
import os
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChromeHistoryReader(BaseReader):
"""
Chrome browser history reader that extracts browsing data from SQLite database.
Reads Chrome history from the default Chrome profile location and creates documents
with embedded metadata similar to the email reader structure.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
Load Chrome history data from the default Chrome profile location.
Args:
input_dir: Not used for Chrome history (kept for compatibility)
**load_kwargs:
max_count (int): Maximum amount of history entries to read.
chrome_profile_path (str): Custom path to Chrome profile directory.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
# Default Chrome profile path on macOS
if chrome_profile_path is None:
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return docs
try:
# Connect to the Chrome history database
print(f"Connecting to database: {history_db_path}")
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
# Query to get browsing history with metadata (removed created_time column)
query = """
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
ORDER BY last_visit_time DESC
"""
print(f"Executing query on database: {history_db_path}")
cursor.execute(query)
rows = cursor.fetchall()
print(f"Query returned {len(rows)} rows")
count = 0
for row in rows:
if count >= max_count and max_count > 0:
break
last_visit, url, title, visit_count, typed_count, hidden = row
# Create document content with metadata embedded in text
doc_content = f"""
[BROWSING HISTORY METADATA]
URL: {url}
Title: {title}
Last Visit: {last_visit}
Visit Count: {visit_count}
Typed Count: {typed_count}
Hidden: {hidden}
[END METADATA]
Title: {title}
URL: {url}
Last visited: {last_visit}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
conn.close()
print(f"Loaded {len(docs)} Chrome history documents")
except Exception as e:
print(f"Error reading Chrome history: {e}")
return docs
return docs
@staticmethod
def find_chrome_profiles() -> List[Path]:
"""
Find all Chrome profile directories.
Returns:
List of Path objects pointing to Chrome profile directories
"""
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
profile_dirs = []
if not chrome_base_path.exists():
print(f"Chrome directory not found at: {chrome_base_path}")
return profile_dirs
# Find all profile directories
for profile_dir in chrome_base_path.iterdir():
if profile_dir.is_dir() and profile_dir.name != "System Profile":
history_path = profile_dir / "History"
if history_path.exists():
profile_dirs.append(profile_dir)
print(f"Found Chrome profile: {profile_dir}")
print(f"Found {len(profile_dirs)} Chrome profiles")
return profile_dirs
@staticmethod
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
"""
Export Chrome history to a text file using the same SQL query format.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
"""
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return
try:
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
query = """
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
ORDER BY last_visit_time DESC
LIMIT ?
"""
cursor.execute(query, (max_count,))
rows = cursor.fetchall()
with open(output_file, 'w', encoding='utf-8') as f:
for row in rows:
last_visit, url, title, visit_count, typed_count, hidden = row
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
conn.close()
print(f"Exported {len(rows)} history entries to {output_file}")
except Exception as e:
print(f"Error exporting Chrome history: {e}")

View File

@@ -0,0 +1,720 @@
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path
from typing import List, Any, Dict, Optional
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
from datetime import datetime
class WeChatHistoryReader(BaseReader):
"""
WeChat chat history reader that extracts chat data from exported JSON files.
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
and creates documents with embedded metadata similar to the Chrome history reader structure.
Also includes utilities for automatic WeChat chat history export.
"""
def __init__(self) -> None:
"""Initialize."""
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
def check_wechat_running(self) -> bool:
"""Check if WeChat is currently running."""
try:
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
return result.returncode == 0
except Exception:
return False
def install_wechattweak(self) -> bool:
"""Install WeChatTweak CLI tool."""
try:
# Create wechat-exporter directory if it doesn't exist
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
if not wechattweak_path.exists():
print("Downloading WeChatTweak CLI...")
subprocess.run([
"curl", "-L", "-o", str(wechattweak_path),
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
], check=True)
# Make executable
wechattweak_path.chmod(0o755)
# Install WeChatTweak
print("Installing WeChatTweak...")
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
return True
except Exception as e:
print(f"Error installing WeChatTweak: {e}")
return False
def restart_wechat(self):
"""Restart WeChat to apply WeChatTweak."""
try:
print("Restarting WeChat...")
subprocess.run(["pkill", "-f", "WeChat"], check=False)
time.sleep(2)
subprocess.run(["open", "-a", "WeChat"], check=True)
time.sleep(5) # Wait for WeChat to start
except Exception as e:
print(f"Error restarting WeChat: {e}")
def check_api_available(self) -> bool:
"""Check if WeChatTweak API is available."""
try:
result = subprocess.run([
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
], capture_output=True, text=True, timeout=5)
return result.returncode == 0 and result.stdout.strip()
except Exception:
return False
def _extract_readable_text(self, content: str) -> str:
"""
Extract readable text from message content, removing XML and system messages.
Args:
content: The raw message content (can be string or dict)
Returns:
Cleaned, readable text
"""
if not content:
return ""
# Handle dictionary content (like quoted messages)
if isinstance(content, dict):
# Extract text from dictionary structure
text_parts = []
if 'title' in content:
text_parts.append(str(content['title']))
if 'quoted' in content:
text_parts.append(str(content['quoted']))
if 'content' in content:
text_parts.append(str(content['content']))
if 'text' in content:
text_parts.append(str(content['text']))
if text_parts:
return " | ".join(text_parts)
else:
# If we can't extract meaningful text from dict, return empty
return ""
# Handle string content
if not isinstance(content, str):
return ""
# Remove common prefixes like "wxid_xxx:\n"
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If it's just XML or system message, return empty
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
return ""
return clean_content.strip()
def _is_text_message(self, content: str) -> bool:
"""
Check if a message contains readable text content.
Args:
content: The message content (can be string or dict)
Returns:
True if the message contains readable text, False otherwise
"""
if not content:
return False
# Handle dictionary content
if isinstance(content, dict):
# Check if dict has any readable text fields
text_fields = ['title', 'quoted', 'content', 'text']
for field in text_fields:
if field in content and content[field]:
return True
return False
# Handle string content
if not isinstance(content, str):
return False
# Skip image messages (contain XML with img tags)
if '<img' in content and 'cdnurl' in content:
return False
# Skip emoji messages (contain emoji XML tags)
if '<emoji' in content and 'productid' in content:
return False
# Skip voice messages
if '<voice' in content:
return False
# Skip video messages
if '<video' in content:
return False
# Skip file messages
if '<appmsg' in content and 'appid' in content:
return False
# Skip system messages (like "recalled a message")
if 'recalled a message' in content:
return False
# Check if there's actual readable text (not just XML or system messages)
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If after cleaning we have meaningful text, consider it readable
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
return True
return False
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
"""
Concatenate messages based on length and time rules.
Args:
messages: List of message dictionaries
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
overlap_messages: Number of messages to overlap between consecutive groups
Returns:
List of concatenated message groups
"""
if not messages:
return []
concatenated_groups = []
current_group = []
current_length = 0
last_timestamp = None
for message in messages:
# Extract message info
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
from_user = message.get('fromUser', '')
to_user = message.get('toUser', '')
is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Skip empty messages
if not readable_text.strip():
continue
# Check time window constraint (only if time_window_minutes != -1)
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
time_diff_minutes = (create_time - last_timestamp) / 60
if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group
if current_group:
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
else:
current_group = []
current_length = 0
# Check length constraint (only if max_length != -1)
message_length = len(readable_text)
if max_length != -1 and current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
else:
current_group = []
current_length = 0
# Add message to current group
current_group.append(message)
current_length += message_length
last_timestamp = create_time
# Add the last group if it exists
if current_group:
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
return concatenated_groups
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
"""
Create concatenated content from a group of messages.
Args:
message_group: Dictionary containing messages and metadata
contact_name: Name of the contact
Returns:
Formatted concatenated content
"""
messages = message_group['messages']
start_time = message_group['start_time']
end_time = message_group['end_time']
# Format timestamps
if start_time:
try:
start_timestamp = datetime.fromtimestamp(start_time)
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
start_time_str = str(start_time)
else:
start_time_str = "Unknown"
if end_time:
try:
end_timestamp = datetime.fromtimestamp(end_time)
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
end_time_str = str(end_time)
else:
end_time_str = "Unknown"
# Build concatenated message content
message_parts = []
for message in messages:
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Format individual message
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
sender = "Me" if is_sent_from_self else "Contact"
message_parts.append(f"[{time_str}] {sender}: {readable_text}")
concatenated_text = "\n".join(message_parts)
# Create final document content
doc_content = f"""
Contact: {contact_name}
Time Range: {start_time_str} - {end_time_str}
Messages ({len(messages)} messages, {message_group['total_length']} chars):
{concatenated_text}
"""
doc_content = f"""
Contact: {contact_name}
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
Load WeChat chat history data from exported JSON files.
Args:
input_dir: Directory containing exported WeChat JSON files
**load_kwargs:
max_count (int): Maximum amount of chat entries to read.
wechat_export_dir (str): Custom path to WeChat export directory.
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
concatenate_messages (bool): Whether to concatenate messages based on length rules.
max_length (int): Maximum length for concatenated message groups (default: 1000).
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
include_non_text = load_kwargs.get('include_non_text', False)
concatenate_messages = load_kwargs.get('concatenate_messages', False)
max_length = load_kwargs.get('max_length', 1000)
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
# Default WeChat export path
if wechat_export_dir is None:
wechat_export_dir = "./wechat_export_test"
if not os.path.exists(wechat_export_dir):
print(f"WeChat export directory not found at: {wechat_export_dir}")
return docs
try:
# Find all JSON files in the export directory
json_files = list(Path(wechat_export_dir).glob("*.json"))
print(f"Found {len(json_files)} WeChat chat history files")
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, 'r', encoding='utf-8') as f:
chat_data = json.load(f)
# Extract contact name from filename
contact_name = json_file.stem
if concatenate_messages:
# Filter messages to only include readable text messages
readable_messages = []
for message in chat_data:
try:
content = message.get('content', '')
if not include_non_text and not self._is_text_message(content):
continue
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
readable_messages.append(message)
except Exception as e:
print(f"Error processing message in {json_file}: {e}")
continue
# Concatenate messages based on rules
message_groups = self._concatenate_messages(
readable_messages,
max_length=-1,
time_window_minutes=-1,
overlap_messages=0 # Keep 2 messages overlap between groups
)
# Create documents from concatenated groups
for message_group in message_groups:
if count >= max_count and max_count > 0:
break
doc_content = self._create_concatenated_content(message_group, contact_name)
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
else:
# Original single-message processing
for message in chat_data:
if count >= max_count and max_count > 0:
break
# Extract message information
from_user = message.get('fromUser', '')
to_user = message.get('toUser', '')
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
is_sent_from_self = message.get('isSentFromSelf', False)
# Handle content that might be dict or string
try:
# Check if this is a readable text message
if not include_non_text and not self._is_text_message(content):
continue
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
except Exception as e:
# Skip messages that cause processing errors
print(f"Error processing message in {json_file}: {e}")
continue
# Convert timestamp to readable format
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
# Create document content with metadata header and contact info
doc_content = f"""
Contact: {contact_name}
Is sent from self: {is_sent_from_self}
Time: {time_str}
Message: {readable_text if readable_text else message_text}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error reading {json_file}: {e}")
continue
print(f"Loaded {len(docs)} WeChat chat documents")
except Exception as e:
print(f"Error reading WeChat history: {e}")
return docs
return docs
@staticmethod
def find_wechat_export_dirs() -> List[Path]:
"""
Find all WeChat export directories.
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for common export directory names
possible_dirs = [
Path("./wechat_export_test"),
Path("./wechat_export"),
Path("./wechat_chat_history"),
Path("./chat_export")
]
for export_dir in possible_dirs:
if export_dir.exists() and export_dir.is_dir():
json_files = list(export_dir.glob("*.json"))
if json_files:
export_dirs.append(export_dir)
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
print(f"Found {len(export_dirs)} WeChat export directories")
return export_dirs
@staticmethod
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
"""
Export WeChat chat history to a text file.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
export_dir: Directory containing WeChat JSON files
include_non_text: Whether to include non-text messages
"""
if export_dir is None:
export_dir = "./wechat_export_test"
if not os.path.exists(export_dir):
print(f"WeChat export directory not found at: {export_dir}")
return
try:
json_files = list(Path(export_dir).glob("*.json"))
with open(output_file, 'w', encoding='utf-8') as f:
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, 'r', encoding='utf-8') as json_f:
chat_data = json.load(json_f)
contact_name = json_file.stem
f.write(f"\n=== Chat with {contact_name} ===\n")
for message in chat_data:
if count >= max_count and max_count > 0:
break
from_user = message.get('fromUser', '')
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
# Skip non-text messages unless requested
if not include_non_text:
reader = WeChatHistoryReader()
if not reader._is_text_message(content):
continue
readable_text = reader._extract_readable_text(content)
if not readable_text:
continue
message_text = readable_text
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
f.write(f"[{time_str}] {from_user}: {message_text}\n")
count += 1
except Exception as e:
print(f"Error processing {json_file}: {e}")
continue
print(f"Exported {count} chat entries to {output_file}")
except Exception as e:
print(f"Error exporting WeChat chat history: {e}")
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
"""
Export WeChat chat history using wechat-exporter tool.
Args:
export_dir: Directory to save exported chat history
Returns:
Path to export directory if successful, None otherwise
"""
try:
import subprocess
import sys
# Create export directory
export_path = Path(export_dir)
export_path.mkdir(exist_ok=True)
print(f"Exporting WeChat chat history to {export_path}...")
# Check if wechat-exporter directory exists
if not self.wechat_exporter_dir.exists():
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
return None
# Install requirements if needed
requirements_file = self.wechat_exporter_dir / "requirements.txt"
if requirements_file.exists():
print("Installing wechat-exporter requirements...")
subprocess.run([
"uv", "pip", "install", "-r", str(requirements_file)
], check=True)
# Run the export command
print("Running wechat-exporter...")
result = subprocess.run([
sys.executable, str(self.wechat_exporter_dir / "main.py"),
"export-all", str(export_path)
], capture_output=True, text=True, check=True)
print("Export command output:")
print(result.stdout)
if result.stderr:
print("Export errors:")
print(result.stderr)
# Check if export was successful
if export_path.exists() and any(export_path.glob("*.json")):
json_files = list(export_path.glob("*.json"))
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
return export_path
else:
print("Export completed but no JSON files found")
return None
except subprocess.CalledProcessError as e:
print(f"Export command failed: {e}")
print(f"Command output: {e.stdout}")
print(f"Command errors: {e.stderr}")
return None
except Exception as e:
print(f"Export failed: {e}")
print("Please ensure WeChat is running and WeChatTweak is installed.")
return None
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
"""
Find existing WeChat exports or create new ones.
Args:
export_dir: Directory to save exported chat history if needed
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for existing exports in common locations
possible_export_dirs = [
Path("./wechat_database_export"),
Path("./wechat_export_test"),
Path("./wechat_export"),
Path("./wechat_export_direct"),
Path("./wechat_chat_history"),
Path("./chat_export")
]
for export_dir_path in possible_export_dirs:
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
export_dirs.append(export_dir_path)
print(f"Found existing export: {export_dir_path}")
# If no existing exports, try to export automatically
if not export_dirs:
print("No existing WeChat exports found. Starting direct export...")
# Try to export using wechat-exporter
exported_path = self.export_wechat_chat_history(export_dir)
if exported_path:
export_dirs = [exported_path]
else:
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
return export_dirs

View File

@@ -0,0 +1,286 @@
import os
import sys
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any
# Add the project root to Python path so we can import from examples
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv()
# Auto-detect user's mail path
def get_mail_path():
"""Get the mail path for the current user"""
home_dir = os.path.expanduser("~")
return os.path.join(home_dir, "Library", "Mail")
# Default mail path for macOS
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
Create LEANN index from multiple mail data sources.
Args:
messages_dirs: List of Path objects pointing to Messages directories
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process per directory
include_html: Whether to include HTML content in email processing
"""
print("Creating LEANN index from multiple mail data sources...")
# Load documents using EmlxReader from LEANN_email_reader
from examples.email_data.LEANN_email_reader import EmlxReader
reader = EmlxReader(include_html=include_html)
# from email_data.email import EmlxMboxReader
# from pathlib import Path
# reader = EmlxMboxReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Messages directory
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
try:
documents = reader.load_data(messages_dir)
if documents:
print(f"Loaded {len(documents)} email documents from {messages_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {messages_dir}")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
Create LEANN index from mail data.
Args:
mail_path: Path to the mail directory
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process
include_html: Whether to include HTML content in email processing
"""
print("Creating LEANN index from mail data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using EmlxReader from LEANN_email_reader
from examples.email_data.LEANN_email_reader import EmlxReader
reader = EmlxReader(include_html=include_html)
# from email_data.email import EmlxMboxReader
# from pathlib import Path
# reader = EmlxMboxReader()
documents = reader.load_data(Path(mail_path))
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} email documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path,
llm_config={"type": "openai", "model": "gpt-4o"})
print(f"You: {query}")
import time
start_time = time.time()
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=12,
beam_width=1,
)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
# Remove --mail-path argument and auto-detect all Messages directories
# Remove DEFAULT_MAIL_PATH
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
help='Single query to run (default: runs example queries)')
parser.add_argument('--include-html', action='store_true', default=False,
help='Include HTML content in email processing (default: False)')
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
help='Embedding model to use (default: facebook/contriever)')
args = parser.parse_args()
print(f"args: {args}")
# Automatically find all Messages directories under the current user's Mail directory
from examples.email_data.LEANN_email_reader import find_all_messages_directories
mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path)
print('len(messages_dirs): ', len(messages_dirs))
if not messages_dirs:
print("No Messages directories found. Exiting.")
return
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
print(f"Index directory: {INDEX_DIR}")
print(f"Found {len(messages_dirs)} Messages directories.")
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,108 @@
import os
import sys
import argparse
from pathlib import Path
from typing import List, Any
# Add the project root to Python path so we can import from examples
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.node_parser import SentenceSplitter
# --- EMBEDDING MODEL ---
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import torch
# --- END EMBEDDING MODEL ---
# Import EmlxReader from the new module
from examples.email_data.LEANN_email_reader import EmlxReader
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000, include_html: bool = False):
print("Creating index from mail data with embedded metadata...")
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Use facebook/contriever as the embedder
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
# set on device
import torch
if torch.cuda.is_available():
embed_model._model.to("cuda")
# set mps
elif torch.backends.mps.is_available():
embed_model._model.to("mps")
else:
embed_model._model.to("cpu")
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter],
embed_model=embed_model
)
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index_embedded"):
try:
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LlamaIndex Mail Reader - Create and query email index')
parser.add_argument('--mail-path', type=str,
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
help='Path to mail data directory')
parser.add_argument('--save-dir', type=str, default="mail_index_embedded",
help='Directory to store the index (default: mail_index_embedded)')
parser.add_argument('--max-emails', type=int, default=10000,
help='Maximum number of emails to process')
parser.add_argument('--include-html', action='store_true', default=False,
help='Include HTML content in email processing (default: False)')
args = parser.parse_args()
mail_path = args.mail_path
save_dir = args.save_dir
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
index = load_index(save_dir)
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html)
if index:
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*50)
query_index(index, query)
if __name__ == "__main__":
main()

View File

@@ -1,13 +1,7 @@
import faulthandler
faulthandler.enable()
import argparse
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.readers.base import BaseReader
from llama_index.node_parser.docling import DoclingNodeParser
from llama_index.readers.docling import DoclingReader
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
from llama_index.core.node_parser import SentenceSplitter
import asyncio
import os
import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat
import shutil
@@ -15,23 +9,15 @@ from pathlib import Path
dotenv.load_dotenv()
reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)
file_extractor: dict[str, BaseReader] = {
".docx": reader,
".pptx": reader,
".pdf": reader,
".xlsx": reader,
}
node_parser = DoclingNodeParser(
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64)
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print("Loading documents...")
documents = SimpleDirectoryReader(
"examples/data",
recursive=True,
file_extractor=file_extractor,
"examples/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
@@ -40,41 +26,85 @@ for doc in documents:
for node in nodes:
all_texts.append(node.get_content())
INDEX_DIR = Path("./test_pdf_index")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
print(f"\n[PHASE 1] Building Leann index...")
async def main(args):
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
# CSR compact mode with recompute
builder = LeannBuilder(
backend_name="diskann",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True
)
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
print(f"Loaded {len(all_texts)} text chunks from documents.")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
print(f"\nLeann index built at {INDEX_PATH}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Loaded {len(all_texts)} text chunks from documents.")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
print(f"\nLeann index built at {INDEX_PATH}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
async def main():
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
llm_config = {"type": "ollama", "model": "qwen3:8b"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# )
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
chat_response = chat.ask(
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
asyncio.run(main())
parser = argparse.ArgumentParser(
description="Run Leann Chat with various LLM backends."
)
parser.add_argument(
"--llm",
type=str,
default="hf",
choices=["simulated", "ollama", "hf", "openai"],
help="The LLM backend to use.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-0.6B",
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
)
parser.add_argument(
"--host",
type=str,
default="http://localhost:11434",
help="The host for the Ollama API.",
)
parser.add_argument(
"--index-dir",
type=str,
default="./test_doc_files",
help="Directory where the Leann index will be stored.",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Multi-Vector Aggregator for Fat Embeddings
==========================================
This module implements aggregation strategies for multi-vector embeddings,
similar to ColPali's approach where multiple patch vectors represent a single document.
Key features:
- MaxSim aggregation (take maximum similarity across patches)
- Voting-based aggregation (count patch matches)
- Weighted aggregation (attention-score weighted)
- Spatial clustering of matching patches
- Document-level result consolidation
"""
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import json
@dataclass
class PatchResult:
"""Represents a single patch search result."""
patch_id: int
image_name: str
image_path: str
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
score: float
attention_score: float
scale: float
metadata: Dict[str, Any]
@dataclass
class AggregatedResult:
"""Represents an aggregated document-level result."""
image_name: str
image_path: str
doc_score: float
patch_count: int
best_patch: PatchResult
all_patches: List[PatchResult]
aggregation_method: str
spatial_clusters: Optional[List[List[PatchResult]]] = None
class MultiVectorAggregator:
"""
Aggregates multiple patch-level results into document-level results.
"""
def __init__(self,
aggregation_method: str = "maxsim",
spatial_clustering: bool = True,
cluster_distance_threshold: float = 100.0):
"""
Initialize the aggregator.
Args:
aggregation_method: "maxsim", "voting", "weighted", or "mean"
spatial_clustering: Whether to cluster spatially close patches
cluster_distance_threshold: Distance threshold for spatial clustering
"""
self.aggregation_method = aggregation_method
self.spatial_clustering = spatial_clustering
self.cluster_distance_threshold = cluster_distance_threshold
def aggregate_results(self,
search_results: List[Dict[str, Any]],
top_k: int = 10) -> List[AggregatedResult]:
"""
Aggregate patch-level search results into document-level results.
Args:
search_results: List of search results from LeannSearcher
top_k: Number of top documents to return
Returns:
List of aggregated document results
"""
# Group results by image
image_groups = defaultdict(list)
for result in search_results:
metadata = result.metadata
if "image_name" in metadata and "patch_id" in metadata:
patch_result = PatchResult(
patch_id=metadata["patch_id"],
image_name=metadata["image_name"],
image_path=metadata["image_path"],
coordinates=tuple(metadata["coordinates"]),
score=result.score,
attention_score=metadata.get("attention_score", 0.0),
scale=metadata.get("scale", 1.0),
metadata=metadata
)
image_groups[metadata["image_name"]].append(patch_result)
# Aggregate each image group
aggregated_results = []
for image_name, patches in image_groups.items():
if len(patches) == 0:
continue
agg_result = self._aggregate_image_patches(image_name, patches)
aggregated_results.append(agg_result)
# Sort by aggregated score and return top-k
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
return aggregated_results[:top_k]
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
"""Aggregate patches for a single image."""
if self.aggregation_method == "maxsim":
doc_score = max(patch.score for patch in patches)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "voting":
# Count patches above threshold
threshold = np.percentile([p.score for p in patches], 75)
doc_score = sum(1 for patch in patches if patch.score >= threshold)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "weighted":
# Weight by attention scores
total_weighted_score = sum(p.score * p.attention_score for p in patches)
total_weights = sum(p.attention_score for p in patches)
doc_score = total_weighted_score / max(total_weights, 1e-8)
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
elif self.aggregation_method == "mean":
doc_score = np.mean([patch.score for patch in patches])
best_patch = max(patches, key=lambda p: p.score)
else:
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
# Spatial clustering if enabled
spatial_clusters = None
if self.spatial_clustering:
spatial_clusters = self._cluster_patches_spatially(patches)
return AggregatedResult(
image_name=image_name,
image_path=patches[0].image_path,
doc_score=float(doc_score),
patch_count=len(patches),
best_patch=best_patch,
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
aggregation_method=self.aggregation_method,
spatial_clusters=spatial_clusters
)
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
"""Cluster patches that are spatially close to each other."""
if len(patches) <= 1:
return [patches]
clusters = []
remaining_patches = patches.copy()
while remaining_patches:
# Start new cluster with highest scoring remaining patch
seed_patch = max(remaining_patches, key=lambda p: p.score)
current_cluster = [seed_patch]
remaining_patches.remove(seed_patch)
# Add nearby patches to cluster
added_to_cluster = True
while added_to_cluster:
added_to_cluster = False
for patch in remaining_patches.copy():
if self._is_patch_nearby(patch, current_cluster):
current_cluster.append(patch)
remaining_patches.remove(patch)
added_to_cluster = True
clusters.append(current_cluster)
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
"""Check if a patch is spatially close to any patch in the cluster."""
patch_center = self._get_patch_center(patch.coordinates)
for cluster_patch in cluster:
cluster_center = self._get_patch_center(cluster_patch.coordinates)
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
(patch_center[1] - cluster_center[1])**2)
if distance <= self.cluster_distance_threshold:
return True
return False
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Get center point of a patch."""
x1, y1, x2, y2 = coordinates
return ((x1 + x2) / 2, (y1 + y2) / 2)
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
"""Pretty print aggregated results."""
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
print("=" * 80)
for i, result in enumerate(results):
print(f"\n{i+1}. {result.image_name}")
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
print(f" Path: {result.image_path}")
# Show best patch
best = result.best_patch
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
# Show top patches
print(f" 📍 Top Patches:")
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
# Show spatial clusters if available
if result.spatial_clusters and len(result.spatial_clusters) > 1:
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
cluster_score = max(p.score for p in cluster)
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
def demo_aggregation():
"""Demonstrate the multi-vector aggregation functionality."""
print("=== Multi-Vector Aggregation Demo ===")
# Simulate some patch-level search results
# In real usage, these would come from LeannSearcher.search()
class MockResult:
def __init__(self, score, metadata):
self.score = score
self.metadata = metadata
# Simulate results for 2 images with multiple patches each
mock_results = [
# Image 1: cats_and_kitchen.jpg - 4 patches
MockResult(0.85, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 3,
"coordinates": [100, 50, 224, 174], # Kitchen area
"attention_score": 0.92,
"scale": 1.0
}),
MockResult(0.78, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 7,
"coordinates": [200, 300, 324, 424], # Cat area
"attention_score": 0.88,
"scale": 1.0
}),
MockResult(0.72, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 12,
"coordinates": [150, 100, 274, 224], # Appliances
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.65, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 15,
"coordinates": [50, 250, 174, 374], # Furniture
"attention_score": 0.70,
"scale": 1.0
}),
# Image 2: city_street.jpg - 3 patches
MockResult(0.68, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 2,
"coordinates": [300, 100, 424, 224], # Buildings
"attention_score": 0.80,
"scale": 1.0
}),
MockResult(0.62, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 8,
"coordinates": [100, 350, 224, 474], # Street level
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.55, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 11,
"coordinates": [400, 200, 524, 324], # Sky area
"attention_score": 0.60,
"scale": 1.0
}),
]
# Test different aggregation methods
methods = ["maxsim", "voting", "weighted", "mean"]
for method in methods:
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
aggregator = MultiVectorAggregator(
aggregation_method=method,
spatial_clustering=True,
cluster_distance_threshold=100.0
)
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
aggregator.print_aggregated_results(aggregated)
if __name__ == "__main__":
demo_aggregation()

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
"""
OpenAI Embedding Example
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
"""
import os
import dotenv
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher
# Load environment variables
dotenv.load_dotenv()
def main():
# Check if OpenAI API key is available
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("ERROR: OPENAI_API_KEY environment variable not set")
return False
print(f"✅ OpenAI API key found: {api_key[:10]}...")
# Sample texts
sample_texts = [
"Machine learning is a powerful technology that enables computers to learn from data.",
"Natural language processing helps computers understand and generate human language.",
"Deep learning uses neural networks with multiple layers to solve complex problems.",
"Computer vision allows machines to interpret and understand visual information.",
"Reinforcement learning trains agents to make decisions through trial and error.",
"Data science combines statistics, math, and programming to extract insights from data.",
"Artificial intelligence aims to create machines that can perform human-like tasks.",
"Python is a popular programming language used extensively in data science and AI.",
"Neural networks are inspired by the structure and function of the human brain.",
"Big data refers to extremely large datasets that require special tools to process."
]
INDEX_DIR = Path("./simple_openai_test_index")
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
print(f"\n=== Building Index with OpenAI Embeddings ===")
print(f"Index path: {INDEX_PATH}")
try:
# Use proper configuration for OpenAI embeddings
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# HNSW settings for OpenAI embeddings
M=16, # Smaller graph degree
efConstruction=64, # Smaller construction complexity
is_compact=True, # Enable compact storage for recompute
is_recompute=True, # MUST enable for OpenAI embeddings
num_threads=1,
)
print(f"Adding {len(sample_texts)} texts to the index...")
for i, text in enumerate(sample_texts):
metadata = {"id": f"doc_{i}", "topic": "AI"}
builder.add_text(text, metadata)
print("Building index...")
builder.build_index(INDEX_PATH)
print(f"✅ Index built successfully!")
except Exception as e:
print(f"❌ Error building index: {e}")
import traceback
traceback.print_exc()
return False
print(f"\n=== Testing Search ===")
try:
searcher = LeannSearcher(INDEX_PATH)
test_queries = [
"What is machine learning?",
"How do neural networks work?",
"Programming languages for data science"
]
for query in test_queries:
print(f"\n🔍 Query: '{query}'")
results = searcher.search(query, top_k=3)
print(f" Found {len(results)} results:")
for i, result in enumerate(results):
print(f" {i+1}. Score: {result.score:.4f}")
print(f" Text: {result.text[:80]}...")
print(f"\n✅ Search test completed successfully!")
return True
except Exception as e:
print(f"❌ Error during search: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = main()
if success:
print(f"\n🎉 Simple OpenAI index test completed successfully!")
else:
print(f"\n💥 Simple OpenAI index test failed!")

18
examples/resue_index.py Normal file
View File

@@ -0,0 +1,18 @@
import asyncio
from leann.api import LeannChat
from pathlib import Path
INDEX_DIR = Path("./test_pdf_index_huawei")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
async def main():
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
print(f"\n[PHASE 2] Response: {response}")
if __name__ == "__main__":
asyncio.run(main())

382
examples/run_evaluation.py Normal file
View File

@@ -0,0 +1,382 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import json
import argparse
import time
from pathlib import Path
import sys
import numpy as np
from typing import List
from leann.api import LeannSearcher, LeannBuilder
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print(
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try:
from huggingface_hub import snapshot_download
if download_embeddings:
# Download everything including embeddings (large files)
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
)
print("Data download complete (including embeddings)!")
else:
# Download only specific folders, excluding embeddings
allow_patterns = [
"ground_truth/**",
"indices/**",
"queries/**",
"*.md",
"*.txt",
]
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
print("Data download complete (excluding embeddings)!")
except ImportError:
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
if dataset_type:
# Check if specific dataset embeddings exist
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
print(f"Embeddings for {dataset_type} already exist")
return str(target_file)
print("Downloading embeddings from HuggingFace Hub...")
try:
from huggingface_hub import snapshot_download
# Download only embeddings folder
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=["embeddings/**/*.pkl"],
)
print("Embeddings download complete!")
if dataset_type:
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
return str(target_file)
return str(embeddings_dir)
except Exception as e:
print(f"Error downloading embeddings: {e}")
sys.exit(1)
# --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
"""
Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager.
"""
golden_texts = set()
for gid in golden_ids:
try:
# PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"])
except KeyError:
print(
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data["query"])
return queries
def build_index_from_embeddings(
embeddings_file: str, output_path: str, backend: str = "hnsw"
):
"""
Build a LEANN index from pre-computed embeddings.
Args:
embeddings_file: Path to pickle file with (ids, embeddings) tuple
output_path: Path where to save the index
backend: Backend to use ("hnsw" or "diskann")
"""
print(f"Building {backend} index from embeddings: {embeddings_file}")
# Create builder with appropriate parameters
if backend == "hnsw":
builder_kwargs = {
"M": 32, # Graph degree
"efConstruction": 256, # Construction complexity
"is_compact": True, # Use compact storage
"is_recompute": True, # Enable pruning for better recall
}
elif backend == "diskann":
builder_kwargs = {
"complexity": 64,
"graph_degree": 32,
"search_memory_maximum": 8.0, # GB
"build_memory_maximum": 16.0, # GB
}
else:
builder_kwargs = {}
builder = LeannBuilder(
backend_name=backend,
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
dimensions=768, # Will be auto-detected from embeddings
**builder_kwargs,
)
# Build index from precomputed embeddings
builder.build_index_from_embeddings(output_path, embeddings_file)
print(f"Index saved to: {output_path}")
return output_path
def main():
parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index."
)
parser.add_argument(
"index_path",
type=str,
nargs="?",
help="Path to the LEANN index to evaluate or build (optional).",
)
parser.add_argument(
"--mode",
choices=["evaluate", "build"],
default="evaluate",
help="Mode: 'evaluate' existing index or 'build' from embeddings",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Path to embeddings pickle file (optional for build mode)",
)
parser.add_argument(
"--backend",
choices=["hnsw", "diskann"],
default="hnsw",
help="Backend to use for building index (default: hnsw)",
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument(
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
args = parser.parse_args()
# --- Path Configuration ---
# Assumes a project structure where the script is in 'examples/'
# and data is in 'data/' at the project root.
project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data"
# Download data based on mode
if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(
data_root, download_embeddings=False
) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
embeddings_file = args.embeddings_file
# Try to detect dataset type from embeddings file path
if "rpj_wiki" in str(embeddings_file):
dataset_type = "rpj_wiki"
elif "dpr" in str(embeddings_file):
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default
else:
# Auto-detect from index path if provided, otherwise default to DPR
if args.index_path:
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default to DPR
else:
dataset_type = "dpr" # Default to DPR
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
# Auto-generate index path if not provided
if not args.index_path:
indices_dir = data_root / "indices" / dataset_type
indices_dir.mkdir(parents=True, exist_ok=True)
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
print(f"Auto-generated index path: {args.index_path}")
print(f"Building index from embeddings: {embeddings_file}")
built_index_path = build_index_from_embeddings(
embeddings_file, args.index_path, args.backend
)
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = (
input("Run evaluation on the built index? (y/n): ").strip().lower()
)
if eval_response != "y":
print("Index building complete. Exiting.")
return
else:
# For evaluation mode, don't need embeddings
download_data_if_needed(data_root, download_embeddings=False)
# Auto-detect index path if not provided
if not args.index_path:
# Default to using downloaded indices
indices_dir = data_root / "indices"
# Try common datasets in order of preference
for dataset in ["dpr", "rpj_wiki"]:
dataset_dir = indices_dir / dataset
if dataset_dir.exists():
# Look for index files
index_files = list(dataset_dir.glob("*.index")) + list(
dataset_dir.glob("*_disk.index")
)
if index_files:
args.index_path = str(
index_files[0].with_suffix("")
) # Remove .index extension
print(f"Using index: {args.index_path}")
break
if not args.index_path:
print(
"No indices found. The data download should have included pre-built indices."
)
print(
"Please check the data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
# Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = (
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
print(f"INFO: Using ground truth file: {golden_results_file}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file)
with open(golden_results_file, "r") as f:
golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
# Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,318 @@
import os
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any, Optional
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
import requests
import time
dotenv.load_dotenv()
# Default WeChat export directory
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
def create_leann_index_from_multiple_wechat_exports(
export_dirs: List[Path],
index_path: str = "wechat_history_index.leann",
max_count: int = -1,
):
"""
Create LEANN index from multiple WeChat export data sources.
Args:
export_dirs: List of Path objects pointing to WeChat export directories
index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process per export
"""
print("Creating LEANN index from multiple WeChat export data sources...")
# Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each WeChat export directory
for i, export_dir in enumerate(export_dirs):
print(
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
)
try:
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_count,
concatenate_messages=False, # Disable concatenation - one message per document
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
)
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=64)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
)
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="Qwen/Qwen3-Embedding-0.6B",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(
export_dir: str = None,
index_path: str = "wechat_history_index.leann",
max_count: int = 1000,
):
"""
Create LEANN index from WeChat chat history data.
Args:
export_dir: Path to the WeChat export directory (optional, uses default if None)
index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process
"""
print("Creating LEANN index from WeChat chat history data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
documents = reader.load_data(
wechat_export_dir=export_dir,
max_count=max_count,
concatenate_messages=False, # Disable concatenation - one message per document
)
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} chat documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=128,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
)
print(f"Leann: {chat_response}")
async def main():
"""Main function with integrated WeChat export functionality."""
# Parse command line arguments
parser = argparse.ArgumentParser(
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
)
parser.add_argument(
"--export-dir",
type=str,
default=DEFAULT_WECHAT_EXPORT_DIR,
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
)
parser.add_argument(
"--index-dir",
type=str,
default="./wechat_history_june19_test",
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=5000,
help="Maximum number of chat entries to process (default: 5000)",
)
parser.add_argument(
"--query",
type=str,
default=None,
help="Single query to run (default: runs example queries)",
)
parser.add_argument(
"--force-export",
action="store_true",
default=False,
help="Force re-export of WeChat data even if exports exist",
)
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
print(f"Using WeChat export directory: {args.export_dir}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Initialize WeChat reader with export capabilities
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Exiting.")
return
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_wechat_exports(
export_dirs, INDEX_PATH, max_count=args.max_entries
)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
]
for query in queries:
print("\n" + "=" * 60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,32 +0,0 @@
{
"version": "0.1.0",
"backend_name": "diskann",
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
"num_chunks": 6,
"chunks": [
{
"text": "Python is a powerful programming language",
"metadata": {}
},
{
"text": "Machine learning transforms industries",
"metadata": {}
},
{
"text": "Neural networks process complex data",
"metadata": {}
},
{
"text": "Java is a powerful programming language",
"metadata": {}
},
{
"text": "C++ is a powerful programming language",
"metadata": {}
},
{
"text": "C# is a powerful programming language",
"metadata": {}
}
]
}

1
packages/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -1,8 +1,8 @@
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
cmake_minimum_required(VERSION 3.20)
project(leann_backend_diskann_wrapper)
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
# DiskANN will handle everything itself, including compiling Python bindings
add_subdirectory(src/third_party/DiskANN)

View File

@@ -0,0 +1 @@
# This file makes the directory a Python package

View File

@@ -1,33 +1,30 @@
import numpy as np
import os
import json
import struct
from pathlib import Path
from typing import Dict, Any, List
from typing import Dict, Any, List, Literal
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface
LeannBackendSearcherInterface,
)
def _get_diskann_metrics():
from . import _diskannpy as diskannpy
from . import _diskannpy as diskannpy # type: ignore
return {
"mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE,
}
@contextlib.contextmanager
def chdir(path):
original_dir = os.getcwd()
@@ -37,13 +34,15 @@ def chdir(path):
finally:
os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
num_vectors, dim = data.shape
with open(file_path, 'wb') as f:
f.write(struct.pack('I', num_vectors))
f.write(struct.pack('I', dim))
with open(file_path, "wb") as f:
f.write(struct.pack("I", num_vectors))
f.write(struct.pack("I", dim))
f.write(data.tobytes())
@register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod
@@ -52,211 +51,164 @@ class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f:
meta = json.load(f)
# Pass essential metadata to the searcher
kwargs['meta'] = meta
return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(metric_str)
metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
raise ValueError("Unsupported distance_metric.")
complexity = build_kwargs.get("complexity", 64)
graph_degree = build_kwargs.get("graph_degree", 32)
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = ""
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
from . import _diskannpy as diskannpy
from . import _diskannpy as diskannpy # type: ignore
with chdir(index_dir):
diskannpy.build_disk_float_index(
metric_enum,
data_filename,
index_prefix,
complexity,
graph_degree,
final_index_ram_limit,
indexing_ram_budget,
num_threads,
pq_disk_bytes,
codebook_prefix
build_kwargs.get("complexity", 64),
build_kwargs.get("graph_degree", 32),
build_kwargs.get("search_memory_maximum", 4.0),
build_kwargs.get("build_memory_maximum", 8.0),
build_kwargs.get("num_threads", 8),
build_kwargs.get("pq_disk_bytes", 0),
"",
)
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise
finally:
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
class DiskannSearcher(LeannBackendSearcherInterface):
class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
self.meta = kwargs.get("meta", {})
if not self.meta:
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
super().__init__(
index_path,
backend_module_name="leann_backend_diskann.embedding_server",
**kwargs,
)
from . import _diskannpy as diskannpy # type: ignore
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
# Extract parameters for DiskANN
distance_metric = kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(distance_metric)
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
self.num_threads = kwargs.get("num_threads", 8)
self.zmq_port = kwargs.get("zmq_port", 6666)
try:
from . import _diskannpy as diskannpy
full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
)
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_diskann.embedding_server"
)
print("✅ DiskANN index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256)
beam_width = kwargs.get("beam_width", 4)
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
skip_search_reorder = kwargs.get("skip_search_reorder", False)
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
dedup_node_dis = kwargs.get("dedup_node_dis", False)
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_model:
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
self.zmq_port,
"",
"",
)
passages_file = kwargs.get("passages_file")
if not passages_file:
# Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else:
raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.")
def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""
Search for nearest neighbors using DiskANN index.
server_started = self.embedding_server_manager.start_server(
port=self.zmq_port,
model_name=self.embedding_model,
distance_metric=kwargs.get("distance_metric", "mips"),
passages_file=passages_file
Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server
pruning_strategy: PQ candidate selection strategy:
- "global": Use global pruning strategy (default)
- "local": Use local pruning strategy
- "proportional": Not supported in DiskANN, falls back to global
zmq_port: ZMQ port for embedding server
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":
raise NotImplementedError(
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
)
if not server_started:
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
try:
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
USE_DEFERRED_FETCH,
skip_search_reorder,
recompute_beighbor_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
global_pruning
)
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)],
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()
# Map pruning_strategy to DiskANN's global_pruning parameter
if pruning_strategy == "local":
use_global_pruning = False
else: # "global"
use_global_pruning = True
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False),
use_recompute,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning,
)
string_labels = [
[str(int_label) for int_label in batch_labels]
for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -5,7 +5,6 @@ Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
import pickle
import argparse
import threading
import time
import json
from typing import Dict, Any, Optional, Union
@@ -15,8 +14,19 @@ import os
from contextlib import contextmanager
import zmq
import numpy as np
import msgpack
from pathlib import Path
import logging
RED = "\033[91m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
RESET = "\033[0m"
# --- New Passage Loader from HNSW backend ---
@@ -26,6 +36,7 @@ class SimplePassageLoader:
"""
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {}
self._meta_path = ''
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID"""
@@ -38,14 +49,69 @@ class SimplePassageLoader:
def __len__(self) -> int:
return len(self.passages_data)
def keys(self):
return self.passages_data.keys()
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""
Load passages using metadata file with PassageManager for lazy loading
"""
# Load metadata to get passage sources
with open(meta_file, 'r') as f:
meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports
import sys
from pathlib import Path
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
finally:
sys.path.pop(0)
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager):
self.passage_manager = passage_manager
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
string_id = str(int_id)
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
except Exception as e:
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
def __len__(self) -> int:
return len(self.passage_manager.global_offset_map)
def keys(self):
return self.passage_manager.global_offset_map.keys()
loader = LazyPassageLoader(passage_manager)
loader._meta_path = meta_file
return loader
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
from pathlib import Path
import pickle
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
@@ -53,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
# Load passages directly by their sequential IDs
passages_data = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
passages_data[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread(
@@ -89,15 +135,17 @@ def create_embedding_server_thread(
model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128,
passages_file: Optional[str] = None,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
在当前线程中创建并运行 embedding server
这个函数设计为在单独的线程中调用
Create and run embedding server in the current thread
This function is designed to be called in a separate thread
"""
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
logger.info(f"Initializing embedding server thread on port {zmq_port}")
try:
# 检查端口是否已被占用
# Check if port is already occupied
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -107,55 +155,147 @@ def create_embedding_server_thread(
print(f"{RED}Port {zmq_port} is already in use{RESET}")
return
# 初始化模型
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
import torch
# 选择设备
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
embedding_mode = "openai"
if cuda_available:
device = torch.device("cuda")
print("INFO: Using CUDA device")
elif mps_available:
device = torch.device("mps")
print("INFO: Using MPS device (Apple Silicon)")
else:
if embedding_mode == "mlx":
from leann.api import compute_embeddings_mlx
import torch
logger.info("Using MLX for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
print("INFO: Using CPU device")
# 加载模型
print(f"INFO: Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
cuda_available = False
mps_available = False
elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai
import torch
logger.info("Using OpenAI API for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
mps_available = False
elif embedding_mode == "sentence-transformers":
# Initialize model
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
import torch
# 优化模型
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# Select device
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
if cuda_available:
device = torch.device("cuda")
logger.info("Using CUDA device")
elif mps_available:
device = torch.device("mps")
logger.info("Using MPS device (Apple Silicon)")
else:
device = torch.device("cpu")
logger.info("Using CPU device")
# Load model
logger.info(f"Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
# Optimize model
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
logger.info(f"Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
else:
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
# Load passages from file if provided
if passages_file and os.path.exists(passages_file):
passages = load_passages_from_file(passages_file)
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = load_passages_from_file(passages_file)
else:
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()
print(f"INFO: Loaded {len(passages)} passages.")
logger.info(f"Loaded {len(passages)} passages.")
def client_warmup(zmq_port):
"""Perform client-side warmup for DiskANN server"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
# Get actual passage IDs from the loaded passages
sample_ids = []
if hasattr(passages, 'keys') and len(passages) > 0:
available_ids = list(passages.keys())
# Take up to 5 actual IDs, but at least 1
sample_ids = available_ids[:min(5, len(available_ids))]
print(f"Using actual passage IDs for warmup: {sample_ids}")
else:
print("No passages available for warmup, skipping warmup...")
return
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
print("Warning: Could not convert sample IDs to integers, skipping warmup")
return
if not ids_to_send:
print("Skipping warmup send.")
return
# Use protobuf format for warmup
from . import embedding_pb2
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.node_ids.extend(ids_to_send)
request_bytes = req_proto.SerializeToString()
for i in range(3):
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
socket.send(request_bytes)
response_bytes = socket.recv()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
resp_proto.ParseFromString(response_bytes)
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
time.sleep(0.1)
print("Client-side Protobuf ZMQ warmup complete")
socket.close()
context.term()
except Exception as e:
print(f"Error during Protobuf ZMQ warmup: {e}")
class DeviceTimer:
"""设备计时器"""
"""Device timer"""
def __init__(self, name="", device=device):
self.name = name
self.device = device
self.start_time = 0
self.end_time = 0
if cuda_available:
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
else:
@@ -169,123 +309,230 @@ def create_embedding_server_thread(
self.end()
def start(self):
if cuda_available:
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
torch.cuda.synchronize()
self.start_event.record()
else:
if self.device.type == "mps":
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
torch.mps.synchronize()
self.start_time = time.time()
def end(self):
if cuda_available:
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
self.end_event.record()
torch.cuda.synchronize()
else:
if self.device.type == "mps":
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
torch.mps.synchronize()
self.end_time = time.time()
def elapsed_time(self):
if cuda_available:
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
return self.start_event.elapsed_time(self.end_event) / 1000.0
else:
return self.end_time - self.start_time
def print_elapsed(self):
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
elapsed = self.elapsed_time()
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
def process_batch(texts_batch, ids_batch, missing_ids):
"""处理文本批次"""
batch_size = len(texts_batch)
print(f"INFO: Processing batch of size {batch_size}")
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
"""Process text batch"""
if not texts_batch:
return np.array([])
tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device)
embed_timer = DeviceTimer("embedding (batch)", device)
pool_timer = DeviceTimer("mean pooling (batch)", device)
# Filter out empty texts and their corresponding IDs
valid_texts = []
valid_ids = []
for i, text in enumerate(texts_batch):
if text.strip(): # Only include non-empty texts
valid_texts.append(text)
valid_ids.append(ids_batch[i])
with tokenize_timer.timing():
encoded_batch = tokenizer.batch_encode_plus(
texts_batch,
padding="max_length",
if not valid_texts:
print("WARNING: No valid texts in batch")
return np.array([])
# Tokenize
token_timer = DeviceTimer("tokenization")
with token_timer.timing():
inputs = tokenizer(
valid_texts,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt",
return_token_type_ids=False,
)
tokenize_timer.print_elapsed()
max_length=512,
return_tensors="pt"
).to(device)
seq_length = encoded_batch["input_ids"].size(1)
print(f"Batch size: {batch_size}, Sequence length: {seq_length}")
with to_device_timer.timing():
enc = {k: v.to(device) for k, v in encoded_batch.items()}
to_device_timer.print_elapsed()
with torch.no_grad():
with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
embed_timer.print_elapsed()
with pool_timer.timing():
hidden_states = out.last_hidden_state if hasattr(out, "last_hidden_state") else out
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
# Compute embeddings
embed_timer = DeviceTimer("embedding computation")
with embed_timer.timing():
with torch.no_grad():
outputs = model(**inputs)
hidden_states = outputs.last_hidden_state
# Mean pooling
attention_mask = inputs['attention_mask']
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
batch_embeddings = sum_embeddings / sum_mask
pool_timer.print_elapsed()
embed_timer.print_elapsed()
return batch_embeddings.cpu().numpy()
# ZMQ server 主循环 - 修改为REP套接字
# ZMQ server main loop - modified to use REP socket
context = zmq.Context()
socket = context.socket(zmq.ROUTER) # 改为REP套接字
socket = context.socket(zmq.ROUTER) # Changed to REP socket
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
# 设置超时
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5秒接收超时
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300秒发送超时
# Set timeouts
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
from . import embedding_pb2
print(f"INFO: Embedding server ready to serve requests")
# Start warmup thread if enabled
if enable_warmup and len(passages) > 0:
import threading
print(f"Warmup enabled: starting warmup thread")
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
else:
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
while True:
try:
parts = socket.recv_multipart()
# --- 恢复稳健的消息格式判断 ---
# 必须检查 parts 的长度,避免 IndexError
# --- Restore robust message format detection ---
# Must check parts length to avoid IndexError
if len(parts) >= 3:
identity = parts[0]
# empty = parts[1] # 中间的空帧我们通常不关心
# empty = parts[1] # We usually don't care about the middle empty frame
message = parts[2]
elif len(parts) == 2:
# 也能处理没有空帧的情况
# Can also handle cases without empty frame
identity = parts[0]
message = parts[1]
else:
# 如果收到格式错误的消息,打印警告并忽略它,而不是崩溃
# If received message format is wrong, print warning and ignore it instead of crashing
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
continue
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup", device)
# Handle control messages (MessagePack format)
try:
request_payload = msgpack.unpackb(message)
if isinstance(request_payload, list) and len(request_payload) >= 1:
if request_payload[0] == "__QUERY_META_PATH__":
# Return the current meta path being used by the server
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
response = [current_meta_path]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
# Update the server's meta path and reload passages
new_meta_path = request_payload[1]
try:
print(f"INFO: Updating server meta path to: {new_meta_path}")
# Reload passages from the new meta file
passages = load_passages_from_metadata(new_meta_path)
# Store the meta path for future queries
passages._meta_path = new_meta_path
response = ["SUCCESS"]
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
except Exception as e:
print(f"ERROR: Failed to update meta path: {e}")
response = ["FAILED", str(e)]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__QUERY_MODEL__":
# Return the current model being used by the server
response = [model_name]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
# Update the server's embedding model
new_model_name = request_payload[1]
try:
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
# Clean up old model to free memory
if not use_mlx:
print("INFO: Releasing old model from memory...")
old_model = model
old_tokenizer = tokenizer
# Load new tokenizer first
print(f"Loading new tokenizer for {new_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
# Load new model
print(f"Loading new model {new_model_name}...")
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
# Optimize new model
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {new_model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# Now safely delete old model after new one is loaded
del old_model
del old_tokenizer
# Clear GPU cache if available
if device.type == "cuda":
torch.cuda.empty_cache()
print("INFO: Cleared CUDA cache")
elif device.type == "mps":
torch.mps.empty_cache()
print("INFO: Cleared MPS cache")
# Force garbage collection
import gc
gc.collect()
print("INFO: Memory cleanup completed")
# Update model name
model_name = new_model_name
response = ["SUCCESS"]
print(f"INFO: Successfully updated model to: {new_model_name}")
except Exception as e:
print(f"ERROR: Failed to update model: {e}")
response = ["FAILED", str(e)]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
except:
# Not a control message, continue with normal protobuf processing
pass
# 解析请求
e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup")
# Parse request
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = req_proto.node_ids
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
# 添加调试信息
# Add debug information
if len(node_ids) > 0:
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
# 查找文本
# Look up texts
texts = []
missing_ids = []
with lookup_timer.timing():
@@ -295,8 +542,8 @@ def create_embedding_server_thread(
if txt:
texts.append(txt)
else:
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
# 但将其ID记录为缺失
# If text is empty, we still need a placeholder for batch processing,
# but record its ID as missing
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed()
@@ -304,7 +551,7 @@ def create_embedding_server_thread(
if missing_ids:
print(f"WARNING: Missing passages for IDs: {missing_ids}")
# 处理批次
# Process batch
total_size = len(texts)
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
@@ -319,20 +566,31 @@ def create_embedding_server_thread(
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
if embedding_mode == "mlx":
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
elif embedding_mode == "openai":
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
else: # sentence-transformers
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
if embedding_mode == "sentence-transformers":
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings)
print(f"INFO: Combined embeddings shape: {hidden.shape}")
else:
hidden = process_batch(texts, node_ids, missing_ids)
if embedding_mode == "mlx":
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
elif embedding_mode == "openai":
hidden = compute_embeddings_openai(texts, model_name)
else: # sentence-transformers
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
# 序列化响应
# Serialize response
ser_start = time.time()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
@@ -344,32 +602,32 @@ def create_embedding_server_thread(
response_data = resp_proto.SerializeToString()
# REP 套接字发送单个响应
# REP socket sends a single response
socket.send_multipart([identity, b'', response_data])
ser_end = time.time()
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
if embedding_mode == "sentence-transformers":
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
except zmq.Again:
print("INFO: ZMQ socket timeout, continuing to listen")
# REP套接字不需要重新创建只需要继续监听
continue
except Exception as e:
print(f"ERROR: Error in ZMQ server: {e}")
try:
# 发送空响应以维持REQ-REP状态
# Send empty response to maintain REQ-REP state
empty_resp = embedding_pb2.NodeEmbeddingResponse()
socket.send(empty_resp.SerializeToString())
except:
# 如果发送失败,重新创建socket
# If sending fails, recreate socket
socket.close()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
@@ -382,7 +640,6 @@ def create_embedding_server_thread(
raise
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
def create_embedding_server(
domain="demo",
load_passages=True,
@@ -395,12 +652,14 @@ def create_embedding_server(
lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
if __name__ == "__main__":
@@ -417,7 +676,17 @@ if __name__ == "__main__":
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
choices=["sentence-transformers", "mlx", "openai"],
help="Embedding backend mode")
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
args = parser.parse_args()
# Handle backward compatibility with use_mlx
embedding_mode = args.embedding_mode
if args.use_mlx:
embedding_mode = "mlx"
create_embedding_server(
domain=args.domain,
@@ -431,4 +700,6 @@ if __name__ == "__main__":
lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name,
passages_file=args.passages_file,
)
embedding_mode=embedding_mode,
enable_warmup=not args.disable_warmup,
)

View File

@@ -8,9 +8,12 @@ version = "0.1.0"
dependencies = ["leann-core==0.1.0", "numpy"]
[tool.scikit-build]
# 关键:简化的 CMake 路径
# Key: simplified CMake path
cmake.source-dir = "third_party/DiskANN"
# 关键:Python 包在根目录,路径完全匹配
# Key: Python package in root directory, paths match exactly
wheel.packages = ["leann_backend_diskann"]
# 使用默认的 redirect 模式
editable.mode = "redirect"
# Use default redirect mode
editable.mode = "redirect"
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]

View File

@@ -1,7 +1,30 @@
# 最终简化版
cmake_minimum_required(VERSION 3.24)
project(leann_backend_hnsw_wrapper)
set(CMAKE_C_COMPILER_WORKS 1)
set(CMAKE_CXX_COMPILER_WORKS 1)
# Set OpenMP path for macOS
if(APPLE)
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp")
set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
endif()
# Use system ZeroMQ instead of building from source
find_package(PkgConfig REQUIRED)
pkg_check_modules(ZMQ REQUIRED libzmq)
# Add cppzmq headers
include_directories(third_party/cppzmq)
# Configure msgpack-c - disable boost dependency
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
add_compile_definitions(MSGPACK_NO_BOOST)
include_directories(third_party/msgpack-c/include)
# Faiss configuration - streamlined build
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
@@ -9,4 +32,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
# Disable additional SIMD versions to speed up compilation
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
# Additional optimization options from INSTALL.md
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
# Avoid building demos and benchmarks
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
# NEW: Tell Faiss to only build the generic version
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
# IMPORTANT: Disable building AVX versions to speed up compilation
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
add_subdirectory(third_party/faiss)

View File

@@ -1,36 +1,32 @@
import numpy as np
import os
import json
import struct
from pathlib import Path
from typing import Dict, Any, List
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
from typing import Dict, Any, List, Literal
import pickle
import shutil
import time
from leann.embedding_server_manager import EmbeddingServerManager
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend
from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface
LeannBackendSearcherInterface,
)
def get_metric_map():
from . import faiss
from . import faiss # type: ignore
return {
"mips": faiss.METRIC_INNER_PRODUCT,
"l2": faiss.METRIC_L2,
"cosine": faiss.METRIC_INNER_PRODUCT,
}
@register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod
@@ -39,345 +35,201 @@ class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f:
meta = json.load(f)
kwargs['meta'] = meta
return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs.copy()
# --- Configuration defaults with standardized names ---
self.is_compact = self.build_params.setdefault("is_compact", True)
self.is_recompute = self.build_params.setdefault("is_recompute", True)
# --- Additional Options ---
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0)
self.external_storage_path = self.build_params.get("external_storage_path", None)
# --- Standard HNSW parameters ---
self.M = self.build_params.setdefault("M", 32)
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if self.is_skip_neighbors and not self.is_compact:
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""Build HNSW index using FAISS"""
from . import faiss
from . import faiss # type: ignore
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
M = self.M
efConstruction = self.efConstruction
dim = self.dimensions
if not dim:
dim = data.shape[1]
dim = self.dimensions or data.shape[1]
index = faiss.IndexHNSWFlat(dim, self.M, metric_enum)
index.hnsw.efConstruction = self.efConstruction
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
index.hnsw.efConstruction = efConstruction
if metric_str == "cosine":
faiss.normalize_L2(data)
index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
print(f"✅ HNSW index built successfully at '{index_file}'")
if self.distance_metric.lower() == "cosine":
faiss.normalize_L2(data)
if self.is_compact:
self._convert_to_csr(index_file)
except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise
index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
if self.is_compact:
self._convert_to_csr(index_file)
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
try:
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr(
str(index_file),
str(csr_temp_file),
prune_embeddings=self.is_recompute
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr(
str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
)
if success:
print("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
print(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError(
"CSR conversion failed - cannot proceed with compact format"
)
if success:
print("✅ CSR conversion successful.")
import shutil
shutil.move(str(csr_temp_file), str(index_file))
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
except Exception as e:
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise
class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
"""
Robustly determines the index's storage status by parsing the file.
Returns:
A tuple (is_compact, is_pruned).
"""
if not index_file.exists():
return False, False
with open(index_file, 'rb') as f:
try:
def read_struct(fmt):
size = struct.calcsize(fmt)
data = f.read(size)
if len(data) != size:
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.")
return struct.unpack(fmt, data)[0]
def skip_vector(element_size):
count = read_struct('<Q')
f.seek(count * element_size, 1)
# 1. Read up to the compact flag
read_struct('<I'); read_struct('<i'); read_struct('<q');
read_struct('<q'); read_struct('<q'); read_struct('<?')
metric_type = read_struct('<i')
if metric_type > 1: read_struct('<f')
skip_vector(8); skip_vector(4); skip_vector(4)
# 2. Check if there's a compact flag byte
# Try to read the compact flag, but handle both old and new formats
pos_before_compact = f.tell()
try:
is_compact = read_struct('<?')
print(f"INFO: Detected is_compact flag as: {is_compact}")
except (EOFError, struct.error):
# Old format without compact flag - assume non-compact
f.seek(pos_before_compact)
is_compact = False
print(f"INFO: No compact flag found, assuming is_compact=False")
# 3. Read storage FourCC to determine if pruned
is_pruned = False
try:
if is_compact:
# For compact, we need to skip pointers and scalars to get to the storage FourCC
skip_vector(8) # level_ptr
skip_vector(8) # node_offsets
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
storage_fourcc = read_struct('<I')
else:
# For non-compact, we need to read the flag probe, then skip offsets and neighbors
pos_before_probe = f.tell()
flag_byte = f.read(1)
if not (flag_byte and flag_byte == b'\x00'):
f.seek(pos_before_probe)
skip_vector(8); skip_vector(4) # offsets, neighbors
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
# Now we are at the storage. The entire rest is storage blob.
storage_fourcc = struct.unpack('<I', f.read(4))[0]
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
if storage_fourcc == NULL_INDEX_FOURCC:
is_pruned = True
except (EOFError, struct.error):
# Cannot determine pruning status, assume not pruned
pass
print(f"INFO: Detected is_pruned as: {is_pruned}")
return is_compact, is_pruned
except (EOFError, struct.error) as e:
print(f"WARNING: Could not parse index file to detect format: {e}. Assuming standard, not pruned.")
return False, False
class HNSWSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
from . import faiss
self.meta = kwargs.get("meta", {})
if not self.meta:
raise ValueError("HNSWSearcher requires metadata from .meta.json.")
super().__init__(
index_path,
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
**kwargs,
)
from . import faiss # type: ignore
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
self.is_compact, self.is_pruned = (
self.meta.get("is_compact", True),
self.meta.get("is_pruned", True),
)
path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
index_file = self.index_dir / f"{self.index_prefix}.index"
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
# Validate configuration constraints
if not self.is_compact and kwargs.get("is_skip_neighbors", False):
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"):
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
# Apply additional configuration options with strict validation
hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False)
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}")
print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}")
raise RuntimeError("Index is pruned but recompute is disabled.")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
if self.is_compact:
print("✅ Compact CSR format HNSW index loaded successfully.")
else:
print("✅ Standard HNSW index loaded successfully.")
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
batch_size: int = 0,
**kwargs,
) -> Dict[str, Any]:
"""
Search for nearest neighbors using HNSW index.
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""Search using HNSW index with optional recompute functionality"""
from . import faiss
ef = kwargs.get("ef", 200)
if self.is_pruned:
print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.")
if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/efSearch, higher = more accurate but slower
beam_width: Number of parallel search paths/beam_size
prune_ratio: Ratio of neighbors to prune via PQ (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server
pruning_strategy: PQ candidate selection strategy:
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
passages_file = kwargs.get("passages_file")
if not passages_file:
# Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else:
raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
from . import faiss # type: ignore
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings or self.is_pruned
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server(
port=zmq_port,
model_name=self.embedding_model,
passages_file=passages_file,
distance_metric=self.distance_metric
)
if not server_started:
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
if self.distance_metric == "cosine":
faiss.normalize_L2(query)
try:
params = faiss.SearchParametersHNSW()
params.efSearch = ef
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
raise
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()
params = faiss.SearchParametersHNSW()
params.zmq_port = zmq_port
params.efSearch = complexity
params.beam_size = beam_width
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
params.pq_pruning_ratio = prune_ratio
# Map pruning_strategy to HNSW parameters
if pruning_strategy == "local":
params.local_prune = True
params.send_neigh_times_ratio = 0.0
elif pruning_strategy == "proportional":
params.local_prune = False
params.send_neigh_times_ratio = (
1.0 # Any value > 1e-6 triggers proportional mode
)
else: # "global"
params.local_prune = False
params.send_neigh_times_ratio = 0.0
# HNSW-specific batch processing parameter
params.batch_size = batch_size
batch_size_query = query.shape[0]
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
self._index.search(
query.shape[0],
faiss.swig_ptr(query),
top_k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
params,
)
string_labels = [
[str(int_label) for int_label in batch_labels]
for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -1,369 +1,90 @@
#!/usr/bin/env python3
"""
HNSW-specific embedding server with removed config.py dependencies
Based on DiskANN embedding server architecture
HNSW-specific embedding server
"""
import pickle
import argparse
import threading
import time
from transformers import AutoTokenizer, AutoModel
import os
from contextlib import contextmanager
import zmq
import numpy as np
import msgpack
import json
from pathlib import Path
from typing import Dict, Any, Optional, Union
import sys
import logging
RED = "\033[91m"
RESET = "\033[0m"
def is_similarity_metric():
"""
Check if the metric type is similarity-based (like inner product).
0 = L2 (distance metric), 1 = Inner Product (similarity metric)
"""
return True # 1 is METRIC_INNER_PRODUCT in FAISS
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Function for E5-style average pooling
import torch
from torch import Tensor
import torch.nn.functional as F
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
class SimplePassageLoader:
"""
Simple passage loader that replaces config.py dependencies
"""
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {}
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID"""
str_id = str(passage_id)
if str_id in self.passages_data:
return {"text": self.passages_data[str_id]}
else:
# Return empty text for missing passages
return {"text": ""}
def __len__(self) -> int:
return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
passages_data: Optional[Dict[str, str]] = None,
embeddings_file: Optional[str] = None,
use_fp16: bool = True,
use_int8: bool = False,
use_cuda_graphs: bool = False,
zmq_port: int = 5555,
max_batch_size: int = 128,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips",
embedding_mode: str = "sentence-transformers",
):
"""
Create and start a ZMQ-based embedding server for HNSW backend.
Args:
passages_file: Path to JSON file containing passage ID -> text mapping
passages_data: Direct passage data dict (alternative to passages_file)
embeddings_file: Path to pre-computed embeddings file (optional)
use_fp16: Whether to use FP16 precision
use_int8: Whether to use INT8 quantization
use_cuda_graphs: Whether to use CUDA graphs
zmq_port: ZMQ port to bind to
max_batch_size: Maximum batch size for processing
model_name: Transformer model name
custom_max_length_param: Custom max sequence length
distance_metric: The distance metric to use
Simplified version using unified embedding computation module.
"""
print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Tokenizer loaded successfully!")
# Device setup
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
print(f"MPS available: {mps_available}")
print(f"CUDA available: {cuda_available}")
if cuda_available:
device = torch.device("cuda")
print("Using CUDA device")
elif mps_available:
device = torch.device("mps")
print("Using MPS device (Apple Silicon)")
else:
device = torch.device("cpu")
print("Using CPU device (no GPU acceleration available)")
# Load model to the appropriate device
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith(
"text-embedding-"
):
embedding_mode = "openai"
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Loading model {model_name}... (this may take a while if downloading)")
model = AutoModel.from_pretrained(model_name).to(device).eval()
print(f"Model {model_name} loaded successfully!")
print(f"Using embedding mode: {embedding_mode}")
# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
print("Successfully imported unified embedding computation module")
except ImportError as e:
print(f"ERROR: Failed to import embedding computation module: {e}")
return
finally:
sys.path.pop(0)
# Check port availability
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}")
return
# Apply model optimizations (similar to DiskANN version)
if use_fp16 and (cuda_available or mps_available):
model = model.half()
model = torch.compile(model)
print(f"Using FP16 precision with model: {model_name}")
elif use_int8:
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
quantize_(model, Int8DynamicActivationInt8WeightConfig())
model = torch.compile(model)
model.eval()
print("- Model successfully quantized and compiled")
# Load passages
if passages_data:
passages = SimplePassageLoader(passages_data)
print(f"Using provided passages data: {len(passages)} passages")
elif passages_file:
passages = load_passages_from_file(passages_file)
else:
passages = SimplePassageLoader()
print("No passages provided, using empty loader")
# Load embeddings if provided
_embeddings = None
if embeddings_file and os.path.exists(embeddings_file):
try:
with open(embeddings_file, "rb") as f:
_embeddings = pickle.load(f)
print(f"Loaded embeddings from {embeddings_file}")
except Exception as e:
print(f"Error loading embeddings: {e}")
class DeviceTimer:
"""Device event-based timer for accurate timing."""
def __init__(self, name="", device=device):
self.name = name
self.device = device
self.start_time = 0
self.end_time = 0
if cuda_available:
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
else:
self.start_event = None
self.end_event = None
@contextmanager
def timing(self):
self.start()
yield
self.end()
def start(self):
if cuda_available:
torch.cuda.synchronize()
self.start_event.record()
else:
if self.device.type == "mps":
torch.mps.synchronize()
self.start_time = time.time()
def end(self):
if cuda_available:
self.end_event.record()
torch.cuda.synchronize()
else:
if self.device.type == "mps":
torch.mps.synchronize()
self.end_time = time.time()
def elapsed_time(self):
if cuda_available:
return self.start_event.elapsed_time(self.end_event) / 1000.0
else:
return self.end_time - self.start_time
def print_elapsed(self):
return # Disabled for now
def process_batch(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts and return embeddings"""
_is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch)
# E5 model preprocessing
if _is_e5_model:
processed_texts_batch = [f"passage: {text}" for text in texts_batch]
else:
processed_texts_batch = texts_batch
# Set max length
if _is_e5_model:
current_max_length = custom_max_length_param if custom_max_length_param is not None else 512
else:
current_max_length = custom_max_length_param if custom_max_length_param is not None else 256
tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device)
embed_timer = DeviceTimer("embedding (batch)", device)
pool_timer = DeviceTimer("pooling (batch)", device)
norm_timer = DeviceTimer("normalization (batch)", device)
with tokenize_timer.timing():
encoded_batch = tokenizer(
processed_texts_batch,
padding="max_length",
truncation=True,
max_length=current_max_length,
return_tensors="pt",
return_token_type_ids=False,
)
seq_length = encoded_batch["input_ids"].size(1)
with to_device_timer.timing():
enc = {k: v.to(device) for k, v in encoded_batch.items()}
with torch.no_grad():
with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
with pool_timer.timing():
if _is_bge_model:
pooled_embeddings = out.last_hidden_state[:, 0]
elif not hasattr(out, 'last_hidden_state'):
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out
else:
print(f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}")
hidden_dim = getattr(model.config, 'hidden_size', 384 if _is_e5_model else 768)
pooled_embeddings = torch.zeros((batch_size, hidden_dim), device=device, dtype=enc["input_ids"].dtype if hasattr(enc["input_ids"], "dtype") else torch.float32)
elif _is_e5_model:
pooled_embeddings = e5_average_pool(out.last_hidden_state, enc['attention_mask'])
else:
hidden_states = out.last_hidden_state
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
pooled_embeddings = sum_embeddings / sum_mask
final_embeddings = pooled_embeddings
if _is_e5_model or _is_bge_model:
with norm_timer.timing():
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
print(f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}")
dim_size = final_embeddings.shape[-1]
error_output = torch.zeros((batch_size, dim_size), device='cpu', dtype=torch.float32).numpy()
print(f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}")
return error_output
return final_embeddings.cpu().numpy()
def client_warmup(zmq_port):
"""Perform client-side warmup"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
sample_ids = ["1", "2", "3", "4", "5"]
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
ids_to_send = []
if not ids_to_send:
print("Skipping warmup send.")
return
request_payload = [ids_to_send]
request_bytes = msgpack.packb(request_payload)
for i in range(3):
print(f"Sending warmup request {i+1}/3 via ZMQ (MessagePack)...")
socket.send(request_bytes)
response_bytes = socket.recv()
response_payload = msgpack.unpackb(response_bytes)
dimensions = response_payload[0]
embeddings_count = dimensions[0] if dimensions and len(dimensions) > 0 else 0
print(f"Warmup request {i+1}/3 successful, received {embeddings_count} embeddings")
time.sleep(0.1)
print("Client-side MessagePack ZMQ warmup complete")
socket.close()
context.term()
except Exception as e:
print(f"Error during MessagePack ZMQ warmup: {e}")
# Only support metadata file, fail fast for everything else
if not passages_file or not passages_file.endswith(".meta.json"):
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file, "r") as f:
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata")
def zmq_server_thread():
"""ZMQ server thread"""
@@ -381,206 +102,156 @@ def create_hnsw_embedding_server(
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup", device)
request_payload = msgpack.unpackb(message_bytes)
try:
request_payload = msgpack.unpackb(message_bytes)
# Handle distance calculation requests
if isinstance(request_payload, list) and len(request_payload) == 2 and isinstance(request_payload[0], list) and isinstance(request_payload[1], list):
node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32)
print(f"Request for distance calculation: {len(node_ids)} nodes, query vector dim: {len(query_vector)}")
# Get embeddings for node IDs
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
try:
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
else:
txt = txtinfo["text"]
except (KeyError, IndexError):
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
texts.append(txt)
lookup_timer.print_elapsed()
# Process embeddings in chunks if needed
all_node_embeddings = []
total_size = len(texts)
if total_size > max_batch_size:
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
all_node_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
node_embeddings = np.vstack(all_node_embeddings)
else:
node_embeddings = process_batch(texts, node_ids, missing_ids)
# Calculate distances
query_tensor = torch.tensor(query_vector, device=device).float()
node_embeddings_tensor = torch.tensor(node_embeddings, device=device).float()
calc_timer = DeviceTimer("distance calculation", device)
with calc_timer.timing():
with torch.no_grad():
if distance_metric == "l2":
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
query_np = query_tensor.cpu().numpy().astype(np.float32)
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
else: # mips or cosine
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
query_np = query_tensor.cpu().numpy()
distances = -np.dot(node_embeddings_np, query_np)
calc_timer.print_elapsed()
try:
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True)
print(f"Sending distance response with {len(distances)} distances")
except Exception as pack_error:
print(f"Error packing MessagePack distance response: {pack_error}")
response_bytes = msgpack.packb([[]])
socket.send(response_bytes)
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
# Handle direct text embedding request (for OpenAI and sentence-transformers)
if isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload):
logger.info(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
# Use unified embedding computation
embeddings = compute_embeddings(
request_payload, model_name, mode=embedding_mode
)
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
print(f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds")
logger.info(
f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s"
)
continue
# Standard embedding request
if not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list):
print(f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}")
socket.send(msgpack.packb([[], []]))
continue
# Handle distance calculation requests
if (
isinstance(request_payload, list)
and len(request_payload) == 2
and isinstance(request_payload[0], list)
and isinstance(request_payload[1], list)
):
node_ids = request_payload[0]
print(f"Request for {len(node_ids)} node embeddings")
except Exception as unpack_error:
print(f"Error unpacking MessagePack request: {unpack_error}")
query_vector = np.array(request_payload[1], dtype=np.float32)
logger.debug("Distance calculation request received")
print(f" Node IDs: {node_ids}")
print(f" Query vector dim: {len(query_vector)}")
# Get embeddings for node IDs
texts = []
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
texts.append(txt)
except KeyError:
print(f"ERROR: Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Calculate distances
if distance_metric == "l2":
distances = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
distances = -np.dot(embeddings, query_vector)
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb(
[response_payload], use_single_float=True
)
print(f"Sending distance response with {len(distances)} distances")
socket.send(response_bytes)
e2e_end = time.time()
logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
continue
# Standard embedding request (passage ID lookup)
if (
not isinstance(request_payload, list)
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
print(
f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
node_ids = request_payload[0]
print(f"Request for {len(node_ids)} node embeddings")
# Look up texts by node IDs
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
try:
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
else:
txt = txtinfo["text"]
except (KeyError, IndexError):
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
lookup_timer.print_elapsed()
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
raise
if missing_ids:
print(f"Missing passages for IDs: {missing_ids}")
# Process in chunks
total_size = len(texts)
print(f"Total batch size: {total_size}, max_batch_size: {max_batch_size}")
all_embeddings = []
if total_size > max_batch_size:
print(f"Splitting batch of size {total_size} into chunks of {max_batch_size}")
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
print(f"Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings)
print(f"Combined embeddings shape: {hidden.shape}")
else:
hidden = process_batch(texts, node_ids, missing_ids)
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Serialization and response
ser_start = time.time()
print(f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}")
if np.isnan(hidden).any() or np.isinf(hidden).any():
print(f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
f"Requested IDs (sample): {node_ids[:5]}...{RESET}")
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
print(
f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}"
)
assert False
try:
hidden_contiguous_f32 = np.ascontiguousarray(hidden, dtype=np.float32)
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist()
]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
except Exception as pack_error:
print(f"Error packing MessagePack response: {pack_error}")
response_bytes = msgpack.packb([[], []])
hidden_contiguous_f32 = np.ascontiguousarray(
embeddings, dtype=np.float32
)
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(),
]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
socket.send(response_bytes)
ser_end = time.time()
print(f"Serialize time: {ser_end - ser_start:.6f} seconds")
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
print("ZMQ socket timeout, continuing to listen")
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
print(f"Error in ZMQ server loop: {e}")
import traceback
traceback.print_exc()
try:
socket.send(msgpack.packb([[], []]))
except:
pass
# Start warmup and server threads
if len(passages) > 0:
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
traceback.print_exc()
socket.send(msgpack.packb([[], []]))
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while True:
@@ -593,29 +264,35 @@ def create_hnsw_embedding_server(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--embeddings-file", type=str, help="Pickle file containing pre-computed embeddings")
parser.add_argument("--use-fp16", action="store_true", default=False)
parser.add_argument("--use-int8", action="store_true", default=False)
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use")
parser.add_argument(
"--passages-file",
type=str,
help="JSON file containing passage ID to text mapping",
)
parser.add_argument(
"--model-name",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name",
)
parser.add_argument(
"--distance-metric", type=str, default="mips", help="Distance metric to use"
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai"],
help="Embedding backend mode",
)
args = parser.parse_args()
# Create and start the HNSW embedding server
create_hnsw_embedding_server(
passages_file=args.passages_file,
embeddings_file=args.embeddings_file,
use_fp16=args.use_fp16,
use_int8=args.use_int8,
use_cuda_graphs=args.use_cuda_graphs,
zmq_port=args.zmq_port,
max_batch_size=args.max_batch_size,
model_name=args.model_name,
custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric,
)
embedding_mode=args.embedding_mode,
)

View File

@@ -1,4 +1,4 @@
# 文件: packages/leann-backend-hnsw/pyproject.toml
# packages/leann-backend-hnsw/pyproject.toml
[build-system]
requires = ["scikit-build-core>=0.10", "numpy", "swig"]
@@ -10,9 +10,13 @@ version = "0.1.0"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = ["leann-core==0.1.0", "numpy"]
# 回归到最标准的 scikit-build-core 配置
[tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"]
editable.mode = "redirect"
cmake.build-type = "Debug"
build.verbose = true
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]
# CMake definitions to optimize compilation
[tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8"

View File

@@ -11,8 +11,12 @@ requires-python = ">=3.9"
license = { text = "MIT" }
dependencies = [
"numpy>=1.20.0"
"numpy>=1.20.0",
"tqdm>=4.60.0"
]
[project.scripts]
leann = "leann.cli:main"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,4 +1,14 @@
# packages/leann-core/src/leann/__init__.py
import os
import platform
# Fix OpenMP threading issues on macOS ARM64
if platform.system() == "Darwin":
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["KMP_BLOCKTIME"] = "0"
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends

View File

@@ -1,63 +1,97 @@
"""
This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code.
"""
import json
import pickle
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from typing import List, Dict, Any, Optional
import numpy as np
import os
import json
from pathlib import Path
import openai
from dataclasses import dataclass, field
import uuid
import pickle
from .chat import get_llm
# --- Helper Functions for Embeddings ---
def _get_openai_client():
"""Initializes and returns an OpenAI client, ensuring the API key is set."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set, which is required for OpenAI models.")
return openai.OpenAI(api_key=api_key)
def compute_embeddings(
chunks: List[str],
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
port: int = 5557,
) -> np.ndarray:
"""
Computes embeddings using different backends.
def _is_openai_model(model_name: str) -> bool:
"""Checks if the model is likely an OpenAI embedding model."""
# This is a simple check, can be improved with a more robust list.
return "ada" in model_name or "babbage" in model_name or model_name.startswith("text-embedding-")
Args:
chunks: List of text chunks to embed
model_name: Name of the embedding model
mode: Embedding backend mode. Options:
- "sentence-transformers": Use sentence-transformers library (default)
- "mlx": Use MLX backend for Apple Silicon
- "openai": Use OpenAI embedding API
use_server: Whether to use embedding server (True for search, False for build)
def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings for a list of text chunks using either SentenceTransformers or OpenAI."""
if _is_openai_model(model_name):
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=chunks)
embeddings = [item.embedding for item in response.data]
Returns:
numpy array of embeddings
"""
if use_server:
# Use embedding server (for search/query)
return compute_embeddings_via_server(chunks, model_name, port=port)
else:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
embeddings = model.encode(chunks, show_progress_bar=True)
return np.asarray(embeddings, dtype=np.float32)
# Use direct computation (for build_index)
from .embedding_compute import (
compute_embeddings as compute_embeddings_direct,
)
def _get_embedding_dimensions(model_name: str) -> int:
"""Gets the embedding dimensions for a given model."""
print(f"INFO: Calculating dimensions for model '{model_name}'...")
if _is_openai_model(model_name):
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=["dummy text"])
return len(response.data[0].embedding)
else:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
dimension = model.get_sentence_embedding_dimension()
if dimension is None:
raise ValueError(f"Model '{model_name}' does not have a valid embedding dimension.")
return dimension
return compute_embeddings_direct(
chunks,
model_name,
mode=mode,
)
def compute_embeddings_via_server(
chunks: List[str], model_name: str, port: int
) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
chunks: List of text chunks to embed
model_name: Name of the sentence transformer model
"""
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
import zmq
import msgpack
import numpy as np
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
@dataclass
class SearchResult:
"""Represents a single search result."""
id: str
score: float
text: str
@@ -65,276 +99,429 @@ class SearchResult:
class PassageManager:
"""Manages passage data and lazy loading from JSONL files."""
def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not os.path.exists(index_file):
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, 'rb') as f:
if not Path(index_file).exists():
raise FileNotFoundError(
f"Passage index file not found: {index_file}"
)
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
"""Lazy load a passage by ID."""
for passage_file, offset_map in self.offset_maps.items():
if passage_id in offset_map:
offset = offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
line = f.readline()
return json.loads(line)
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}")
# --- Core Classes ---
class LeannBuilder:
"""
The builder is responsible for building the index, it will compute the embeddings and then build the index.
It will also save the metadata of the index.
"""
def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", dimensions: Optional[int] = None, **backend_kwargs):
def __init__(
self,
backend_name: str,
embedding_model: str = "facebook/contriever-msmarco",
dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers",
**backend_kwargs,
):
self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(
backend_name
)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory
self.embedding_model = embedding_model
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = []
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None:
metadata = {}
# Check if ID is provided in metadata
passage_id = metadata.get('id')
if passage_id is None:
passage_id = str(uuid.uuid4())
else:
# Validate uniqueness
existing_ids = {chunk['id'] for chunk in self.chunks}
if passage_id in existing_ids:
raise ValueError(f"Duplicate passage ID: {passage_id}")
# Store the definitive ID with the chunk
chunk_data = {
"id": passage_id,
"text": text,
"metadata": metadata
}
passage_id = metadata.get("id", str(len(self.chunks)))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data)
def build_index(self, index_path: str):
if not self.chunks:
raise ValueError("No chunks added. Use add_text() first.")
raise ValueError("No chunks added.")
if self.dimensions is None:
self.dimensions = _get_embedding_dimensions(self.embedding_model)
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
self.dimensions = len(
compute_embeddings(
["dummy"],
self.embedding_model,
self.embedding_mode,
use_server=False,
)[0]
)
path = Path(index_path)
index_dir = path.parent
index_name = path.name
# Ensure the directory exists
index_dir.mkdir(parents=True, exist_ok=True)
# Create the passages.jsonl file and offset index
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
offset_map = {}
with open(passages_file, 'w', encoding='utf-8') as f:
for chunk in self.chunks:
offset = f.tell()
passage_data = {
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"]
}
json.dump(passage_data, f, ensure_ascii=False)
f.write('\n')
offset_map[chunk["id"]] = offset
# Save the offset map
with open(offset_file, 'wb') as f:
pickle.dump(offset_map, f)
# Compute embeddings
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
# Extract string IDs for the backend
string_ids = [chunk["id"] for chunk in self.chunks]
# Build the vector index
current_backend_kwargs = self.backend_kwargs.copy()
current_backend_kwargs['dimensions'] = self.dimensions
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
with open(passages_file, "w", encoding="utf-8") as f:
try:
from tqdm import tqdm
# Create the lightweight meta.json file
chunk_iterator = tqdm(
self.chunks, desc="Writing passages", unit="chunk"
)
except ImportError:
chunk_iterator = self.chunks
for chunk in chunk_iterator:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"],
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
port=5557,
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(
embeddings, string_ids, index_path, **current_backend_kwargs
)
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0",
"backend_name": self.backend_name,
"embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"embedding_mode": self.embedding_mode,
"passage_sources": [
{
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file)
"index_path": str(offset_file),
}
]
],
}
with open(leann_meta_path, 'w', encoding='utf-8') as f:
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = (
is_compact and is_recompute
) # Pruned only if compact and recompute
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
print(f"INFO: Leann metadata saved to {leann_meta_path}")
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
"""
Build an index from pre-computed embeddings stored in a pickle file.
Args:
index_path: Path where the index will be saved
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
"""
# Load pre-computed embeddings
with open(embeddings_file, "rb") as f:
data = pickle.load(f)
if not isinstance(data, tuple) or len(data) != 2:
raise ValueError(
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
)
ids, embeddings = data
if not isinstance(embeddings, np.ndarray):
raise ValueError(
f"Expected embeddings to be numpy array, got {type(embeddings)}"
)
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
)
# Validate/set dimensions
embedding_dim = embeddings.shape[1]
if self.dimensions is None:
self.dimensions = embedding_dim
elif self.dimensions != embedding_dim:
raise ValueError(
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
print(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
)
# Ensure we have text data for each embedding
if len(self.chunks) != len(ids):
# If no text chunks provided, create placeholder text entries
if not self.chunks:
print("No text chunks provided, creating placeholder entries...")
for id_val in ids:
self.add_text(
f"Document {id_val}",
metadata={"id": str(id_val), "from_embeddings": True},
)
else:
raise ValueError(
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
)
# Build file structure
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_dir.mkdir(parents=True, exist_ok=True)
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
# Write passages and create offset map
offset_map = {}
with open(passages_file, "w", encoding="utf-8") as f:
for chunk in self.chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"],
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
# Build the vector index using precomputed embeddings
string_ids = [str(id_val) for id_val in ids]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path)
# Create metadata file
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0",
"backend_name": self.backend_name,
"embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"embedding_mode": self.embedding_mode,
"passage_sources": [
{
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file),
}
],
"built_from_precomputed_embeddings": True,
"embeddings_source": str(embeddings_file),
}
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
print(f"Index built successfully from precomputed embeddings: {index_path}")
class LeannSearcher:
"""
The searcher is responsible for loading the index and performing the search.
It will also load the metadata of the index.
"""
def __init__(self, index_path: str, **backend_kwargs):
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
if not leann_meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}. Was the index built with LeannBuilder?")
with open(leann_meta_path, 'r', encoding='utf-8') as f:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
with open(meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f)
backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model']
# Initialize the passage manager
passage_sources = self.meta_data.get('passage_sources', [])
self.passage_manager = PassageManager(passage_sources)
backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
# Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx"
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
final_kwargs = backend_kwargs.copy()
final_kwargs['meta'] = self.meta_data
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
def search(self, query: str, top_k: int = 5, **search_kwargs):
query_embedding = _compute_embeddings([query], self.embedding_model)
search_kwargs['embedding_model'] = self.embedding_model
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
def search(
self,
query: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Additional kwargs: {kwargs}")
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
import time
start_time = time.time()
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
print(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
print(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
results = self.backend_impl.search(
query_embedding,
top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
)
search_time = time.time() - start_time
print(f" Search time: {search_time} seconds")
print(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
enriched_results = []
for string_id, dist in zip(results['labels'][0], results['distances'][0]):
try:
passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult(
id=string_id,
score=dist,
text=passage_data['text'],
metadata=passage_data.get('metadata', {})
))
except KeyError:
print(f"WARNING: Passage ID '{string_id}' not found in passage files")
if "labels" in results and "distances" in results:
print(f" Processing {len(results['labels'][0])} passage IDs:")
for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0])
):
try:
passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(
SearchResult(
id=string_id,
score=dist,
text=passage_data["text"],
metadata=passage_data.get("metadata", {}),
)
)
print(
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
)
except KeyError:
print(
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
)
print(f" Final enriched results: {len(enriched_results)} passages")
return enriched_results
class LeannChat:
"""
The chat is responsible for the conversation with the LLM.
It will use the searcher to get the results and then use the LLM to generate the response.
"""
def __init__(self, index_path: str, backend_name: Optional[str] = None, llm_model: str = "gpt-4o", **kwargs):
if backend_name is None:
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
if not leann_meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}.")
with open(leann_meta_path, 'r', encoding='utf-8') as f:
meta_data = json.load(f)
backend_name = meta_data['backend_name']
self.searcher = LeannSearcher(index_path, **kwargs)
self.llm_model = llm_model
def ask(self, question: str, top_k=5, **kwargs):
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
def __init__(
self,
index_path: str,
llm_config: Optional[Dict[str, Any]] = None,
enable_warmup: bool = False,
**kwargs,
):
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config)
results = self.searcher.search(question, top_k=top_k, **kwargs)
def ask(
self,
question: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
llm_kwargs: Optional[Dict[str, Any]] = None,
**search_kwargs,
):
if llm_kwargs is None:
llm_kwargs = {}
results = self.searcher.search(
question,
top_k=top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**search_kwargs,
)
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n"
f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
ans = self.llm.ask(prompt, **llm_kwargs)
return ans
print(f"DEBUG: Calling LLM with prompt: {prompt}...")
try:
client = _get_openai_client()
response = client.chat.completions.create(
model=self.llm_model,
messages=[
{"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."},
{"role": "user", "content": prompt}
]
)
return response.choices[0].message.content
except Exception as e:
print(f"ERROR: Failed to call OpenAI API: {e}")
return f"Error: Could not get a response from the LLM. {e}"
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit']:
if user_input.lower() in ["quit", "exit"]:
break
if not user_input:
continue

View File

@@ -0,0 +1,562 @@
#!/usr/bin/env python3
"""
This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import logging
import os
import difflib
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_ollama_models() -> List[str]:
"""Check available Ollama models and return a list"""
try:
import requests
response = requests.get("http://localhost:11434/api/tags", timeout=5)
if response.status_code == 200:
data = response.json()
return [model["name"] for model in data.get("models", [])]
return []
except Exception:
return []
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
return []
query_lower = query.lower()
suggestions = []
# 1. Exact matches first
exact_matches = [m for m in available_models if query_lower == m.lower()]
suggestions.extend(exact_matches)
# 2. Starts with query
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
suggestions.extend(starts_with)
# 3. Contains query
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
suggestions.extend(contains)
# 4. Base model name matching (remove version numbers)
def get_base_name(model_name: str) -> str:
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
return model_name.split(':')[0].split('-')[0]
query_base = get_base_name(query_lower)
base_matches = [
m for m in available_models
if get_base_name(m.lower()) == query_base and m not in suggestions
]
suggestions.extend(base_matches)
# 5. Family/variant matching
model_families = {
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
'qwen': ['qwen', 'qwen2', 'qwen3'],
'gemma': ['gemma', 'gemma2'],
'phi': ['phi', 'phi2', 'phi3'],
'mistral': ['mistral', 'mixtral', 'openhermes'],
'dolphin': ['dolphin', 'openchat'],
'deepseek': ['deepseek', 'deepseek-coder']
}
query_family = None
for family, variants in model_families.items():
if any(variant in query_lower for variant in variants):
query_family = family
break
if query_family:
family_variants = model_families[query_family]
family_matches = [
m for m in available_models
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
]
suggestions.extend(family_matches)
# 6. Use difflib for remaining fuzzy matches
remaining_models = [m for m in available_models if m not in suggestions]
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
suggestions.extend(difflib_matches)
return suggestions[:8] # Return top 8 suggestions
# Remove this function entirely - we don't need external API calls for Ollama
# Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
"""Use difflib to find similar model names"""
if not available_models:
return []
# Get close matches using fuzzy matching
suggestions = difflib.get_close_matches(
invalid_model, available_models, n=3, cutoff=0.3
)
return suggestions
def check_hf_model_exists(model_name: str) -> bool:
"""Quick check if HuggingFace model exists without downloading"""
try:
from huggingface_hub import model_info
model_info(model_name)
return True
except Exception:
return False
def get_popular_hf_models() -> List[str]:
"""Return a list of popular HuggingFace models for suggestions"""
try:
from huggingface_hub import list_models
# Get popular text-generation models, sorted by downloads
models = list_models(
filter="text-generation",
sort="downloads",
direction=-1,
limit=20 # Get top 20 most downloaded
)
# Extract model names and filter for chat/conversation models
model_names = []
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
for model in models:
model_name = model.id if hasattr(model, 'id') else str(model)
# Prioritize models with chat-related keywords
if any(keyword in model_name.lower() for keyword in chat_keywords):
model_names.append(model_name)
elif len(model_names) < 10: # Fill up with other popular models
model_names.append(model_name)
return model_names[:10] if model_names else _get_fallback_hf_models()
except Exception:
# Fallback to static list if API call fails
return _get_fallback_hf_models()
def _get_fallback_hf_models() -> List[str]:
"""Fallback list of popular HuggingFace models"""
return [
"microsoft/DialoGPT-medium",
"microsoft/DialoGPT-large",
"facebook/blenderbot-400M-distill",
"microsoft/phi-2",
"deepseek-ai/deepseek-llm-7b-chat",
"microsoft/DialoGPT-small",
"facebook/blenderbot_small-90M",
"microsoft/phi-1_5",
"facebook/opt-350m",
"EleutherAI/gpt-neo-1.3B"
]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
try:
from huggingface_hub import list_models
# HF Hub's search is already fuzzy! It handles typos and partial matches
models = list_models(
search=query,
filter="text-generation",
sort="downloads",
direction=-1,
limit=limit
)
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
# If direct search doesn't return enough results, try some variations
if len(model_names) < 3:
# Try searching for partial matches or common variations
variations = []
# Extract base name (e.g., "gpt3" from "gpt-3.5")
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
if base_query != query.lower():
variations.append(base_query)
# Try common model name patterns
if 'gpt' in query.lower():
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
elif 'llama' in query.lower():
variations.extend(['llama2', 'alpaca', 'vicuna'])
elif 'bert' in query.lower():
variations.extend(['roberta', 'distilbert', 'albert'])
# Search with variations
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
try:
var_models = list_models(
search=var,
filter="text-generation",
sort="downloads",
direction=-1,
limit=3
)
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
model_names.extend(var_names)
except:
continue
# Remove duplicates while preserving order
seen = set()
unique_models = []
for model in model_names:
if model not in seen:
seen.add(model)
unique_models.append(model)
return unique_models[:limit]
except Exception:
# If search fails, return empty list
return []
def search_hf_models(query: str, limit: int = 10) -> List[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit)
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models()
if available_models and model_name not in available_models:
# Use intelligent fuzzy search based on locally installed models
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\nTo list all models: ollama list"
error_msg += "\nTo download a new model: ollama pull <model_name>"
error_msg += "\nBrowse models: https://ollama.com/library"
return error_msg
elif llm_type == "hf":
# For HF models, we can do a quick existence check
if not check_hf_model_exists(model_name):
# Use HF Hub's native fuzzy search directly
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
if search_suggestions:
error_msg += "\n\nDid you mean one of these?\n"
for i, suggestion in enumerate(search_suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Fallback to popular models if search returns nothing
popular_models = get_popular_hf_models()
error_msg += "\n\nPopular chat models:\n"
for i, model in enumerate(popular_models[:5], 1):
error_msg += f" {i}. {model}\n"
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
return error_msg
return None # Model is valid or we can't check
class LLMInterface(ABC):
"""Abstract base class for a generic Language Model (LLM) interface."""
@abstractmethod
def ask(self, prompt: str, **kwargs) -> str:
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
# """
# Sends a prompt to the LLM and returns the generated text.
# Args:
# prompt: The input prompt for the LLM.
# **kwargs: Additional keyword arguments for the LLM backend.
# Returns:
# The response string from the LLM.
# """
pass
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama")
if model_error:
raise ValueError(model_error)
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
import requests
import json
full_url = f"{self.host}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False, # Keep it simple for now
"options": kwargs,
}
logger.info(f"Sending request to Ollama: {payload}")
try:
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
# The response from Ollama can be a stream of JSON objects, handle this
response_parts = response.text.strip().split("\n")
full_response = ""
for part in response_parts:
if part:
json_part = json.loads(part)
full_response += json_part.get("response", "")
if json_part.get("done"):
break
return full_response
except requests.exceptions.RequestException as e:
logger.error(f"Error communicating with Ollama: {e}")
return f"Error: Could not get a response from Ollama. Details: {e}"
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf")
if model_error:
raise ValueError(model_error)
try:
from transformers.pipelines import pipeline
import torch
except ImportError:
raise ImportError(
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
)
# Auto-detect device
if torch.cuda.is_available():
device = "cuda"
logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.")
else:
device = "cpu"
logger.info("No GPU detected. Using CPU.")
self.pipeline = pipeline("text-generation", model=model_name, device=device)
def ask(self, prompt: str, **kwargs) -> str:
# Map OpenAI-style arguments to Hugging Face equivalents
if "max_tokens" in kwargs:
# Prefer user-provided max_new_tokens if both are present
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
# Remove the unsupported key to avoid errors in Transformers
kwargs.pop("max_tokens")
# Handle temperature=0 edge-case for greedy decoding
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
# Remove unsupported zero temperature and use deterministic generation
kwargs.pop("temperature")
kwargs.setdefault("do_sample", False)
# Sensible defaults for text generation
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
# Handle different response formats from transformers
if isinstance(results, list) and len(results) > 0:
generated_text = (
results[0].get("generated_text", "")
if isinstance(results[0], dict)
else str(results[0])
)
else:
generated_text = str(results)
# Extract only the newly generated portion by removing the original prompt
if isinstance(generated_text, str) and generated_text.startswith(prompt):
response = generated_text[len(prompt) :].strip()
else:
# Fallback: return the full response if prompt removal fails
response = str(generated_text)
return response
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing OpenAI Chat with model='{model}'")
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
)
def ask(self, prompt: str, **kwargs) -> str:
# Default parameters for OpenAI
params = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
**{
k: v
for k, v in kwargs.items()
if k not in ["max_tokens", "temperature"]
},
}
logger.info(f"Sending request to OpenAI with model {self.model}")
try:
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")
return f"Error: Could not get a response from OpenAI. Details: {e}"
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
def ask(self, prompt: str, **kwargs) -> str:
logger.info("Simulating LLM call...")
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.
Args:
llm_config: A dictionary specifying the LLM type and its parameters.
Example: {"type": "ollama", "model": "llama3"}
{"type": "hf", "model": "distilgpt2"}
None (for simulation mode)
Returns:
An instance of an LLMInterface subclass.
"""
if llm_config is None:
llm_config = {
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
}
llm_type = llm_config.get("type", "openai")
model = llm_config.get("model")
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
if llm_type == "ollama":
return OllamaChat(
model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"),
)
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":
return SimulatedChat()
else:
raise ValueError(f"Unknown LLM type: '{llm_type}'")

View File

@@ -0,0 +1,287 @@
#!/usr/bin/env python3
import argparse
import asyncio
import sys
from pathlib import Path
from typing import Optional
import os
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from .api import LeannBuilder, LeannSearcher, LeannChat
class LeannCLI:
def __init__(self):
self.indexes_dir = Path.home() / ".leann" / "indexes"
self.indexes_dir.mkdir(parents=True, exist_ok=True)
self.node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
def get_index_path(self, index_name: str) -> str:
index_dir = self.indexes_dir / index_name
return str(index_dir / "documents.leann")
def index_exists(self, index_name: str) -> bool:
index_dir = self.indexes_dir / index_name
meta_file = index_dir / "documents.leann.meta.json"
return meta_file.exists()
def create_parser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="leann",
description="LEANN - Local Enhanced AI Navigation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
leann build my-docs --docs ./documents # Build index named my-docs
leann search my-docs "query" # Search in my-docs index
leann ask my-docs "question" # Ask my-docs index
leann list # List all stored indexes
"""
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Build command
build_parser = subparsers.add_parser("build", help="Build document index")
build_parser.add_argument("index_name", help="Index name")
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
build_parser.add_argument("--backend", type=str, default="hnsw", choices=["hnsw", "diskann"])
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
build_parser.add_argument("--graph-degree", type=int, default=32)
build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument("--num-threads", type=int, default=1)
build_parser.add_argument("--compact", action="store_true", default=True)
build_parser.add_argument("--recompute", action="store_true", default=True)
# Search command
search_parser = subparsers.add_parser("search", help="Search documents")
search_parser.add_argument("index_name", help="Index name")
search_parser.add_argument("query", help="Search query")
search_parser.add_argument("--top-k", type=int, default=5)
search_parser.add_argument("--complexity", type=int, default=64)
search_parser.add_argument("--beam-width", type=int, default=1)
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
search_parser.add_argument("--recompute-embeddings", action="store_true")
search_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument("--llm", type=str, default="ollama", choices=["simulated", "ollama", "hf", "openai"])
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
ask_parser.add_argument("--interactive", "-i", action="store_true")
ask_parser.add_argument("--top-k", type=int, default=20)
ask_parser.add_argument("--complexity", type=int, default=32)
ask_parser.add_argument("--beam-width", type=int, default=1)
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
ask_parser.add_argument("--recompute-embeddings", action="store_true")
ask_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
# List command
list_parser = subparsers.add_parser("list", help="List all indexes")
return parser
def list_indexes(self):
print("Stored LEANN indexes:")
if not self.indexes_dir.exists():
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
return
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
if not index_dirs:
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
return
print(f"Found {len(index_dirs)} indexes:")
for i, index_dir in enumerate(index_dirs, 1):
index_name = index_dir.name
status = "" if self.index_exists(index_name) else ""
print(f" {i}. {index_name} [{status}]")
if self.index_exists(index_name):
meta_file = index_dir / "documents.leann.meta.json"
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (1024 * 1024)
print(f" Size: {size_mb:.1f} MB")
if index_dirs:
example_name = index_dirs[0].name
print(f"\nUsage:")
print(f" leann search {example_name} \"your query\"")
print(f" leann ask {example_name} --interactive")
def load_documents(self, docs_dir: str):
print(f"Loading documents from {docs_dir}...")
documents = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md", ".docx"],
).load_data(show_progress=True)
all_texts = []
for doc in documents:
nodes = self.node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
return all_texts
async def build_index(self, args):
docs_dir = args.docs
index_name = args.index_name
index_dir = self.indexes_dir / index_name
index_path = self.get_index_path(index_name)
if index_dir.exists() and not args.force:
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
return
all_texts = self.load_documents(docs_dir)
if not all_texts:
print("No documents found")
return
index_dir.mkdir(parents=True, exist_ok=True)
print(f"Building index '{index_name}' with {args.backend} backend...")
builder = LeannBuilder(
backend_name=args.backend,
embedding_model=args.embedding_model,
graph_degree=args.graph_degree,
complexity=args.complexity,
is_compact=args.compact,
is_recompute=args.recompute,
num_threads=args.num_threads,
)
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"Index built at {index_path}")
async def search_documents(self, args):
index_name = args.index_name
query = args.query
index_path = self.get_index_path(index_name)
if not self.index_exists(index_name):
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
return
searcher = LeannSearcher(index_path=index_path)
results = searcher.search(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy
)
print(f"Search results for '{query}' (top {len(results)}):")
for i, result in enumerate(results, 1):
print(f"{i}. Score: {result.score:.3f}")
print(f" {result.text[:200]}...")
print()
async def ask_questions(self, args):
index_name = args.index_name
index_path = self.get_index_path(index_name)
if not self.index_exists(index_name):
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
return
print(f"Starting chat with index '{index_name}'...")
print(f"Using {args.model} ({args.llm})")
llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama":
llm_config["host"] = args.host
chat = LeannChat(index_path=index_path, llm_config=llm_config)
if args.interactive:
print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40)
while True:
user_input = input("\nYou: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
print("Goodbye!")
break
if not user_input:
continue
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy
)
print(f"LEANN: {response}")
else:
query = input("Enter your question: ").strip()
if query:
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy
)
print(f"LEANN: {response}")
async def run(self, args=None):
parser = self.create_parser()
if args is None:
args = parser.parse_args()
if not args.command:
parser.print_help()
return
if args.command == "list":
self.list_indexes()
elif args.command == "build":
await self.build_index(args)
elif args.command == "search":
await self.search_documents(args)
elif args.command == "ask":
await self.ask_questions(args)
else:
parser.print_help()
def main():
import dotenv
dotenv.load_dotenv()
cli = LeannCLI()
asyncio.run(cli.run())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,272 @@
"""
Unified embedding computation module
Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import numpy as np
import torch
from typing import List
import logging
logger = logging.getLogger(__name__)
def compute_embeddings(
texts: List[str], model_name: str, mode: str = "sentence-transformers"
) -> np.ndarray:
"""
Unified embedding computation entry point
Args:
texts: List of texts to compute embeddings for
model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(texts, model_name)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
def compute_embeddings_sentence_transformers(
texts: List[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
batch_size: int = 32,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer
Preserves all optimization parameters to ensure consistency with original embedding_server
Args:
texts: List of texts to compute embeddings for
model_name: SentenceTransformer model name
use_fp16: Whether to use FP16 precision
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
print(
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
from sentence_transformers import SentenceTransformer
# Auto-detect device
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"INFO: Using device: {device}")
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True, # Skip weight initialization checks for faster loading
}
tokenizer_kwargs = {
"use_fast": True, # Use fast tokenizer for better runtime performance
}
# Load SentenceTransformer (try local first, then network)
print(f"INFO: Loading SentenceTransformer model: {model_name}")
try:
# Try local loading (avoid network delays)
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
model = torch.compile(model)
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
except Exception as e:
print(
f"FP16 or compile optimization failed, continuing with default settings: {e}"
)
# Compute embeddings (using SentenceTransformer's optimized implementation)
print("INFO: Starting embedding computation...")
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False, # Keep consistent with original API behavior
device=device,
)
print(
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
return embeddings
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
"""Compute embeddings using OpenAI API"""
try:
import openai
import os
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
print(
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
# OpenAI has limits on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
print(f"ERROR: Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
print(
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
)
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i : i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)

View File

@@ -1,5 +1,3 @@
import os
import threading
import time
import atexit
@@ -8,16 +6,147 @@ import subprocess
import sys
from pathlib import Path
from typing import Optional
import select
import psutil
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
return s.connect_ex(("localhost", port)) == 0
def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""
Check if the process using the port matches our expected model and passages file.
Returns True if matches, False otherwise.
"""
try:
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
cmdline = proc.info["cmdline"]
if not cmdline:
continue
return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
print(f"DEBUG: No process found listening on port {port}")
return False
except Exception as e:
print(f"WARNING: Could not check process on port {port}: {e}")
return False
def _is_process_listening_on_port(proc, port: int) -> bool:
"""Check if a process is listening on the given port."""
try:
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
return True
return False
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
return False
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(
server_type in cmdline_str
for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server:
print(f"DEBUG: Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
print(
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file."""
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
"""
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
"""
for port in range(start_port, start_port + max_attempts):
if not _check_port(port):
# Port is available
return port, False
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
print(f"✅ Found compatible server on port {port}")
return port, True
else:
print(f"⚠️ Port {port} has incompatible server, trying next port...")
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
)
class EmbeddingServerManager:
"""
A generic manager for handling the lifecycle of a backend-specific embedding server process.
A simplified manager for embedding server processes that avoids complex update mechanisms.
"""
def __init__(self, backend_module_name: str):
"""
Initializes the manager for a specific backend.
@@ -29,78 +158,168 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
atexit.register(self.stop_server)
self._atexit_registered = False
def start_server(self, port: int, model_name: str, **kwargs) -> bool:
def start_server(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""
Starts the embedding server process.
Args:
port (int): The ZMQ port for the server.
port (int): The preferred ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric).
**kwargs: Additional arguments for the server.
Returns:
bool: True if the server is started successfully or already running, False otherwise.
tuple[bool, int]: (success, actual_port_used)
"""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
passages_file = kwargs.get("passages_file")
assert isinstance(passages_file, str), "passages_file must be a string"
# Check if we have a compatible running server
if self._has_compatible_running_server(model_name, passages_file):
assert self.server_port is not None, (
"a compatible running server should set server_port"
)
return True, self.server_port
# Find available port (compatible or free)
try:
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
except RuntimeError as e:
print(f"{e}")
return False, port
if is_compatible:
print(f"✅ Using existing compatible server on port {actual_port}")
self.server_port = actual_port
self.server_process = None # We don't own this process
return True, actual_port
if actual_port != port:
print(f"⚠️ Using port {actual_port} instead of {port}")
# Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
"""Check if we have a compatible running server."""
if not (
self.server_process
and self.server_process.poll() is None
and self.server_port
):
return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
print(
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
)
return True
if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external server is running.")
return True
print("⚠️ Existing server process is incompatible. Should start a new server.")
return False
print(f"INFO: Starting session-level embedding server for '{self.backend_module_name}'...")
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
print(f"INFO: Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
command = [
sys.executable,
"-m", self.backend_module_name,
"--zmq-port", str(port),
"--model-name", model_name
]
# Add extra arguments for specific backends
if "passages_file" in kwargs and kwargs["passages_file"]:
command.extend(["--passages-file", str(kwargs["passages_file"])])
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
# command.extend(["--distance-metric", kwargs["distance_metric"]])
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
text=True,
encoding='utf-8'
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
print(f"❌ ERROR: Failed to start embedding server: {e}")
return False, port
def _build_server_command(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> list:
"""Build the command to start the embedding server."""
command = [
sys.executable,
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
if kwargs.get("passages_file"):
command.extend(["--passages-file", str(kwargs["passages_file"])])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
return command
def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Command: {' '.join(command)}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
bufsize=1,
universal_newlines=True,
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Use a lambda to avoid issues with bound methods
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print("✅ Embedding server is ready!")
threading.Thread(target=self._log_monitor, daemon=True).start()
return True, port
if self.server_process.poll() is not None:
print("❌ ERROR: Server terminated during startup.")
self._print_recent_output()
return False, port
time.sleep(wait_interval)
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:
print(f"[{self.backend_module_name} OUTPUT]: {output}")
except Exception as e:
print(f"Error reading server output: {e}")
def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr."""
@@ -108,25 +327,38 @@ class EmbeddingServerManager:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[{self.backend_module_name} LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[{self.backend_module_name} ERROR]: {line.strip()}")
self.server_process.stderr.close()
while True:
line = self.server_process.stdout.readline()
if not line:
break
print(
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
)
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self):
"""Stops the embedding server process if it's running."""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: Server process did not terminate gracefully, killing it.")
self.server_process.kill()
if not self.server_process:
return
if self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
return
print(
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print(f"INFO: Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
print(
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
)
self.server_process.kill()
self.server_process = None

View File

@@ -1,59 +1,98 @@
from abc import ABC, abstractmethod
import numpy as np
from typing import Dict, Any
from typing import Dict, Any, List, Literal
class LeannBackendBuilderInterface(ABC):
"""用于构建索引的后端接口"""
@abstractmethod
def build(self, data: np.ndarray, index_path: str, **kwargs) -> None:
"""构建索引
"""Backend interface for building indexes"""
@abstractmethod
def build(
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
) -> None:
"""Build index
Args:
data: 向量数据 (N, D)
index_path: 索引保存路径
**kwargs: 后端特定的构建参数
data: Vector data (N, D)
ids: List of string IDs for each vector
index_path: Path to save index
**kwargs: Backend-specific build parameters
"""
pass
class LeannBackendSearcherInterface(ABC):
"""用于搜索的后端接口"""
"""Backend interface for searching"""
@abstractmethod
def __init__(self, index_path: str, **kwargs):
"""初始化搜索器
"""Initialize searcher
Args:
index_path: 索引文件路径
**kwargs: 后端特定的加载参数
index_path: Path to index file
**kwargs: Backend-specific loading parameters
"""
pass
@abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""搜索最近邻
def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> Dict[str, Any]:
"""Search for nearest neighbors
Args:
query: 查询向量 (1, D) 或 (B, D)
top_k: 返回的最近邻数量
**kwargs: 搜索参数
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters
Returns:
{"labels": [...], "distances": [...]}
"""
pass
@abstractmethod
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""Compute embedding for a query string
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array with shape (1, D)
"""
pass
class LeannBackendFactoryInterface(ABC):
"""后端工厂接口"""
"""Backend factory interface"""
@staticmethod
@abstractmethod
def builder(**kwargs) -> LeannBackendBuilderInterface:
"""创建 Builder 实例"""
"""Create Builder instance"""
pass
@staticmethod
@abstractmethod
@abstractmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
"""创建 Searcher 实例"""
pass
"""Create Searcher instance"""
pass

View File

@@ -7,30 +7,37 @@ import importlib.metadata
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str):
"""A decorator to register a new backend class."""
def decorator(cls):
print(f"INFO: Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls
return cls
return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
print("INFO: Starting backend auto-discovery...")
# print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata['name']
if dist_name.startswith('leann-backend-'):
backend_module_name = dist_name.replace('-', '_')
dist_name = dist.metadata["name"]
if dist_name.startswith("leann-backend-"):
backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
for backend_module_name in sorted(
discovered_backends
): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
print("INFO: Backend auto-discovery finished.")
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass
# print("INFO: Backend auto-discovery finished.")

View File

@@ -0,0 +1,193 @@
import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, Literal
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendSearcherInterface
class BaseSearcher(LeannBackendSearcherInterface, ABC):
"""
Abstract base class for Leann searchers, containing common logic for
loading metadata, managing embedding servers, and handling file paths.
"""
def __init__(self, index_path: str, backend_module_name: str, **kwargs):
"""
Initializes the BaseSearcher.
Args:
index_path: Path to the Leann index file (e.g., '.../my_index.leann').
backend_module_name: The specific embedding server module to use
(e.g., 'leann_backend_hnsw.hnsw_embedding_server').
**kwargs: Additional keyword arguments.
"""
self.index_path = Path(index_path)
self.index_dir = self.index_path.parent
self.meta = kwargs.get("meta", self._load_meta())
if not self.meta:
raise ValueError("Searcher requires metadata from .meta.json.")
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print(
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name
)
def _load_meta(self) -> Dict[str, Any]:
"""Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f)
def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs
) -> int:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
"""
if not self.embedding_model:
raise ValueError(
"Cannot use recompute mode without 'embedding_model' in meta.json."
)
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
server_started, actual_port = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
embedding_mode=embedding_mode,
enable_warmup=kwargs.get("enable_warmup", False),
)
if not server_started:
raise RuntimeError(
f"Failed to start embedding server on port {actual_port}"
)
return actual_port
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""
Compute embedding for a query string.
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array
"""
# Try to use embedding server if available and requested
if use_server_if_available:
try:
# Ensure we have a server with passages_file for compatibility
passages_source_file = (
self.index_dir / f"{self.index_path.name}.meta.json"
)
zmq_port = self._ensure_server_running(
str(passages_source_file), zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[
0:1
] # Return (1, D) shape
except Exception as e:
print(f"⚠️ Embedding server failed: {e}")
print("⏭️ Falling back to direct model loading...")
# Fallback to direct computation
from .api import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""
import zmq
import msgpack
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
socket.connect(f"tcp://localhost:{zmq_port}")
# Send embedding request
request = chunks
request_bytes = msgpack.packb(request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Convert response to numpy array
if isinstance(response, list) and len(response) > 0:
return np.array(response, dtype=np.float32)
else:
raise RuntimeError("Invalid response from embedding server")
except Exception as e:
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
@abstractmethod
def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> Dict[str, Any]:
"""
Search for the top_k nearest neighbors of the query vector.
Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
pass
def __del__(self):
"""Ensures the embedding server is stopped when the searcher is destroyed."""
if hasattr(self, "embedding_server_manager"):
self.embedding_server_manager.stop_server()

View File

@@ -0,0 +1,115 @@
import json
import typer
from pathlib import Path
import requests
from tqdm import tqdm
import xml.etree.ElementTree as ET
from typing_extensions import Annotated
import sqlite3
app = typer.Typer()
def get_safe_path(s: str) -> str:
"""
Remove invalid characters to sanitize a path.
:param s: str to sanitize
:returns: sanitized str
"""
ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(
' ', '')
for i in ban_chars:
s = s.replace(i, "")
return s
def process_history(history: str):
if history.startswith("<?xml") or history.startswith("<msg>"):
try:
root = ET.fromstring(history)
title = root.find('.//title').text if root.find('.//title') is not None else None
quoted = root.find('.//refermsg/content').text if root.find('.//refermsg/content') is not None else None
if title and quoted:
return {
"title": title,
"quoted": process_history(quoted)
}
if title:
return title
except Exception:
return history
return history
def get_message(history: dict | str):
if isinstance(history, dict):
if 'title' in history:
return history['title']
else:
return history
def export_chathistory(user_id: str):
res = requests.get("http://localhost:48065/wechat/chatlog", params={
"userId": user_id,
"count": 100000
}).json()
for i in range(len(res['chatLogs'])):
res['chatLogs'][i]['content'] = process_history(res['chatLogs'][i]['content'])
res['chatLogs'][i]['message'] = get_message(res['chatLogs'][i]['content'])
return res['chatLogs']
@app.command()
def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]):
"""
Export all users' chat history to json files.
"""
if not dest.is_dir():
if not dest.exists():
inp = typer.prompt("Destination path does not exist, create it? (y/n)")
if inp.lower() == 'y':
dest.mkdir(parents=True)
else:
typer.echo("Aborted.", err=True)
return
else:
typer.echo("Destination path is not a directory!", err=True)
return
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
exported_count = 0
for user in tqdm(all_users):
try:
usr_chatlog = export_chathistory(user['arg'])
# Only write file if there are messages
if len(usr_chatlog) > 0:
out_path = dest/get_safe_path((user['title'] or "")+"-"+user['arg']+'.json')
with open(out_path, 'w', encoding='utf-8') as f:
json.dump(usr_chatlog, f, ensure_ascii=False, indent=2)
exported_count += 1
except Exception as e:
print(f"Error exporting {user.get('title', 'Unknown')}: {e}")
continue
print(f"Exported {exported_count} users' chat history to {dest} in json.")
@app.command()
def export_sqlite(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path("chatlog.db")):
"""
Export all users' chat history to a sqlite database.
"""
connection = sqlite3.connect(dest)
cursor = connection.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)")
cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)")
cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)")
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
for user in tqdm(all_users):
cursor.execute("INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user['arg'], user['title']))
usr_chatlog = export_chathistory(user['arg'])
for msg in usr_chatlog:
cursor.execute("INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)", (user['arg'], msg['fromUser'], msg['toUser'], msg['message'], msg['createTime'], str(msg['content'])))
connection.commit()
if __name__ == "__main__":
app()

View File

Binary file not shown.

View File

@@ -9,7 +9,6 @@ requires-python = ">=3.10"
dependencies = [
"leann-core",
"leann-backend-diskann",
"leann-backend-hnsw",
"numpy>=1.26.0",
"torch",
@@ -21,7 +20,7 @@ dependencies = [
"colorama",
"boto3",
"protobuf==4.25.3",
"sglang[all]",
"sglang",
"ollama",
"requests>=2.25.0",
"sentence-transformers>=2.2.0",
@@ -32,6 +31,11 @@ dependencies = [
"llama-index-node-parser-docling",
"ipykernel==6.29.5",
"msgpack>=1.1.1",
"llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5",
"mlx>=0.26.3",
"mlx-lm>=0.26.0",
"psutil>=5.8.0",
]
[project.optional-dependencies]
@@ -41,6 +45,11 @@ dev = [
"black>=23.0",
"ruff>=0.1.0",
"matplotlib",
"huggingface-hub>=0.20.0",
]
diskann = [
"leann-backend-diskann",
]
[tool.setuptools]

View File

@@ -23,7 +23,7 @@ g++ ./demo_reader.cpp -o ./demo_reader && ./demo_reader --stats \
f.read(reinterpret_cast<char *>(&val), sizeof(uint32_t))
#define SECTOR_SIZE 4096
// 辅助:获取文件大小
// Helper: Get file size
static size_t get_file_size(const std::string &fname) {
std::ifstream ifs(fname, std::ios::binary | std::ios::ate);
if (ifs.fail() || !ifs.is_open()) {
@@ -32,7 +32,7 @@ static size_t get_file_size(const std::string &fname) {
return static_cast<size_t>(ifs.tellg());
}
// 打印 sector 的前若干 hex用于debug
// Print first few hex of sector for debug
static void print_hex(const char *buf, size_t len, size_t max_len = 64) {
size_t show_len = (len < max_len) ? len : max_len;
for (size_t i = 0; i < show_len; i++) {
@@ -46,19 +46,19 @@ static void print_hex(const char *buf, size_t len, size_t max_len = 64) {
}
/*
修正后的 demo_reader:
1) partition.bin:
Corrected demo_reader:
1) Read from partition.bin:
- C, partition_nums, nd
- graph_partitions[i]: 分区 i 的所有 nodeID
- graph_partitions[i]: all nodeIDs in partition i
- id2partition[nodeID]: nodeID => partition i
2) _disk_graph.index:
a) sector0 里先有 2 int: meta_n, meta_dim
b) 再有 meta_n uint64_t
例如: [0]=nd, [1]=dim, [2]=??, [3]=max_node_len, [4]=C, [5]..??,
[8]=file_size... 具体位置要结合 relayout 的写法 c) graph_node_len =
max_node_len - dim_in_meta*sizeof(float) 3) 用户给定 target_node_id =>
2) Read from _disk_graph.index:
a) sector0 first has 2 ints: meta_n, meta_dim
b) then meta_n uint64_t
e.g.: [0]=nd, [1]=dim, [2]=??, [3]=max_node_len, [4]=C, [5]..??,
[8]=file_size... specific positions need to be combined with relayout writing c) graph_node_len =
max_node_len - dim_in_meta*sizeof(float) 3) User given target_node_id =>
partition_id= id2partition[node_id]
graph_partitions[partition_id] 里找 node 的下标 j
find node index j in graph_partitions[partition_id]
offset = (partition_id+1)*4096 => sector
adjacency_offset= j*graph_node_len => neighbor_count => neighbors
*/
@@ -105,7 +105,7 @@ int main(int argc, char **argv) {
<< "\n";
}
// 1) 读取 partition.bin
// 1) Read partition.bin
std::ifstream pf(partition_bin, std::ios::binary);
if (!pf.is_open()) {
std::cerr << "Cannot open partition.bin: " << partition_bin << std::endl;
@@ -119,8 +119,8 @@ int main(int argc, char **argv) {
<< ", partition_nums=" << partition_nums << ", nd=" << nd
<< std::endl;
// 读取分区节点列表
std::vector<std::vector<uint32_t>> graph_partitions(partition_nums);
// Read partition node lists
std::vector<std::vector<uint32_t> > graph_partitions(partition_nums);
for (uint64_t i = 0; i < partition_nums; i++) {
uint32_t psize;
READ_U32(pf, psize);
@@ -128,7 +128,7 @@ int main(int argc, char **argv) {
pf.read(reinterpret_cast<char *>(graph_partitions[i].data()),
psize * sizeof(uint32_t));
}
// 读取 _id2partition[node], 大小= nd
// Read _id2partition[node], size= nd
std::vector<uint32_t> id2partition(nd);
pf.read(reinterpret_cast<char *>(id2partition.data()), nd * sizeof(uint32_t));
pf.close();
@@ -140,23 +140,23 @@ int main(int argc, char **argv) {
return 1;
}
// 2) 解析 _disk_graph.index
// 2) Parse _disk_graph.index
std::ifstream gf(graph_index, std::ios::binary);
if (!gf.is_open()) {
std::cerr << "Cannot open disk_graph.index: " << graph_index << std::endl;
return 1;
}
// (a) sector0 => 先读 2 int
// (a) sector0 => first read 2 ints
int meta_n, meta_dim;
gf.read((char *)&meta_n, sizeof(int));
gf.read((char *)&meta_dim, sizeof(int));
std::cout << "[debug] meta_n=" << meta_n << ", meta_dim=" << meta_dim << "\n";
// (b) meta_n uint64_t
// (b) Read meta_n uint64_t
std::vector<uint64_t> meta_info(meta_n);
gf.read(reinterpret_cast<char *>(meta_info.data()),
meta_n * sizeof(uint64_t));
// 打印
// Print
for (int i = 0; i < meta_n; i++) {
std::cout << " meta_info[" << i << "]= " << meta_info[i] << "\n";
}
@@ -164,11 +164,11 @@ int main(int argc, char **argv) {
size_t file_size = get_file_size(graph_index);
std::cout << "[disk_graph.index size] " << file_size << " bytes\n";
// **根据 relayout log** 你说: meta_info[0]=nd=60450220, meta_info[1]=dim=769,
// **According to relayout log** you said: meta_info[0]=nd=60450220, meta_info[1]=dim=769,
// meta_info[2]=??(16495248?), meta_info[3]=max_node_len=3320,
// meta_info[4]=16 (C),
// meta_info[8]= 15475261440(文件大小)
// 我们这里先手动解析:
// meta_info[8]= 15475261440(file size)
// We manually parse here first:
uint64_t nd_in_meta = meta_info[0];
uint64_t dim_in_meta = meta_info[1];
uint64_t max_node_len = meta_info[3];
@@ -182,7 +182,7 @@ int main(int argc, char **argv) {
<< ", c_in_meta= " << c_in_meta
<< ", entire_file_size= " << entire_file_sz << "\n";
// 计算 graph_node_len
// Calculate graph_node_len
uint64_t dim_size = dim_in_meta * sizeof(float);
uint64_t graph_node_len = max_node_len - dim_size;
std::cout << " => graph_node_len= " << graph_node_len << "\n\n";
@@ -305,7 +305,7 @@ int main(int argc, char **argv) {
// Error check pf_again if needed
}
// 3) target_node_id => partition_id => subIndex
// 3) Find target_node_id => partition_id => subIndex
uint32_t partition_id = id2partition[target_node_id];
if (partition_id >= partition_nums) {
std::cerr << "Partition ID out-of-range for target node.\n";

44
test/build_mlx_index.py Normal file
View File

@@ -0,0 +1,44 @@
import os
from leann.api import LeannBuilder, LeannSearcher, LeannChat
# Define the path for our new MLX-based index
INDEX_PATH = "./mlx_diskann_index/leann"
if os.path.exists(INDEX_PATH + ".meta.json"):
print(f"Index already exists at {INDEX_PATH}. Skipping build.")
else:
print("Initializing LeannBuilder with MLX support...")
# 1. Configure LeannBuilder to use MLX
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
use_mlx=True,
)
# 2. Add documents
print("Adding documents...")
docs = [
"MLX is an array framework for machine learning on Apple silicon.",
"It was designed by Apple's machine learning research team.",
"The mlx-community organization provides pre-trained models in MLX format.",
"It supports operations on multi-dimensional arrays.",
"Leann can now use MLX for its embedding models.",
]
for doc in docs:
builder.add_text(doc)
# 3. Build the index
print(f"Building the MLX-based index at: {INDEX_PATH}")
builder.build_index(INDEX_PATH)
print("\nSuccessfully built the index with MLX embeddings!")
print(f"Check the metadata file: {INDEX_PATH}.meta.json")
chat = LeannChat(index_path=INDEX_PATH)
# add query
query = "MLX is an array framework for machine learning on Apple silicon."
print(f"Query: {query}")
response = chat.ask(
query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1
)
print(f"Response: {response}")

View File

@@ -0,0 +1,147 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content
doc_content = f"""
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
}
if count == 0:
print("--------------------------------")
print('dir path', dirpath)
print(metadata)
print(doc_content)
print("--------------------------------")
body=[]
if msg.is_multipart():
for part in msg.walk():
print("-------------------------------- get content type -------------------------------")
print(part.get_content_type())
print(part)
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
print("-------------------------------- get content type -------------------------------")
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
print(body)
print(body)
print("--------------------------------")
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!")
continue
except Exception as e:
print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
# Use the custom EmlxReader instead of MboxReader
documents = EmlxReader().load_data(
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
max_count=1000
) # Returns list of documents
# Configure the index with larger chunk size to handle long metadata
from llama_index.core.node_parser import SentenceSplitter
# Create a custom text splitter with larger chunk size
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
) # Initialize index with documents
query_engine = index.as_query_engine()
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
print(res)

View File

@@ -0,0 +1,213 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document, StorageContext
from llama_index.core.readers.base import BaseReader
from llama_index.core.node_parser import SentenceSplitter
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content
doc_content = f"""
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
"""
Create the index from mail data and save it to disk.
Args:
mail_path: Path to the mail directory
save_dir: Directory to save the index
max_count: Maximum number of emails to process
"""
print("Creating index from mail data...")
# Load documents
documents = EmlxReader().load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
# Create text splitter
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
# Create index
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
)
# Save the index
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
"""
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index"
# Check if index already exists
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
index = load_index(save_dir)
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=1000)
if index:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"What emails mention GSR appointments?",
"Find emails about deadlines"
]
for query in queries:
print("\n" + "="*50)
query_index(index, query)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,211 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document, StorageContext
from llama_index.core.readers.base import BaseReader
from llama_index.core.node_parser import SentenceSplitter
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with reduced metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content with metadata embedded in text
doc_content = f"""
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
{body}
"""
# Create minimal metadata (only essential info)
metadata = {
'subject': subject[:50], # Truncate subject
'from': from_addr[:30], # Truncate from
'date': date[:20], # Truncate date
'filename': filename # Keep filename
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000):
"""
Create the index from mail data and save it to disk.
Args:
mail_path: Path to the mail directory
save_dir: Directory to save the index
max_count: Maximum number of emails to process
"""
print("Creating index from mail data with small chunks...")
# Load documents
documents = EmlxReader().load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
# Create text splitter with small chunk size
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
# Create index
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
)
# Save the index
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index_small"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
"""
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index_small"
# Check if index already exists
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
index = load_index(save_dir)
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=1000)
if index:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"What emails mention GSR appointments?",
"Find emails about deadlines"
]
for query in queries:
print("\n" + "="*50)
query_index(index, query)
if __name__ == "__main__":
main()

147
test/mail_reader_test.py Normal file
View File

@@ -0,0 +1,147 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Check if directory exists and is accessible
if not os.path.exists(input_dir):
print(f"Error: Directory '{input_dir}' does not exist")
return docs
if not os.access(input_dir, os.R_OK):
print(f"Error: Directory '{input_dir}' is not accessible (permission denied)")
print("This is likely due to macOS security restrictions on Mail app data")
return docs
print(f"Scanning directory: {input_dir}")
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
print(f"Found .emlx file: {filepath}")
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content
doc_content = f"""
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def main():
# Use the current directory where the sample.emlx file is located
current_dir = os.path.dirname(os.path.abspath(__file__))
print("Testing EmlxReader with sample .emlx file...")
print(f"Scanning directory: {current_dir}")
# Use the custom EmlxReader
documents = EmlxReader().load_data(current_dir, max_count=1000)
if not documents:
print("No documents loaded. Make sure sample.emlx exists in the examples directory.")
return
print(f"\nSuccessfully loaded {len(documents)} document(s)")
# Initialize index with documents
index = VectorStoreIndex.from_documents(documents)
query_engine = index.as_query_engine()
print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'")
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
print(f"Response: {res}")
if __name__ == "__main__":
main()

630
test/micro_tpt.py Normal file
View File

@@ -0,0 +1,630 @@
# python embedd_micro.py --use_int8 Fastest
import argparse
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm
from contextlib import contextmanager
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: List[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False # Add this parameter
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
class GraphContainer:
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, seq_length: int):
self.model = model
self.seq_length = seq_length
self.graphs: Dict[int, 'GraphWrapper'] = {}
def get_or_create(self, batch_size: int) -> 'GraphWrapper':
if batch_size not in self.graphs:
self.graphs[batch_size] = GraphWrapper(
self.model, batch_size, self.seq_length
)
return self.graphs[batch_size]
class GraphWrapper:
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.device = self._get_device()
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Only use CUDA graphs on NVIDIA GPUs
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'):
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
self.use_cuda_graph = True
else:
# For MPS or CPU, just store the model
self.use_cuda_graph = False
self.static_output = None
def _get_device(self) -> str:
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length),
device=self.device,
dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.use_cuda_graph:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
else:
# For MPS/CPU, just run normally
return self.model(input_ids=input_ids, attention_mask=attention_mask)
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
if model is None:
raise ValueError("Cannot optimize None model")
# Move to GPU
if torch.cuda.is_available():
model = model.cuda()
device = "cuda"
elif torch.backends.mps.is_available():
model = model.to("mps")
device = "mps"
else:
model = model.cpu()
device = "cpu"
print(f"- Model moved to {device}")
# FP16
if config.use_fp16 and not config.use_int4:
model = model.half()
# use torch compile
model = torch.compile(model)
print("- Using FP16 precision")
# Check if using SDPA (only on CUDA)
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Flash Attention (only on CUDA)
if config.use_flash_attention and torch.cuda.is_available():
try:
from flash_attn.flash_attention import FlashAttention
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Memory efficient attention (only on CUDA)
if torch.cuda.is_available():
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using GPU events or CPU timing."""
def __init__(self):
if torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.use_gpu_timing = True
elif torch.backends.mps.is_available():
# MPS doesn't have events, use CPU timing
self.use_gpu_timing = False
else:
# CPU timing
self.use_gpu_timing = False
@contextmanager
def timing(self):
if self.use_gpu_timing:
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
else:
# Use CPU timing for MPS/CPU
start_time = time.time()
yield
self.cpu_elapsed = time.time() - start_time
def elapsed_time(self) -> float:
if self.use_gpu_timing:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
else:
return self.cpu_elapsed
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
try:
self.model = self._load_model()
if self.model is None:
raise ValueError("Model initialization failed - model is None")
# Only use CUDA graphs on NVIDIA GPUs
if config.use_cuda_graphs and torch.cuda.is_available():
self.graphs = GraphContainer(self.model, config.seq_length)
else:
self.graphs = None
self.timer = Timer()
except Exception as e:
print(f"ERROR in benchmark initialization: {str(e)}")
raise
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
try:
# Int4 quantization using HuggingFace integration
if self.config.use_int4:
import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}")
# 检查是否使用自定义的8bit量化
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers")
# 加载原始模型(不使用量化配置)
import bitsandbytes as bnb
import torch
# set default to half
torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
model = AutoModel.from_pretrained(
self.config.model_path,
torch_dtype=compute_dtype,
)
# 定义替换函数
def replace_linear_with_linear8bitlt(model):
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
for name, module in list(model.named_children()):
if isinstance(module, nn.Linear):
# 获取原始线性层的参数
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# 创建8bit线性层
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features,
out_features,
bias=bias,
has_fp16_weights=False
)
# 复制权重和偏置
new_module.weight.data = module.weight.data
if bias:
new_module.bias.data = module.bias.data
# 替换模块
setattr(model, name, new_module)
else:
# 递归处理子模块
replace_linear_with_linear8bitlt(module)
return model
# 替换所有线性层
model = replace_linear_with_linear8bitlt(model)
# add torch compile
model = torch.compile(model)
# 将模型移到GPU量化发生在这里
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = model.to(device)
print("- All linear layers replaced with Linear8bitLt")
else:
# 使用原来的Int4量化方法
print("- Using bitsandbytes for Int4 quantization")
# Create quantization config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
print("- Quantization config:", quantization_config)
# Load model directly with quantization config
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto" # Let HF decide on device mapping
)
# Check if model loaded successfully
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
# Apply optimizations directly here
print("\nApplying model optimizations:")
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization")
else:
# Skip moving to GPU since device_map="auto" already did that
print("- Model already on GPU due to device_map='auto'")
# Skip FP16 conversion since we specified compute_dtype
print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Try xformers if available (only on CUDA)
if torch.cuda.is_available():
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# Int8 quantization using HuggingFace integration
elif self.config.use_int8:
print("- Using INT8 quantization")
# For now, just use standard loading with INT8 config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto"
)
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
model.eval()
print("- Model set to eval mode")
else:
# Standard loading for FP16/FP32
model = AutoModel.from_pretrained(self.config.model_path)
print("- Model loaded in standard precision")
print(f"- Model type: {type(model)}")
# Apply standard optimizations
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config)
model = model.half()
# add torch compile
model = torch.compile(model)
# Final check to ensure model is not None
if model is None:
raise ValueError("Model is None after optimization")
print(f"- Final model type: {type(model)}")
return model
except Exception as e:
print(f"ERROR loading model: {str(e)}")
import traceback
traceback.print_exc()
raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device=device,
dtype=torch.long
)
def _run_inference(
self,
input_ids: torch.Tensor,
graph_wrapper: Optional[GraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing():
if graph_wrapper is not None:
output = graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]:
results = {}
# Reset peak memory stats
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
elif torch.backends.mps.is_available():
# MPS doesn't have reset_peak_memory_stats, skip it
pass
else:
print("- No GPU memory stats available")
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create graph for this batch size
graph_wrapper = (
self.graphs.get_or_create(batch_size)
if self.graphs is not None
else None
)
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
print(f"Input shape: {input_ids.shape}")
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time, output = self._run_inference(input_ids, graph_wrapper)
if i == 0: # Only print on first run
print(f"Output shape: {output.last_hidden_state.shape}")
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
print(f"No successful runs for batch size {batch_size}, skipping")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
# Log memory usage
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
elif torch.backends.mps.is_available():
# MPS doesn't have max_memory_allocated, use 0
peak_memory_gb = 0.0
else:
peak_memory_gb = 0.0
print("- No GPU memory usage available")
if peak_memory_gb > 0:
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
else:
print("\n- GPU memory usage not available")
# Add memory info to results
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,16,32",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--use_fp16",
action="store_true",
help="Enable FP16 inference",
)
parser.add_argument(
"--use_int4",
action="store_true",
help="Enable INT4 quantization using bitsandbytes",
)
parser.add_argument(
"--use_int8",
action="store_true",
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization (only on NVIDIA GPUs)",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available (only on NVIDIA GPUs)",
)
parser.add_argument(
"--use_linear8bitlt",
action="store_true",
help="Enable Linear8bitLt quantization for all linear layers",
)
args = parser.parse_args()
# Print arguments for debugging
print("\nCommand line arguments:")
for arg, value in vars(args).items():
print(f"- {arg}: {value}")
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=args.use_fp16,
use_int4=args.use_int4,
use_int8=args.use_int8, # Add this line
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
use_linear8bitlt=args.use_linear8bitlt,
)
# Print configuration for debugging
print("\nBenchmark configuration:")
for field, value in vars(config).items():
print(f"- {field}: {value}")
try:
benchmark = Benchmark(config)
results = benchmark.run()
# Save results to file
import json
import os
# Create results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Generate filename based on configuration
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32"
model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
# Save results
with open(output_file, "w") as f:
json.dump(
{
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
"results": {str(k): v for k, v in results.items()}
},
f,
indent=2
)
print(f"Results saved to {output_file}")
except Exception as e:
print(f"Benchmark failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

99
test/query_saved_index.py Normal file
View File

@@ -0,0 +1,99 @@
import os
from llama_index.core import VectorStoreIndex, StorageContext
def load_index(save_dir: str = "mail_index"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
"""
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"\nQuery: {query}")
print(f"Response: {response}")
def main():
save_dir = "mail_index"
# Check if index exists
if not os.path.exists(save_dir) or not os.path.exists(os.path.join(save_dir, "vector_store.json")):
print(f"Index not found in {save_dir}")
print("Please run mail_reader_save_load.py first to create the index.")
return
# Load the index
index = load_index(save_dir)
if not index:
print("Failed to load index.")
return
print("\n" + "="*60)
print("Email Query Interface")
print("="*60)
print("Type 'quit' to exit")
print("Type 'help' for example queries")
print("="*60)
# Interactive query loop
while True:
try:
query = input("\nEnter your query: ").strip()
if query.lower() == 'quit':
print("Goodbye!")
break
elif query.lower() == 'help':
print("\nExample queries:")
print("- Hows Berkeley Graduate Student Instructor")
print("- What emails mention GSR appointments?")
print("- Find emails about deadlines")
print("- Search for emails from specific sender")
print("- Find emails about meetings")
continue
elif not query:
continue
query_index(index, query)
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error processing query: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,128 @@
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from sentence_transformers import SentenceTransformer
import mlx.core as mx
from mlx_lm import load
# --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
NUM_RUNS = 10 # Number of runs to average for each batch size
WARMUP_RUNS = 2 # Number of warm-up runs
# --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- Benchmark Functions ---b
def benchmark_torch(model, sentences):
start_time = time.time()
model.encode(sentences, convert_to_numpy=True)
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
def benchmark_mlx(model, tokenizer, sentences):
start_time = time.time()
# Tokenize sentences using MLX tokenizer
tokens = []
for sentence in sentences:
token_ids = tokenizer.encode(sentence)
tokens.append(token_ids)
# Pad sequences to the same length
max_len = max(len(t) for t in tokens)
input_ids = []
attention_mask = []
for token_seq in tokens:
# Pad sequence
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
input_ids.append(padded)
# Create attention mask (1 for real tokens, 0 for padding)
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
attention_mask.append(mask)
# Convert to MLX arrays
input_ids = mx.array(input_ids)
attention_mask = mx.array(attention_mask)
# Get embeddings
embeddings = model(input_ids)
# Mean pooling
mask = mx.expand_dims(attention_mask, -1)
sum_embeddings = (embeddings * mask).sum(axis=1)
sum_mask = mask.sum(axis=1)
_ = sum_embeddings / sum_mask
mx.eval() # Ensure computation is finished
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution ---
def main():
print("--- Initializing Models ---")
# Load PyTorch model
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
print(f"PyTorch model loaded on: {device}")
# Load MLX model
print(f"Loading MLX model: {MODEL_NAME_MLX}")
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
print("MLX model loaded.")
# --- Warm-up ---
print("\n--- Performing Warm-up Runs ---")
for _ in range(WARMUP_RUNS):
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
print("Warm-up complete.")
# --- Benchmarking ---
print("\n--- Starting Benchmark ---")
results_torch = []
results_mlx = []
for batch_size in BATCH_SIZES:
print(f"Benchmarking batch size: {batch_size}")
sentences_batch = DUMMY_SENTENCES[:batch_size]
# Benchmark PyTorch
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
results_torch.append(np.mean(torch_times))
# Benchmark MLX
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---")
print(f"Batch Sizes: {BATCH_SIZES}")
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES)
plt.grid(True)
plt.legend()
# Save the plot
output_filename = "embedding_benchmark.png"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")
if __name__ == "__main__":
main()

314
test/simple_mac_tpt_test.py Normal file
View File

@@ -0,0 +1,314 @@
import time
from dataclasses import dataclass
from typing import Dict, List
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm
# Add MLX imports
try:
import mlx.core as mx
from mlx_lm.utils import load
MLX_AVAILABLE = True
except ImportError as e:
print("MLX not available. Install with: uv pip install mlx mlx-lm")
MLX_AVAILABLE = False
@dataclass
class BenchmarkConfig:
model_path: str = "facebook/contriever"
batch_sizes: List[int] = None
seq_length: int = 256
num_runs: int = 5
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
use_mlx: bool = False # New flag for MLX testing
def __post_init__(self):
if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
class MLXBenchmark:
"""MLX-specific benchmark for embedding models"""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.model, self.tokenizer = self._load_model()
def _load_model(self):
"""Load MLX model and tokenizer following the API pattern"""
print(f"Loading MLX model from {self.config.model_path}...")
try:
model, tokenizer = load(self.config.model_path)
print("MLX model loaded successfully")
return model, tokenizer
except Exception as e:
print(f"Error loading MLX model: {e}")
raise
def _create_random_batch(self, batch_size: int):
"""Create random input batches for MLX testing - same as PyTorch"""
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
dtype=torch.long
)
def _run_inference(self, input_ids: torch.Tensor) -> float:
"""Run MLX inference with same input as PyTorch"""
start_time = time.time()
try:
# Convert PyTorch tensor to MLX array
input_ids_mlx = mx.array(input_ids.numpy())
# Get embeddings
embeddings = self.model(input_ids_mlx)
# Mean pooling (following the API pattern)
pooled = embeddings.mean(axis=1)
# Convert to numpy (following the API pattern)
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
# Force computation
_ = pooled_numpy.shape
except Exception as e:
print(f"MLX inference error: {e}")
return float('inf')
end_time = time.time()
return end_time - start_time
def run(self) -> Dict[int, Dict[str, float]]:
"""Run the MLX benchmark across all batch sizes"""
results = {}
print(f"Starting MLX benchmark with model: {self.config.model_path}")
print(f"Testing batch sizes: {self.config.batch_sizes}")
for batch_size in self.config.batch_sizes:
print(f"\n=== Testing MLX batch size: {batch_size} ===")
times = []
# Create input batch (same as PyTorch)
input_ids = self._create_random_batch(batch_size)
# Warm up
print("Warming up...")
for _ in range(3):
try:
self._run_inference(input_ids[:2]) # Warm up with smaller batch
except Exception as e:
print(f"Warmup error: {e}")
break
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
if elapsed_time != float('inf'):
times.append(elapsed_time)
except Exception as e:
print(f"Error during MLX inference: {e}")
break
if not times:
print(f"Skipping batch size {batch_size} due to errors")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
"min_time": np.min(times),
"max_time": np.max(times),
}
print(f"MLX Results for batch size {batch_size}:")
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f" Min Time: {np.min(times):.4f}s")
print(f" Max Time: {np.max(times):.4f}s")
print(f" Throughput: {throughput:.2f} sequences/second")
return results
class Benchmark:
def __init__(self, config: BenchmarkConfig):
self.config = config
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
self.model = self._load_model()
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path)
if self.config.use_fp16:
model = model.half()
model = torch.compile(model)
model = model.to(self.device)
model.eval()
return model
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device=self.device,
dtype=torch.long
)
def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids)
start_time = time.time()
with torch.no_grad():
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
end_time = time.time()
return end_time - start_time
def run(self) -> Dict[int, Dict[str, float]]:
results = {}
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
input_ids = self._create_random_batch(batch_size)
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
continue
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
else:
peak_memory_gb = 0.0
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def run_benchmark():
"""Main function to run the benchmark with optimized parameters."""
config = BenchmarkConfig()
try:
benchmark = Benchmark(config)
results = benchmark.run()
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results
}
except Exception as e:
print(f"Benchmark failed: {e}")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": str(e)
}
def run_mlx_benchmark():
"""Run MLX-specific benchmark"""
if not MLX_AVAILABLE:
print("MLX not available, skipping MLX benchmark")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "MLX not available"
}
config = BenchmarkConfig(
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
use_mlx=True
)
try:
benchmark = MLXBenchmark(config)
results = benchmark.run()
if not results:
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "No valid results"
}
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results
}
except Exception as e:
print(f"MLX benchmark failed: {e}")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": str(e)
}
if __name__ == "__main__":
print("=== PyTorch Benchmark ===")
pytorch_result = run_benchmark()
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
print("\n=== MLX Benchmark ===")
mlx_result = run_mlx_benchmark()
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
# Compare results
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
print(f"\n=== Comparison ===")
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")

View File

@@ -1,68 +0,0 @@
# HNSW Index Storage Optimization
This document explains the storage optimization features available in the HNSW backend.
## Storage Modes
The HNSW backend supports two orthogonal optimization techniques:
### 1. CSR Compression (`is_compact=True`)
- Converts the graph structure from standard format to Compressed Sparse Row (CSR) format
- Reduces memory overhead from graph adjacency storage
- Maintains all embedding data for direct access
### 2. Embedding Pruning (`is_recompute=True`)
- Removes embedding vectors from the index file
- Replaces them with a NULL storage marker
- Requires recomputation via embedding server during search
- Must be used with `is_compact=True` for efficiency
## Performance Impact
**Storage Reduction (100 vectors, 384 dimensions):**
```
Standard format: 168 KB (embeddings + graph)
CSR only: 160 KB (embeddings + compressed graph)
CSR + Pruned: 6 KB (compressed graph only)
```
**Key Benefits:**
- **CSR compression**: ~5% size reduction from graph optimization
- **Embedding pruning**: ~95% size reduction by removing embeddings
- **Combined**: Up to 96% total storage reduction
## Usage
```python
# Standard format (largest)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=False,
is_recompute=False
)
# CSR compressed (medium)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=True,
is_recompute=False
)
# CSR + Pruned (smallest, requires embedding server)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=True, # Required for pruning
is_recompute=True # Default: enabled
)
```
## Trade-offs
| Mode | Storage | Search Speed | Memory Usage | Setup Complexity |
|------|---------|--------------|--------------|------------------|
| Standard | Largest | Fastest | Highest | Simple |
| CSR | Medium | Fast | Medium | Simple |
| CSR + Pruned | Smallest | Slower* | Lowest | Complex** |
*Requires network round-trip to embedding server for recomputation
**Needs embedding server and passages file for search

View File

@@ -1,107 +0,0 @@
#!/usr/bin/env python3
"""
DiskANN 距离函数测试
"""
import os
from pathlib import Path
import shutil
import time
# 导入后端包以触发插件注册
try:
import leann_backend_diskann
import leann_backend_hnsw
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
# 从 leann-core 导入上层 API
from leann.api import LeannBuilder, LeannSearcher
def load_sample_documents():
"""创建用于演示的样本文档"""
docs = [
{"title": "Intro to Python", "content": "Python is a programming language for machine learning"},
{"title": "ML Basics", "content": "Machine learning algorithms build intelligent systems"},
{"title": "Data Structures", "content": "Data structures like arrays and graphs organize information"},
]
return docs
def test_distance_function(distance_func, test_name):
"""测试特定距离函数"""
print(f"\n=== 测试 {test_name} ({distance_func}) ===")
INDEX_DIR = Path(f"./test_indices_{distance_func}")
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
# 构建索引
print(f"构建索引 (距离函数: {distance_func})...")
builder = LeannBuilder(
backend_name="diskann",
distance_metric=distance_func,
graph_degree=16,
complexity=32
)
documents = load_sample_documents()
for doc in documents:
builder.add_text(doc["content"], metadata=doc)
try:
builder.build_index(INDEX_PATH)
print(f"✅ 索引构建成功")
# 测试搜索
searcher = LeannSearcher(INDEX_PATH, distance_metric=distance_func)
results = searcher.search("machine learning programming", top_k=2)
print(f"搜索结果:")
for i, result in enumerate(results):
print(f" {i+1}. Score: {result['score']:.4f}")
print(f" Text: {result['text'][:50]}...")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
return False
def main():
print("🔍 DiskANN 距离函数测试")
print("=" * 50)
# 测试不同距离函数
distance_tests = [
("mips", "Maximum Inner Product Search"),
("l2", "L2 Euclidean Distance"),
("cosine", "Cosine Similarity")
]
results = {}
for distance_func, test_name in distance_tests:
try:
success = test_distance_function(distance_func, test_name)
results[distance_func] = success
except Exception as e:
print(f"{distance_func} 测试异常: {e}")
results[distance_func] = False
# 总结
print("\n" + "=" * 50)
print("📊 测试结果总结:")
for distance_func, success in results.items():
status = "✅ 通过" if success else "❌ 失败"
print(f" {distance_func:10s}: {status}")
print("\n🎉 测试完成!")
if __name__ == "__main__":
main()

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env python3
"""
Sanity check script to verify HNSW index pruning effectiveness.
Tests the difference in file sizes between pruned and non-pruned indices.
"""
import os
import sys
import tempfile
from pathlib import Path
import numpy as np
import json
# Add the project root to the Python path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
# Import backend packages to trigger plugin registration
import leann_backend_hnsw
from leann.api import LeannBuilder
def create_sample_documents(num_docs=1000):
"""Create sample documents for testing"""
documents = []
for i in range(num_docs):
documents.append(f"Sample document {i} with some random text content for testing purposes.")
return documents
def build_index(documents, output_dir, is_recompute=True):
"""Build HNSW index with specified recompute setting"""
index_path = os.path.join(output_dir, "test_index.hnsw")
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
M=16,
efConstruction=100,
distance_metric="mips",
is_compact=True,
is_recompute=is_recompute
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
return index_path
def get_file_size(filepath):
"""Get file size in bytes"""
return os.path.getsize(filepath)
def main():
print("🔍 HNSW Pruning Sanity Check")
print("=" * 50)
# Create sample data
print("📊 Creating sample documents...")
documents = create_sample_documents(num_docs=1000)
print(f" Number of documents: {len(documents)}")
with tempfile.TemporaryDirectory() as temp_dir:
print(f"📁 Working in temporary directory: {temp_dir}")
# Build index with pruning (is_recompute=True)
print("\n🔨 Building index with pruning enabled (is_recompute=True)...")
pruned_dir = os.path.join(temp_dir, "pruned")
os.makedirs(pruned_dir, exist_ok=True)
pruned_index_path = build_index(documents, pruned_dir, is_recompute=True)
# Check what files were actually created
print(f" Looking for index files at: {pruned_index_path}")
import glob
files = glob.glob(f"{pruned_index_path}*")
print(f" Found files: {files}")
# Try to find the actual index file
if os.path.exists(f"{pruned_index_path}.index"):
pruned_index_file = f"{pruned_index_path}.index"
else:
# Look for any .index file in the directory
index_files = glob.glob(f"{pruned_dir}/*.index")
if index_files:
pruned_index_file = index_files[0]
else:
raise FileNotFoundError(f"No .index file found in {pruned_dir}")
pruned_size = get_file_size(pruned_index_file)
print(f" ✅ Pruned index built successfully")
print(f" 📏 Pruned index size: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
# Build index without pruning (is_recompute=False)
print("\n🔨 Building index without pruning (is_recompute=False)...")
non_pruned_dir = os.path.join(temp_dir, "non_pruned")
os.makedirs(non_pruned_dir, exist_ok=True)
non_pruned_index_path = build_index(documents, non_pruned_dir, is_recompute=False)
# Check what files were actually created
print(f" Looking for index files at: {non_pruned_index_path}")
files = glob.glob(f"{non_pruned_index_path}*")
print(f" Found files: {files}")
# Try to find the actual index file
if os.path.exists(f"{non_pruned_index_path}.index"):
non_pruned_index_file = f"{non_pruned_index_path}.index"
else:
# Look for any .index file in the directory
index_files = glob.glob(f"{non_pruned_dir}/*.index")
if index_files:
non_pruned_index_file = index_files[0]
else:
raise FileNotFoundError(f"No .index file found in {non_pruned_dir}")
non_pruned_size = get_file_size(non_pruned_index_file)
print(f" ✅ Non-pruned index built successfully")
print(f" 📏 Non-pruned index size: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
# Compare sizes
print("\n📊 Comparison Results:")
print("=" * 30)
size_diff = non_pruned_size - pruned_size
size_ratio = pruned_size / non_pruned_size if non_pruned_size > 0 else 0
reduction_percent = (1 - size_ratio) * 100
print(f"Non-pruned index: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
print(f"Pruned index: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
print(f"Size difference: {size_diff:,} bytes ({size_diff/1024:.1f} KB)")
print(f"Size ratio: {size_ratio:.3f}")
print(f"Size reduction: {reduction_percent:.1f}%")
# Verify pruning effectiveness
print("\n🔍 Verification:")
if size_diff > 0:
print(" ✅ Pruning is effective - pruned index is smaller")
if reduction_percent > 10:
print(f" ✅ Significant size reduction: {reduction_percent:.1f}%")
else:
print(f" ⚠️ Small size reduction: {reduction_percent:.1f}%")
else:
print(" ❌ Pruning appears ineffective - no size reduction")
# Check if passages files were created
pruned_passages = f"{pruned_index_path}.passages.json"
non_pruned_passages = f"{non_pruned_index_path}.passages.json"
print(f"\n📄 Passages files:")
print(f" Pruned passages file exists: {os.path.exists(pruned_passages)}")
print(f" Non-pruned passages file exists: {os.path.exists(non_pruned_passages)}")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@@ -1,127 +0,0 @@
#!/usr/bin/env python3
"""
验证DiskANN L2距离是否真正工作
"""
import numpy as np
from pathlib import Path
import shutil
# 导入后端包以触发插件注册
try:
import leann_backend_diskann
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
from leann.api import LeannBuilder, LeannSearcher
def test_l2_verification():
"""验证L2距离是否真正被使用"""
print("=== 验证DiskANN L2距离实现 ===")
INDEX_DIR = Path("./test_l2_verification")
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
# 创建特殊的测试文档使L2和cosine产生不同结果
documents = [
"machine learning artificial intelligence", # 文档0
"computer programming software development", # 文档1
"data science analytics statistics" # 文档2
]
print("构建索引...")
builder = LeannBuilder(
backend_name="diskann",
distance_metric="l2", # 明确指定L2
graph_degree=16,
complexity=32
)
for i, doc in enumerate(documents):
builder.add_text(doc, metadata={"id": i, "text": doc})
builder.build_index(INDEX_PATH)
print("✅ 索引构建完成")
# 测试搜索
searcher = LeannSearcher(INDEX_PATH, distance_metric="l2")
# 用一个与文档0非常相似的查询
query = "machine learning AI technology"
results = searcher.search(query, top_k=3)
print(f"\n查询: '{query}'")
print("L2距离搜索结果:")
for i, result in enumerate(results):
print(f" {i+1}. ID:{result['id']}, Score:{result['score']:.6f}")
print(f" Text: {result['text']}")
# 现在用cosine重新测试同样的数据
print(f"\n--- 用Cosine距离对比测试 ---")
INDEX_DIR_COS = Path("./test_cosine_verification")
INDEX_PATH_COS = str(INDEX_DIR_COS / "documents.diskann")
if INDEX_DIR_COS.exists():
shutil.rmtree(INDEX_DIR_COS)
builder_cos = LeannBuilder(
backend_name="diskann",
distance_metric="cosine", # 使用cosine
graph_degree=16,
complexity=32
)
for i, doc in enumerate(documents):
builder_cos.add_text(doc, metadata={"id": i, "text": doc})
builder_cos.build_index(INDEX_PATH_COS)
searcher_cos = LeannSearcher(INDEX_PATH_COS, distance_metric="cosine")
results_cos = searcher_cos.search(query, top_k=3)
print("Cosine距离搜索结果:")
for i, result in enumerate(results_cos):
print(f" {i+1}. ID:{result['id']}, Score:{result['score']:.6f}")
print(f" Text: {result['text']}")
# 对比分析
print(f"\n--- 结果对比分析 ---")
print("L2距离的分数是欧几里得距离平方越小越相似")
print("Cosine距离的分数是余弦相似度的负值越小越相似")
l2_top = results[0]
cos_top = results_cos[0]
print(f"L2最佳匹配: ID{l2_top['id']}, Score={l2_top['score']:.6f}")
print(f"Cosine最佳匹配: ID{cos_top['id']}, Score={cos_top['score']:.6f}")
if l2_top['id'] == cos_top['id']:
print("✅ 两种距离函数返回相同的最佳匹配")
else:
print("⚠️ 两种距离函数返回不同的最佳匹配 - 这表明它们确实使用了不同的距离计算")
# 验证分数范围的合理性
l2_scores = [r['score'] for r in results]
cos_scores = [r['score'] for r in results_cos]
print(f"L2分数范围: {min(l2_scores):.6f}{max(l2_scores):.6f}")
print(f"Cosine分数范围: {min(cos_scores):.6f}{max(cos_scores):.6f}")
# L2分数应该是正数cosine分数应该在-1到0之间因为是负的相似度
if all(score >= 0 for score in l2_scores):
print("✅ L2分数都是正数符合预期")
else:
print("❌ L2分数有负数可能有问题")
if all(-1 <= score <= 0 for score in cos_scores):
print("✅ Cosine分数在合理范围内")
else:
print(f"⚠️ Cosine分数超出预期范围: {cos_scores}")
if __name__ == "__main__":
test_l2_verification()

View File

@@ -1,190 +0,0 @@
#!/usr/bin/env python3
"""
Sanity check script for Leann DiskANN backend
Tests different distance functions and embedding models
"""
import os
import numpy as np
from pathlib import Path
import shutil
import time
# 导入后端包以触发插件注册
import sys
sys.path.append('packages/leann-core/src')
sys.path.append('packages/leann-backend-diskann')
sys.path.append('packages/leann-backend-hnsw')
try:
import leann_backend_diskann
import leann_backend_hnsw
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
# 从 leann-core 导入上层 API
from leann.api import LeannBuilder, LeannSearcher
def test_distance_functions():
"""测试不同的距离函数"""
print("\n=== 测试不同距离函数 ===")
# 测试数据
documents = [
"Machine learning is a powerful technology",
"Deep learning uses neural networks",
"Artificial intelligence transforms industries"
]
distance_functions = ["mips", "l2", "cosine"]
for distance_func in distance_functions:
print(f"\n[测试 {distance_func} 距离函数]")
try:
index_path = f"test_indices/test_{distance_func}.diskann"
if Path(index_path).parent.exists():
shutil.rmtree(Path(index_path).parent)
# 构建索引
builder = LeannBuilder(
backend_name="diskann",
distance_metric=distance_func,
graph_degree=16,
complexity=32
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
# 测试搜索
searcher = LeannSearcher(index_path, distance_metric=distance_func)
results = searcher.search("neural network technology", top_k=2)
print(f"{distance_func} 距离函数工作正常")
for i, result in enumerate(results):
print(f" {i+1}. Score: {result['score']:.4f}, Text: {result['text'][:50]}...")
except Exception as e:
print(f"{distance_func} 距离函数失败: {e}")
def test_embedding_models():
"""测试不同的embedding模型"""
print("\n=== 测试不同Embedding模型 ===")
documents = ["AI is transforming the world", "Technology advances rapidly"]
# 测试不同的embedding模型
models_to_test = [
"sentence-transformers/all-mpnet-base-v2",
"sentence-transformers/all-MiniLM-L6-v2",
# "sentence-transformers/distilbert-base-nli-mean-tokens", # 可能不存在
]
for model_name in models_to_test:
print(f"\n[测试 {model_name}]")
try:
index_path = f"test_indices/test_model.diskann"
if Path(index_path).parent.exists():
shutil.rmtree(Path(index_path).parent)
# 构建索引
builder = LeannBuilder(
backend_name="diskann",
embedding_model=model_name,
distance_metric="cosine"
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
# 测试搜索
searcher = LeannSearcher(index_path)
results = searcher.search("artificial intelligence", top_k=1)
print(f"{model_name} 模型工作正常")
print(f" 结果: {results[0]['text'][:50]}...")
except Exception as e:
print(f"{model_name} 模型失败: {e}")
def test_search_correctness():
"""验证搜索结果的正确性"""
print("\n=== 验证搜索结果正确性 ===")
# 创建有明确相关性的测试文档
documents = [
"Python is a programming language used for machine learning", # 与编程相关
"Dogs are loyal pets that love to play fetch", # 与动物相关
"Machine learning algorithms can predict future trends", # 与ML相关
"Cats are independent animals that sleep a lot", # 与动物相关
"Deep learning neural networks process complex data" # 与ML相关
]
try:
index_path = "test_indices/correctness_test.diskann"
if Path(index_path).parent.exists():
shutil.rmtree(Path(index_path).parent)
# 构建索引
builder = LeannBuilder(
backend_name="diskann",
distance_metric="cosine"
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
# 测试相关性查询
searcher = LeannSearcher(index_path)
test_queries = [
("machine learning programming", [0, 2, 4]), # 应该返回ML相关文档
("pet animals behavior", [1, 3]), # 应该返回动物相关文档
]
for query, expected_topics in test_queries:
print(f"\n查询: '{query}'")
results = searcher.search(query, top_k=3)
print("搜索结果:")
for i, result in enumerate(results):
print(f" {i+1}. ID:{result['id']}, Score:{result['score']:.4f}")
print(f" Text: {result['text'][:60]}...")
# 简单验证:检查前两个结果是否在预期范围内
top_ids = [result['id'] for result in results[:2]]
relevant_found = any(id in expected_topics for id in top_ids)
if relevant_found:
print("✅ 搜索结果相关性正确")
else:
print("⚠️ 搜索结果相关性可能有问题")
except Exception as e:
print(f"❌ 正确性测试失败: {e}")
def main():
print("🔍 Leann DiskANN Sanity Check")
print("=" * 50)
# 清理旧的测试数据
if Path("test_indices").exists():
shutil.rmtree("test_indices")
# 运行测试
test_distance_functions()
test_embedding_models()
test_search_correctness()
print("\n" + "=" * 50)
print("🎉 Sanity check 完成!")
if __name__ == "__main__":
main()

898
uv.lock generated
View File

File diff suppressed because it is too large Load Diff