Compare commits
38 Commits
debug/clea
...
refactor-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0877960547 | ||
|
|
d68af63d05 | ||
|
|
b844aca968 | ||
|
|
85277ba67a | ||
|
|
e9562acdc2 | ||
|
|
7fd3db1ddb | ||
|
|
c1ccc51a75 | ||
|
|
b0239b6e4d | ||
|
|
58556ef44c | ||
|
|
87c930d705 | ||
|
|
86f919a6da | ||
|
|
f8d34663b4 | ||
|
|
568cf597f4 | ||
|
|
baf70dc411 | ||
|
|
7ad2ec39d6 | ||
|
|
31fd3c816a | ||
|
|
1f6c7f2f5a | ||
|
|
c1124eb349 | ||
|
|
274bbb19ea | ||
|
|
8c152c7a31 | ||
|
|
ce77eef13a | ||
|
|
9d77175ac8 | ||
|
|
7fbb6c98ef | ||
|
|
914a248c28 | ||
|
|
55fc5862f9 | ||
|
|
fd97b8dfa8 | ||
|
|
57959947a1 | ||
|
|
cc0c091ca5 | ||
|
|
ff389c7d8d | ||
|
|
6780a8eaba | ||
|
|
984056f126 | ||
|
|
bd4451bf50 | ||
|
|
34e313f64a | ||
|
|
ddc789b231 | ||
|
|
ff1b622bdd | ||
|
|
3cde4fc7b3 | ||
|
|
4e3bcda5fa | ||
|
|
46f6f76fc3 |
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Link Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1"
|
||||
|
||||
jobs:
|
||||
link-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -34,11 +34,15 @@ build/
|
||||
nprobe_logs/
|
||||
micro/results
|
||||
micro/contriever-INT8
|
||||
examples/data/*
|
||||
!examples/data/2501.14312v1 (1).pdf
|
||||
!examples/data/2506.08276v1.pdf
|
||||
!examples/data/PrideandPrejudice.txt
|
||||
!examples/data/README.md
|
||||
data/*
|
||||
!data/2501.14312v1 (1).pdf
|
||||
!data/2506.08276v1.pdf
|
||||
!data/PrideandPrejudice.txt
|
||||
!data/README.md
|
||||
!data/ground_truth/
|
||||
!data/indices/
|
||||
!data/queries/
|
||||
!data/.gitattributes
|
||||
*.qdstrm
|
||||
benchmark_results/
|
||||
results/
|
||||
|
||||
238
README.md
238
README.md
@@ -41,40 +41,40 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary><strong>📦 Prerequisites: Install uv (if you don't have it)</strong></summary>
|
||||
### 📦 Prerequisites: Install uv
|
||||
|
||||
Install uv first if you don't have it:
|
||||
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
||||
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
📖 [Detailed uv installation methods →](https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
|
||||
### 🚀 Quick Install
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
LEANN provides two installation methods: **pip install** (quick and easy) and **build from source** (recommended for development).
|
||||
|
||||
|
||||
|
||||
### 🚀 Quick Install (Recommended for most users)
|
||||
|
||||
Clone the repository to access all examples and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||
Clone the repository to access all examples and try amazing applications,
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
```
|
||||
|
||||
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
|
||||
### 🔧 Build from Source (Recommended for development)
|
||||
<details>
|
||||
<summary>
|
||||
<strong>🔧 Build from Source (Recommended for development)</strong>
|
||||
</summary>
|
||||
|
||||
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
@@ -91,14 +91,14 @@ sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev l
|
||||
uv sync
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
|
||||
[](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb) [Try in this ipynb file →](demo.ipynb)
|
||||
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
@@ -122,11 +122,11 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
### Generation Model Setup
|
||||
|
||||
> **Generation Model Setup**
|
||||
> LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
@@ -166,7 +166,49 @@ ollama pull llama3.2:1b
|
||||
|
||||
</details>
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||
### Flexible Configuration
|
||||
|
||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
||||
|
||||
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
||||
|
||||
```bash
|
||||
# Core Parameters (General preprocessing for all examples)
|
||||
--index-dir DIR # Directory to store the index (default: current directory)
|
||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||
--force-rebuild # Force rebuild index even if it exists
|
||||
|
||||
# Embedding Parameters
|
||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small or mlx-community/multilingual-e5-base-mlx
|
||||
--embedding-mode MODE # sentence-transformers, openai, or mlx
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
# Search Parameters
|
||||
--top-k N # Number of results to retrieve (default: 20)
|
||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||
|
||||
# Chunking Parameters
|
||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||
|
||||
# Index Building Parameters
|
||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||
--graph-degree N # Graph degree for index construction (default: 32)
|
||||
--build-complexity N # Build complexity for index construction (default: 64)
|
||||
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
|
||||
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
||||
|
||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||
|
||||
@@ -174,25 +216,29 @@ Ask questions directly about your personal PDFs, documents, and any directory co
|
||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
The example below asks a question about summarizing two papers (uses default data in `examples/data`) and this is the easiest example to run here:
|
||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a README in Chinese) and this is the **easiest example** to run here:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
python ./examples/main_cli_example.py
|
||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use custom index directory
|
||||
python examples/main_cli_example.py --index-dir "./my_custom_index"
|
||||
--data-dir DIR # Directory containing documents to process (default: data)
|
||||
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
||||
```
|
||||
|
||||
# Use custom data directory
|
||||
python examples/main_cli_example.py --data-dir "./my_documents"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Process all documents with larger chunks for academic papers
|
||||
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
||||
|
||||
# Ask a specific question
|
||||
python examples/main_cli_example.py --query "What are the main findings in these papers?"
|
||||
# Filter only markdown and Python files with smaller chunks
|
||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -206,30 +252,29 @@ python examples/main_cli_example.py --query "What are the main findings in these
|
||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
|
||||
```bash
|
||||
python examples/mail_reader_leann.py --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||
```
|
||||
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default mail path (works for most macOS setups)
|
||||
python examples/mail_reader_leann.py
|
||||
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
||||
--include-html # Include HTML content in processing (useful for newsletters)
|
||||
```
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search work emails from a specific account
|
||||
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
||||
|
||||
# Process all emails (may take time but indexes everything)
|
||||
python examples/mail_reader_leann.py --max-emails -1
|
||||
|
||||
# Limit number of emails processed (useful for testing)
|
||||
python examples/mail_reader_leann.py --max-emails 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
|
||||
# Find all receipts and order confirmations (includes HTML)
|
||||
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -250,25 +295,25 @@ Once the index is built, you can ask questions like:
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
||||
```
|
||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default Chrome profile (auto-finds all profiles)
|
||||
python examples/google_history_reader_leann.py
|
||||
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
||||
```
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search academic research from your browsing history
|
||||
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
||||
|
||||
# Limit number of history entries processed (useful for testing)
|
||||
python examples/google_history_reader_leann.py --max-entries 500
|
||||
|
||||
# Run a single query
|
||||
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
|
||||
# Track competitor analysis across work profile
|
||||
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -308,7 +353,7 @@ Once the index is built, you can ask questions like:
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||
|
||||
@@ -316,7 +361,13 @@ python examples/wechat_history_reader_leann.py --query "Show me all group chats
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
|
||||
First, you need to install the WeChat exporter:
|
||||
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
||||
|
||||
```bash
|
||||
brew install sunnyyoung/repo/wechattweak-cli
|
||||
```
|
||||
|
||||
or install it manually (if you have issues with Homebrew):
|
||||
|
||||
```bash
|
||||
sudo packages/wechat-exporter/wechattweak-cli install
|
||||
@@ -325,30 +376,28 @@ sudo packages/wechat-exporter/wechattweak-cli install
|
||||
**Troubleshooting:**
|
||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||
```
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
```bash
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
|
||||
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default settings (recommended for first run)
|
||||
python examples/wechat_history_reader_leann.py
|
||||
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
||||
--force-export # Force re-export even if data exists
|
||||
```
|
||||
|
||||
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
|
||||
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search for travel plans discussed in group chats
|
||||
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
|
||||
|
||||
# Limit number of chat entries processed (useful for testing)
|
||||
python examples/wechat_history_reader_leann.py --max-entries 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
|
||||
# Re-export and search recent chats (useful after new messages)
|
||||
python -m apps.wechat_rag --force-export --query "work schedule"
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -368,6 +417,27 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
### Installation
|
||||
|
||||
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
leann --help
|
||||
```
|
||||
|
||||
**To make it globally available (recommended for daily use):**
|
||||
```bash
|
||||
# Install the LEANN CLI globally using uv tool
|
||||
uv tool install leann
|
||||
|
||||
# Now you can use leann from anywhere without activating venv
|
||||
leann --help
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Usage Examples
|
||||
|
||||
```bash
|
||||
# Build an index from documents
|
||||
leann build my-docs --docs ./documents
|
||||
@@ -449,8 +519,8 @@ Options:
|
||||
## Benchmarks
|
||||
|
||||
|
||||
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
|
||||
### Storage Comparison
|
||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
|
||||
### 📊 Storage Comparison
|
||||
|
||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||
|--------|-------------|------------|-------------|--------------|---------------|
|
||||
@@ -464,8 +534,8 @@ Options:
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
python benchmarks/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||
python benchmarks/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
@@ -503,7 +573,9 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
|
||||
0
apps/__init__.py
Normal file
0
apps/__init__.py
Normal file
296
apps/base_rag_example.py
Normal file
296
apps/base_rag_example.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Base class for unified RAG examples interface.
|
||||
Provides common parameters and functionality for all RAG examples.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
class BaseRAGExample(ABC):
|
||||
"""Base class for all RAG examples with unified interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
default_index_name: str,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.default_index_name = default_index_name
|
||||
self.parser = self._create_parser()
|
||||
|
||||
def _create_parser(self) -> argparse.ArgumentParser:
|
||||
"""Create argument parser with common parameters."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
# Core parameters (all examples share these)
|
||||
core_group = parser.add_argument_group("Core Parameters")
|
||||
core_group.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default=f"./{self.default_index_name}",
|
||||
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Query to run (if not provided, will run in interactive mode)",
|
||||
)
|
||||
# Allow subclasses to override default max_items
|
||||
max_items_default = getattr(self, "max_items_default", -1)
|
||||
core_group.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=max_items_default,
|
||||
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||
)
|
||||
|
||||
# Embedding parameters
|
||||
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||
# Allow subclasses to override default embedding_model
|
||||
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||
embedding_group.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default=embedding_model_default,
|
||||
help=f"Embedding model to use (default: {embedding_model_default})",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="Embedding backend mode (default: sentence-transformers)",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
llm_group.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="openai",
|
||||
choices=["openai", "ollama", "hf"],
|
||||
help="LLM backend to use (default: openai)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Host for Ollama API (default: http://localhost:11434)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
search_group = parser.add_argument_group("Search Parameters")
|
||||
search_group.add_argument(
|
||||
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||
)
|
||||
search_group.add_argument(
|
||||
"--search-complexity",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Search complexity for graph traversal (default: 64)",
|
||||
)
|
||||
|
||||
# Index building parameters
|
||||
index_group = parser.add_argument_group("Index Building Parameters")
|
||||
index_group.add_argument(
|
||||
"--backend-name",
|
||||
type=str,
|
||||
default="hnsw",
|
||||
choices=["hnsw", "diskann"],
|
||||
help="Backend to use for index (default: hnsw)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--graph-degree",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Graph degree for index construction (default: 32)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--build-complexity",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Build complexity for index construction (default: 64)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-compact",
|
||||
action="store_true",
|
||||
help="Disable compact index storage",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-recompute",
|
||||
action="store_true",
|
||||
help="Disable embedding recomputation",
|
||||
)
|
||||
|
||||
# Add source-specific parameters
|
||||
self._add_specific_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add source-specific arguments. Override in subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
"""Get LLM configuration based on arguments."""
|
||||
config = {"type": args.llm}
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = args.llm_host
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
is_recompute=not args.no_recompute,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
# Add texts in batches for better progress tracking
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
builder.build_index(index_path)
|
||||
print(f"Index saved to: {index_path}")
|
||||
|
||||
return index_path
|
||||
|
||||
async def run_interactive_chat(self, args, index_path: str):
|
||||
"""Run interactive chat with the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
response = chat.ask(query, top_k=args.top_k, complexity=args.search_complexity)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||
response = chat.ask(query, top_k=args.top_k, complexity=args.search_complexity)
|
||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point for the example."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Check if index exists
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
index_exists = Path(args.index_dir).exists()
|
||||
|
||||
if not index_exists or args.force_rebuild:
|
||||
# Load data and build index
|
||||
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||
texts = await self.load_data(args)
|
||||
|
||||
if not texts:
|
||||
print("No data found to index!")
|
||||
return
|
||||
|
||||
index_path = await self.build_index(args, texts)
|
||||
else:
|
||||
print(f"\nUsing existing index in {args.index_dir}")
|
||||
|
||||
# Run query or interactive mode
|
||||
if args.query:
|
||||
await self.run_single_query(args, index_path, args.query)
|
||||
else:
|
||||
await self.run_interactive_chat(args, index_path)
|
||||
|
||||
|
||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
||||
"""Helper function to create text chunks from documents."""
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=" ",
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
|
||||
return all_texts
|
||||
170
apps/browser_rag.py
Normal file
170
apps/browser_rag.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Browser History RAG example using the unified interface.
|
||||
Supports Chrome browser history.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
|
||||
from .history_data.history import ChromeHistoryReader
|
||||
|
||||
|
||||
class BrowserRAG(BaseRAGExample):
|
||||
"""RAG example for Chrome browser history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Browser History",
|
||||
description="Process and query Chrome browser history with LEANN",
|
||||
default_index_name="google_history_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add browser-specific arguments."""
|
||||
browser_group = parser.add_argument_group("Browser Parameters")
|
||||
browser_group.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _get_chrome_base_path(self) -> Path:
|
||||
"""Get the base Chrome profile path based on OS."""
|
||||
if sys.platform == "darwin":
|
||||
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||
elif sys.platform.startswith("linux"):
|
||||
return Path.home() / ".config" / "google-chrome"
|
||||
elif sys.platform == "win32":
|
||||
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||
else:
|
||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||
|
||||
def _find_chrome_profiles(self) -> list[Path]:
|
||||
"""Auto-detect all Chrome profiles."""
|
||||
base_path = self._get_chrome_base_path()
|
||||
if not base_path.exists():
|
||||
return []
|
||||
|
||||
profiles = []
|
||||
|
||||
# Check Default profile
|
||||
default_profile = base_path / "Default"
|
||||
if default_profile.exists() and (default_profile / "History").exists():
|
||||
profiles.append(default_profile)
|
||||
|
||||
# Check numbered profiles
|
||||
for item in base_path.iterdir():
|
||||
if item.is_dir() and item.name.startswith("Profile "):
|
||||
if (item / "History").exists():
|
||||
profiles.append(item)
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
profile_dirs = [Path(args.chrome_profile)]
|
||||
else:
|
||||
print("Auto-detecting Chrome profiles...")
|
||||
profile_dirs = self._find_chrome_profiles()
|
||||
|
||||
# If specific profile given, filter to just that one
|
||||
if args.chrome_profile:
|
||||
profile_path = Path(args.chrome_profile)
|
||||
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found!")
|
||||
print("Please specify --chrome-profile manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
|
||||
# Create reader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
# Process each profile
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per profile
|
||||
max_per_profile = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_profile = remaining
|
||||
|
||||
# Load history
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_per_profile,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} history entries from this profile")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No browser history found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for browser history RAG
|
||||
print("\n🌐 Browser History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What websites did I visit about machine learning?'")
|
||||
print("- 'Find my search history about programming'")
|
||||
print("- 'What YouTube videos did I watch recently?'")
|
||||
print("- 'Show me websites about travel planning'")
|
||||
print("\nNote: Make sure Chrome is closed before running\n")
|
||||
|
||||
rag = BrowserRAG()
|
||||
asyncio.run(rag.run())
|
||||
106
apps/document_rag.py
Normal file
106
apps/document_rag.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Document RAG example using the unified interface.
|
||||
Supports PDF, TXT, MD, and other document formats.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class DocumentRAG(BaseRAGExample):
|
||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Document",
|
||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||
default_index_name="test_doc_files",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add document-specific arguments."""
|
||||
doc_group = parser.add_argument_group("Document Parameters")
|
||||
doc_group.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="data",
|
||||
help="Directory containing documents to index (default: data)",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--file-types",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
print(f"Filtering by file types: {args.file_types}")
|
||||
else:
|
||||
print("Processing all supported file types")
|
||||
|
||||
# Check if data directory exists
|
||||
data_path = Path(args.data_dir)
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
# Load documents
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
if args.file_types:
|
||||
reader_kwargs["required_exts"] = args.file_types
|
||||
|
||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for document RAG
|
||||
print("\n📄 Document RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What are the main techniques LEANN uses?'")
|
||||
print("- 'What is the technique DLPM?'")
|
||||
print("- 'Who does Elizabeth Bennet marry?'")
|
||||
print("- 'What is the problem of developing pan gu model? (盘古大模型开发中遇到什么问题?)'")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = DocumentRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -52,6 +52,11 @@ class EmlxReader(BaseReader):
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
total_files = 0
|
||||
successful_files = 0
|
||||
failed_files = 0
|
||||
|
||||
print(f"Starting to process directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
@@ -59,10 +64,12 @@ class EmlxReader(BaseReader):
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
# Check if we've reached the max count (skip if max_count == -1)
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
total_files += 1
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
@@ -98,17 +105,26 @@ class EmlxReader(BaseReader):
|
||||
and not self.include_html
|
||||
):
|
||||
continue
|
||||
body += part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
# break
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body += payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding payload: {e}")
|
||||
continue
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
try:
|
||||
payload = msg.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding single part payload: {e}")
|
||||
body = ""
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
# Only create document if we have some content
|
||||
if body.strip() or subject != "No Subject":
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
@@ -118,18 +134,34 @@ class EmlxReader(BaseReader):
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
successful_files += 1
|
||||
|
||||
# Print first few successful files for debugging
|
||||
if successful_files <= 3:
|
||||
print(
|
||||
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
print("Processing summary:")
|
||||
print(f" Total .emlx files found: {total_files}")
|
||||
print(f" Successfully loaded: {successful_files}")
|
||||
print(f" Failed to load: {failed_files}")
|
||||
print(f" Final documents: {len(docs)}")
|
||||
|
||||
return docs
|
||||
156
apps/email_rag.py
Normal file
156
apps/email_rag.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Email RAG example using the unified interface.
|
||||
Supports Apple Mail on macOS.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
|
||||
from .email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
class EmailRAG(BaseRAGExample):
|
||||
"""RAG example for Apple Mail processing."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all emails by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Email",
|
||||
description="Process and query Apple Mail emails with LEANN",
|
||||
default_index_name="mail_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add email-specific arguments."""
|
||||
email_group = parser.add_argument_group("Email Parameters")
|
||||
email_group.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||
)
|
||||
|
||||
def _find_mail_directories(self) -> list[Path]:
|
||||
"""Auto-detect all Apple Mail directories."""
|
||||
mail_base = Path.home() / "Library" / "Mail"
|
||||
if not mail_base.exists():
|
||||
return []
|
||||
|
||||
# Find all Messages directories
|
||||
messages_dirs = []
|
||||
for item in mail_base.rglob("Messages"):
|
||||
if item.is_dir():
|
||||
messages_dirs.append(item)
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
messages_dirs = [Path(args.mail_path)]
|
||||
else:
|
||||
print("Auto-detecting Apple Mail directories...")
|
||||
messages_dirs = self._find_mail_directories()
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Apple Mail directories found!")
|
||||
print("Please specify --mail-path manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(messages_dirs)} mail directories")
|
||||
|
||||
# Create reader
|
||||
reader = EmlxReader(include_html=args.include_html)
|
||||
|
||||
# Process each directory
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
# Count emlx files
|
||||
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||
print(f"Found {len(emlx_files)} email files")
|
||||
|
||||
# Apply max_items limit per directory
|
||||
max_per_dir = -1 # Default to process all
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_dir = remaining
|
||||
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||
|
||||
# Load emails - fix the parameter passing
|
||||
documents = reader.load_data(
|
||||
input_dir=str(messages_dir),
|
||||
max_count=max_per_dir,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} emails from this directory")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No emails found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
# Email reader uses chunk_overlap=25 as in original
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||
print(" Windows/Linux support coming soon!\n")
|
||||
|
||||
# Example queries for email RAG
|
||||
print("\n📧 Email RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did my boss say about deadlines?'")
|
||||
print("- 'Find emails about travel expenses'")
|
||||
print("- 'Show me emails from last month about the project'")
|
||||
print("- 'What food did I order from DoorDash?'")
|
||||
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||
|
||||
rag = EmailRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -97,6 +97,11 @@ class ChromeHistoryReader(BaseReader):
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading Chrome history: {e}")
|
||||
# add you may need to close your browser to make the database file available
|
||||
# also highlight in red
|
||||
print(
|
||||
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||
)
|
||||
return docs
|
||||
|
||||
return docs
|
||||
@@ -411,8 +411,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||
include_non_text = load_kwargs.get("include_non_text", False)
|
||||
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||
load_kwargs.get("max_length", 1000)
|
||||
load_kwargs.get("time_window_minutes", 30)
|
||||
max_length = load_kwargs.get("max_length", 1000)
|
||||
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||
|
||||
# Default WeChat export path
|
||||
if wechat_export_dir is None:
|
||||
@@ -460,9 +460,9 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
# Concatenate messages based on rules
|
||||
message_groups = self._concatenate_messages(
|
||||
readable_messages,
|
||||
max_length=-1,
|
||||
time_window_minutes=-1,
|
||||
overlap_messages=0, # Keep 2 messages overlap between groups
|
||||
max_length=max_length,
|
||||
time_window_minutes=time_window_minutes,
|
||||
overlap_messages=0, # No overlap between groups
|
||||
)
|
||||
|
||||
# Create documents from concatenated groups
|
||||
@@ -532,7 +532,9 @@ Message: {readable_text if readable_text else message_text}
|
||||
"""
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc = Document(
|
||||
text=doc_content, metadata={"contact_name": contact_name}
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
@@ -560,8 +562,8 @@ Message: {readable_text if readable_text else message_text}
|
||||
|
||||
# Look for common export directory names
|
||||
possible_dirs = [
|
||||
Path("./wechat_export_test"),
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_export_direct"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export"),
|
||||
]
|
||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
WeChat History RAG example using the unified interface.
|
||||
Supports WeChat chat history export and search.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
|
||||
from .history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
|
||||
class WeChatRAG(BaseRAGExample):
|
||||
"""RAG example for WeChat chat history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Match original default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="WeChat History",
|
||||
description="Process and query WeChat chat history with LEANN",
|
||||
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add WeChat-specific arguments."""
|
||||
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||
wechat_group.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default="./wechat_export",
|
||||
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||
)
|
||||
|
||||
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||
"""Export WeChat data using wechattweak-cli."""
|
||||
print("Exporting WeChat data...")
|
||||
|
||||
# Check if WeChat is running
|
||||
try:
|
||||
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print("WeChat is not running. Please start WeChat first.")
|
||||
return False
|
||||
except Exception:
|
||||
pass # pgrep might not be available on all systems
|
||||
|
||||
# Create export directory
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run export command
|
||||
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||
|
||||
try:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("WeChat data exported successfully!")
|
||||
return True
|
||||
else:
|
||||
print(f"Export failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
except FileNotFoundError:
|
||||
print("\nError: wechattweak-cli not found!")
|
||||
print("Please install it first:")
|
||||
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
# Initialize WeChat reader with export capabilities
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||
# Try to find any existing exports in common locations
|
||||
export_dirs = reader.find_wechat_export_dirs()
|
||||
if not export_dirs:
|
||||
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||
return []
|
||||
|
||||
# Load documents from all found export directories
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per export
|
||||
max_per_export = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_export = remaining
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_per_export,
|
||||
concatenate_messages=True, # Enable message concatenation for better context
|
||||
)
|
||||
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return []
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks with contact information
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
text_splitter = SentenceSplitter(
|
||||
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
# Add contact information to each chunk
|
||||
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||
print(" You can still query existing exports on other platforms\n")
|
||||
|
||||
# Example queries for WeChat RAG
|
||||
print("\n💬 WeChat History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'Show me conversations about travel plans'")
|
||||
print("- 'Find group chats about weekend activities'")
|
||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||
print("- 'What did we discuss about the project last month?'")
|
||||
print("\nNote: WeChat must be running for export to work\n")
|
||||
|
||||
rag = WeChatRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -62,7 +62,7 @@ def test_faiss_hnsw():
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "examples/faiss_only.py"],
|
||||
[sys.executable, "benchmarks/faiss_only.py"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
@@ -115,7 +115,7 @@ def test_leann_hnsw():
|
||||
|
||||
# Load and parse documents
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -65,7 +65,7 @@ def main():
|
||||
tracker.checkpoint("After Faiss index creation")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -200,10 +200,10 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
# Assumes a project structure where the script is in 'benchmarks/'
|
||||
# and evaluation data is in 'benchmarks/data/'.
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
data_root = script_dir / "data"
|
||||
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
@@ -279,7 +279,9 @@ def main():
|
||||
|
||||
if not args.index_path:
|
||||
print("No indices found. The data download should have included pre-built indices.")
|
||||
print("Please check the data/indices/ directory or provide --index-path manually.")
|
||||
print(
|
||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
110
data/README.md
110
data/README.md
@@ -1,44 +1,82 @@
|
||||
---
|
||||
license: mit
|
||||
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
|
||||
|
||||
各位好,
|
||||
|
||||
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
|
||||
|
||||
首先为自证身份,列举一些细节:
|
||||
|
||||
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
|
||||
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
|
||||
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
|
||||
4. 诺亚曾经传说是研究型的,但是来了之后因为在四野做大模型项目,项目成员完全变成了交付型的,且充满了例会,评审,汇报。很多时候做实验都要申请。团队需要对接终端小艺,华为云,ICT等诸多业务线,交付压力不小。
|
||||
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”,一开始只有内部需要申请试用的网页版,到后续迫于压力在welink上接入和公测开放。
|
||||
|
||||
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
|
||||
|
||||
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
|
||||
|
||||
华为确实主要在昇腾卡上训练大模型(小模型实验室有不少英伟达的卡,他们之前也会用来训练,后面转移到昇腾)。曾经我被华为“打造世界第二选择”的决心而折服,我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打,从充满bug到现在能训出模型,付出了巨大的心血和代价。
|
||||
|
||||
最初我们的算力非常有限,在910A上训练模型。那会只支持fp16,训练的稳定性远不如bf16。盘古的moe开始很早,23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型,后面主力模型也逐渐在910B上训练。
|
||||
|
||||
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低,每个单个的符号,数字,空格,乃至汉字都会占用一个token。可想而知这会非常浪费算力,且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好(虽然事后来看,他的怀疑是无疑正确的),于是就决定,让71B和135B换tokenizer,因为小模型实验室曾经尝试过。团队缝合了两个tokenizer,开始了tokenizer的更换。71B模型的更换失败了,而135B因为采用了更精细的embedding初始化策略,续训了至少1T的数据后词表总算更换成功,但可想而知,效果并不会变好。
|
||||
|
||||
于此同期,阿里和智谱等国内其他公司在GPU上训练,且已经摸索出了正确的方法,盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败,导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时,团队的士气低迷到了极点。团队在算力极其有限的时候,做出了很多努力和挣扎。比如,团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数,还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B,架构相对落后,团队进行了一系列的操作,比如切换绝对位置编码到rope,去掉bias,切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验,这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训,变成了第二代38B dense模型(在几个月内这个模型都是主要的盘古中档位模型),曾经具有一定的竞争力。但是,由于更大的135B模型架构落后,且更换词表模型损伤巨大(后续分析发现当时更换的缝合词表有更严重的bug),续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
|
||||
|
||||
在这种情况下,王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来,通过训练短短的几百B数据,各项指标平均提升了十个点左右。实际上,这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行,使得领导完全对于这种扯淡的事情没有概念,他们只会觉得肯定是有什么算法创新。经过内部的分析,他们实际上是使用Qwen 1.5 110B续训而来,通过加层,扩增ffn维度,添加盘古pi论文的一些机制得来,凑够了大概135B的参数。实际上,旧的135B有107层,而这个模型只有82层,各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen,甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游,甚至包括外部客户。
|
||||
|
||||
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击,内部很多人其实都知道这件事,甚至包括终端和华为云。我们都戏称以后别叫盘古模型了,叫千古吧。当时团队成员就想向bcg举报了,毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来,因为更高级别的领导(比如姚老师,以及可能熊总和查老)其实后面也知道了,但是并不管,因为通过套壳拿出好的结果,对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷,离职跑路也逐渐成为挂在嘴边的事。
|
||||
|
||||
此时,盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来,当时诺亚完全没有掌握从头训练的技术,何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下,盘古开始了第三代模型的训练,付出了巨大的努力后,在数据架构和训练算法方面都与业界逐渐接轨,而这其中的艰辛和小模型实验室的人一点关系都没有。
|
||||
|
||||
一开始团队成员毫无信心,只从一个13B的模型开始训练,但是后面发现效果还不错,于是这个模型后续再次进行了一次参数扩增,变成了第三代的38B,代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的(也是业界常见的做法)。而当时王云鹤的实验室做出来了另一个词表(也就是后续pangu系列的词表)。当时两个词表还被迫进行了一次赛马,最终没有明显的好坏结论。于是,领导当即决定,应该统一词表,使用王云鹤他们的。于是,在后续从头训练的135B V3(也就是对外的Pangu Ultra),便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑,为什么当时同为V3代的两个不同档位的模型,会使用不同的tokenizer。
|
||||
|
||||
|
||||
我们打心眼里觉得,135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的,华为全栈自研,正经从头训练的千亿级别的模型,且效果与24年同期竞品可比的。写到这里我已经热泪盈眶,太不容易了。当时为了稳定训练,团队做了大量实验对比,并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难,我们做到了,我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨,我们为了它的训练而不眠。在被内部心声骂的一文不值的时候,我们有多么不甘,有多少的委屈,我们挺住了。
|
||||
|
||||
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
|
||||
|
||||
然而,我们的所有辛苦的成果,经常被小模型实验室轻飘飘的拿走了。数据,直接要走。代码,直接要走,还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦,他们取得荣耀。果然应了那句话,你在负重前行是因为有人替你岁月静好。在这种情况下,越来越多的战友再也坚持不下去了,选择了离开。看到身边那些优秀的同事一个个离职,我的内心又感叹又难过。在这种作战一样的环境下,我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方,堪称良师。看到他们去了诸如字节Seed,Deepseek,月之暗面,腾讯和快手等等很多出色的团队,我打心眼里为他们高兴和祝福,脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新,ta说:“来这里是我技术生涯中的耻辱,在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足,以及没法适应互联网公司高淘汰的环境,让我多次想离职的心始终没有迈出这一步。
|
||||
|
||||
盘古除了dense模型,后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的,小模型实验室也开启了第二次主要的套壳行动(次要的插曲可能还包括一些别的模型,比如math模型),即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的(就算如此,这也与技术报告不符,何况是套壳qwen 2.5的14b续训)。还记得他们训了没几天,内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型,都知道他们的套壳行动,只是迫于各种原因,无法伸张正义。实际上,对于后续训了很久很久的这个模型,Honestagi能够分析出这个量级的相似性我已经很诧异了,因为这个模型为了续训洗参数,所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印,采取了不少办法,甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
|
||||
|
||||
24年底和25年初,在Deepseek v3和r1发布之后,由于其惊艳的技术水平,团队受到了巨大的冲击,也受到了更大的质疑。于是为了紧跟潮流,盘古模仿Deepseek的模型尺寸,开启了718B moe的训练。这个时候,小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数,进行训练。连任务加载ckpt的目录都是deepseekv3,改都不改,何其嚣张?与之相反,一些有真正技术信仰的同事,在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然,这个模型怎么可能比直接套壳的好呢?如果不是团队leader坚持,早就被叫停了。
|
||||
|
||||
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
|
||||
|
||||
HonestAGI的事情出来后,内部让大家不停的研讨分析,如何公关和“回应”。诚然,这个原文的分析也许不够有力,给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此,这两天我内心感到作呕,时时怀疑自己的人生意义以及苍天无眼。我不奉陪了,我要离职了,同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到,他们竟然猖狂到敢开源。我没想到,他们敢如此愚弄世人,大肆宣发。当时,我也许是存了侥幸心理,没有拒绝署名。我相信很多扎实做事的战友,也只是被迫上了贼船,或者不知情。但这件事已经无法挽回,我希望我的余生能够坚持扎实做真正有意义的事,为我当时的软弱和不坚定赎罪。
|
||||
|
||||
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
|
||||
|
||||
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
|
||||
|
||||
现在,我累了,我想投降。
|
||||
|
||||
其实时至今日,我还是真心希望华为能认真吸取教训,能做好盘古,把盘古做到世界一流,把昇腾变成英伟达的水平。内部的劣币驱逐良币,使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着,施展着他们的抱负才华,为中美在AI的激烈竞赛中奉献力量。我时常感叹,华为不是没有人才,而是根本不知道怎么留住人才。如果给这些人合适的环境,合适的资源,更少的枷锁,更少的政治斗争,盘古何愁不成?
|
||||
|
||||
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
|
||||
|
||||
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
|
||||
|
||||
如果我消失了,就当是我为了真理和理想,为了华为乃至中国能够更好地发展算力和AI而牺牲了吧,我愿埋葬于那片曾经奋斗过的地方。
|
||||
|
||||
诺亚,再见
|
||||
|
||||
2025年7月6日凌晨 写于深圳
|
||||
|
||||
---
|
||||
|
||||
# LEANN-RAG Evaluation Data
|
||||
各位好,
|
||||
|
||||
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
|
||||
|
||||
## Dataset Components
|
||||
我补充一些细节,以免某些人继续颠倒黑白。
|
||||
|
||||
This dataset is structured into three main parts:
|
||||
关于135B V2,小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后(比如任务令表彰和及时激励),因为不想继续支撑下游应用和模型迭代,又把这个烫手山芋甩给了四纵。确实技高一筹,直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型,最终拿回了一个当时一个魔改的先进的千问。做大模型的人,自己做的模型就像自己孩子一样熟悉,不要把别人都当傻子。就像自家儿子出门一趟,回来个别人家孩子。
|
||||
|
||||
1. **Pre-built LEANN Indices**:
|
||||
* `dpr/`: A pre-built index for the DPR dataset.
|
||||
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||
盘古report的署名是不符合学术规范的。例如,135B V3有不少有技术贡献的人,因为作者名额数量限制,劳动成果没有得到应有的回报,团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶,甚至是团队当时的精神支柱,支撑着不少兄弟们继续留在诺亚。所谓的名额限制,以及挂名了一些毫无技术贡献的人(如一些小模型实验室的人),让兄弟们何其心寒。
|
||||
|
||||
2. **Ground Truth Data**:
|
||||
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||
---
|
||||
|
||||
3. **Queries**:
|
||||
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||
|
||||
```bash
|
||||
pip install huggingface-hub
|
||||
```
|
||||
|
||||
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir="data"
|
||||
)
|
||||
```
|
||||
|
||||
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||
|
||||
## 🎨 Developer Experience
|
||||
|
||||
|
||||
@@ -72,4 +72,4 @@ Using the wrong distance metric with normalized embeddings can lead to:
|
||||
- **Incorrect ranking** of search results
|
||||
- **Suboptimal performance** compared to using the correct metric
|
||||
|
||||
For more details on why this happens, see our analysis of [OpenAI embeddings with MIPS](../examples/main_cli_example.py).
|
||||
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Simple demo showing basic leann usage
|
||||
Run: uv run python examples/simple_demo.py
|
||||
Run: uv run python examples/basic_demo.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -81,7 +81,7 @@ def main():
|
||||
print()
|
||||
|
||||
print("Demo completed! Try running:")
|
||||
print(" uv run python examples/document_search.py")
|
||||
print(" uv run python apps/document_rag.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1,82 +0,0 @@
|
||||
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
|
||||
|
||||
各位好,
|
||||
|
||||
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
|
||||
|
||||
首先为自证身份,列举一些细节:
|
||||
|
||||
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
|
||||
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
|
||||
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
|
||||
4. 诺亚曾经传说是研究型的,但是来了之后因为在四野做大模型项目,项目成员完全变成了交付型的,且充满了例会,评审,汇报。很多时候做实验都要申请。团队需要对接终端小艺,华为云,ICT等诸多业务线,交付压力不小。
|
||||
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”,一开始只有内部需要申请试用的网页版,到后续迫于压力在welink上接入和公测开放。
|
||||
|
||||
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
|
||||
|
||||
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
|
||||
|
||||
华为确实主要在昇腾卡上训练大模型(小模型实验室有不少英伟达的卡,他们之前也会用来训练,后面转移到昇腾)。曾经我被华为“打造世界第二选择”的决心而折服,我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打,从充满bug到现在能训出模型,付出了巨大的心血和代价。
|
||||
|
||||
最初我们的算力非常有限,在910A上训练模型。那会只支持fp16,训练的稳定性远不如bf16。盘古的moe开始很早,23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型,后面主力模型也逐渐在910B上训练。
|
||||
|
||||
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低,每个单个的符号,数字,空格,乃至汉字都会占用一个token。可想而知这会非常浪费算力,且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好(虽然事后来看,他的怀疑是无疑正确的),于是就决定,让71B和135B换tokenizer,因为小模型实验室曾经尝试过。团队缝合了两个tokenizer,开始了tokenizer的更换。71B模型的更换失败了,而135B因为采用了更精细的embedding初始化策略,续训了至少1T的数据后词表总算更换成功,但可想而知,效果并不会变好。
|
||||
|
||||
于此同期,阿里和智谱等国内其他公司在GPU上训练,且已经摸索出了正确的方法,盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败,导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时,团队的士气低迷到了极点。团队在算力极其有限的时候,做出了很多努力和挣扎。比如,团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数,还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B,架构相对落后,团队进行了一系列的操作,比如切换绝对位置编码到rope,去掉bias,切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验,这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训,变成了第二代38B dense模型(在几个月内这个模型都是主要的盘古中档位模型),曾经具有一定的竞争力。但是,由于更大的135B模型架构落后,且更换词表模型损伤巨大(后续分析发现当时更换的缝合词表有更严重的bug),续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
|
||||
|
||||
在这种情况下,王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来,通过训练短短的几百B数据,各项指标平均提升了十个点左右。实际上,这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行,使得领导完全对于这种扯淡的事情没有概念,他们只会觉得肯定是有什么算法创新。经过内部的分析,他们实际上是使用Qwen 1.5 110B续训而来,通过加层,扩增ffn维度,添加盘古pi论文的一些机制得来,凑够了大概135B的参数。实际上,旧的135B有107层,而这个模型只有82层,各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen,甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游,甚至包括外部客户。
|
||||
|
||||
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击,内部很多人其实都知道这件事,甚至包括终端和华为云。我们都戏称以后别叫盘古模型了,叫千古吧。当时团队成员就想向bcg举报了,毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来,因为更高级别的领导(比如姚老师,以及可能熊总和查老)其实后面也知道了,但是并不管,因为通过套壳拿出好的结果,对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷,离职跑路也逐渐成为挂在嘴边的事。
|
||||
|
||||
此时,盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来,当时诺亚完全没有掌握从头训练的技术,何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下,盘古开始了第三代模型的训练,付出了巨大的努力后,在数据架构和训练算法方面都与业界逐渐接轨,而这其中的艰辛和小模型实验室的人一点关系都没有。
|
||||
|
||||
一开始团队成员毫无信心,只从一个13B的模型开始训练,但是后面发现效果还不错,于是这个模型后续再次进行了一次参数扩增,变成了第三代的38B,代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的(也是业界常见的做法)。而当时王云鹤的实验室做出来了另一个词表(也就是后续pangu系列的词表)。当时两个词表还被迫进行了一次赛马,最终没有明显的好坏结论。于是,领导当即决定,应该统一词表,使用王云鹤他们的。于是,在后续从头训练的135B V3(也就是对外的Pangu Ultra),便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑,为什么当时同为V3代的两个不同档位的模型,会使用不同的tokenizer。
|
||||
|
||||
|
||||
我们打心眼里觉得,135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的,华为全栈自研,正经从头训练的千亿级别的模型,且效果与24年同期竞品可比的。写到这里我已经热泪盈眶,太不容易了。当时为了稳定训练,团队做了大量实验对比,并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难,我们做到了,我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨,我们为了它的训练而不眠。在被内部心声骂的一文不值的时候,我们有多么不甘,有多少的委屈,我们挺住了。
|
||||
|
||||
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
|
||||
|
||||
然而,我们的所有辛苦的成果,经常被小模型实验室轻飘飘的拿走了。数据,直接要走。代码,直接要走,还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦,他们取得荣耀。果然应了那句话,你在负重前行是因为有人替你岁月静好。在这种情况下,越来越多的战友再也坚持不下去了,选择了离开。看到身边那些优秀的同事一个个离职,我的内心又感叹又难过。在这种作战一样的环境下,我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方,堪称良师。看到他们去了诸如字节Seed,Deepseek,月之暗面,腾讯和快手等等很多出色的团队,我打心眼里为他们高兴和祝福,脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新,ta说:“来这里是我技术生涯中的耻辱,在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足,以及没法适应互联网公司高淘汰的环境,让我多次想离职的心始终没有迈出这一步。
|
||||
|
||||
盘古除了dense模型,后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的,小模型实验室也开启了第二次主要的套壳行动(次要的插曲可能还包括一些别的模型,比如math模型),即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的(就算如此,这也与技术报告不符,何况是套壳qwen 2.5的14b续训)。还记得他们训了没几天,内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型,都知道他们的套壳行动,只是迫于各种原因,无法伸张正义。实际上,对于后续训了很久很久的这个模型,Honestagi能够分析出这个量级的相似性我已经很诧异了,因为这个模型为了续训洗参数,所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印,采取了不少办法,甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
|
||||
|
||||
24年底和25年初,在Deepseek v3和r1发布之后,由于其惊艳的技术水平,团队受到了巨大的冲击,也受到了更大的质疑。于是为了紧跟潮流,盘古模仿Deepseek的模型尺寸,开启了718B moe的训练。这个时候,小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数,进行训练。连任务加载ckpt的目录都是deepseekv3,改都不改,何其嚣张?与之相反,一些有真正技术信仰的同事,在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然,这个模型怎么可能比直接套壳的好呢?如果不是团队leader坚持,早就被叫停了。
|
||||
|
||||
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
|
||||
|
||||
HonestAGI的事情出来后,内部让大家不停的研讨分析,如何公关和“回应”。诚然,这个原文的分析也许不够有力,给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此,这两天我内心感到作呕,时时怀疑自己的人生意义以及苍天无眼。我不奉陪了,我要离职了,同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到,他们竟然猖狂到敢开源。我没想到,他们敢如此愚弄世人,大肆宣发。当时,我也许是存了侥幸心理,没有拒绝署名。我相信很多扎实做事的战友,也只是被迫上了贼船,或者不知情。但这件事已经无法挽回,我希望我的余生能够坚持扎实做真正有意义的事,为我当时的软弱和不坚定赎罪。
|
||||
|
||||
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
|
||||
|
||||
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
|
||||
|
||||
现在,我累了,我想投降。
|
||||
|
||||
其实时至今日,我还是真心希望华为能认真吸取教训,能做好盘古,把盘古做到世界一流,把昇腾变成英伟达的水平。内部的劣币驱逐良币,使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着,施展着他们的抱负才华,为中美在AI的激烈竞赛中奉献力量。我时常感叹,华为不是没有人才,而是根本不知道怎么留住人才。如果给这些人合适的环境,合适的资源,更少的枷锁,更少的政治斗争,盘古何愁不成?
|
||||
|
||||
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
|
||||
|
||||
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
|
||||
|
||||
如果我消失了,就当是我为了真理和理想,为了华为乃至中国能够更好地发展算力和AI而牺牲了吧,我愿埋葬于那片曾经奋斗过的地方。
|
||||
|
||||
诺亚,再见
|
||||
|
||||
2025年7月6日凌晨 写于深圳
|
||||
|
||||
---
|
||||
|
||||
各位好,
|
||||
|
||||
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
|
||||
|
||||
我补充一些细节,以免某些人继续颠倒黑白。
|
||||
|
||||
关于135B V2,小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后(比如任务令表彰和及时激励),因为不想继续支撑下游应用和模型迭代,又把这个烫手山芋甩给了四纵。确实技高一筹,直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型,最终拿回了一个当时一个魔改的先进的千问。做大模型的人,自己做的模型就像自己孩子一样熟悉,不要把别人都当傻子。就像自家儿子出门一趟,回来个别人家孩子。
|
||||
|
||||
盘古report的署名是不符合学术规范的。例如,135B V3有不少有技术贡献的人,因为作者名额数量限制,劳动成果没有得到应有的回报,团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶,甚至是团队当时的精神支柱,支撑着不少兄弟们继续留在诺亚。所谓的名额限制,以及挂名了一些毫无技术贡献的人(如一些小模型实验室的人),让兄弟们何其心寒。
|
||||
|
||||
---
|
||||
|
||||
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317
|
||||
@@ -1,158 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Document search demo with recompute mode
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Import backend packages to trigger plugin registration
|
||||
try:
|
||||
import leann_backend_diskann # noqa: F401
|
||||
import leann_backend_hnsw # noqa: F401
|
||||
|
||||
print("INFO: Backend packages imported successfully.")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||
|
||||
# Import upper-level API from leann-core
|
||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def load_sample_documents():
|
||||
"""Create sample documents for demonstration"""
|
||||
docs = [
|
||||
{
|
||||
"title": "Intro to Python",
|
||||
"content": "Python is a high-level, interpreted language known for simplicity.",
|
||||
},
|
||||
{
|
||||
"title": "ML Basics",
|
||||
"content": "Machine learning builds systems that learn from data.",
|
||||
},
|
||||
{
|
||||
"title": "Data Structures",
|
||||
"content": "Data structures like arrays, lists, and graphs organize data.",
|
||||
},
|
||||
]
|
||||
return docs
|
||||
|
||||
|
||||
def main():
|
||||
print("==========================================================")
|
||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
||||
print("==========================================================")
|
||||
|
||||
INDEX_DIR = Path("./test_indices")
|
||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||
BACKEND_TO_TEST = "diskann"
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
# --- 1. Build index ---
|
||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
||||
|
||||
builder = LeannBuilder(backend_name=BACKEND_TO_TEST, graph_degree=32, complexity=64)
|
||||
|
||||
documents = load_sample_documents()
|
||||
print(f"Loaded {len(documents)} sample documents.")
|
||||
for doc in documents:
|
||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print("\nIndex built!")
|
||||
|
||||
# --- 2. Basic search demo ---
|
||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
||||
|
||||
query = "What is machine learning?"
|
||||
print(f"\nQuery: '{query}'")
|
||||
|
||||
print("\n--- Basic search mode (PQ computation) ---")
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=2)
|
||||
basic_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||
print(">>> Basic search results <<<")
|
||||
for i, res in enumerate(results, 1):
|
||||
print(
|
||||
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||
)
|
||||
|
||||
# --- 3. Recompute search demo ---
|
||||
print("\n[PHASE 3] Recompute search using embedding server...")
|
||||
|
||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
||||
|
||||
# Configure recompute parameters
|
||||
recompute_params = {
|
||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
||||
"skip_search_reorder": True, # Skip search reordering
|
||||
"dedup_node_dis": True, # Enable node distance deduplication
|
||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
||||
"batch_recompute": False, # Don't use batch recomputation
|
||||
"global_pruning": False, # Don't use global pruning
|
||||
"zmq_port": 5555, # ZMQ port
|
||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
||||
}
|
||||
|
||||
print("Recompute parameter configuration:")
|
||||
for key, value in recompute_params.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\n🔄 Executing Recompute search...")
|
||||
try:
|
||||
start_time = time.time()
|
||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
||||
recompute_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||
print(">>> Recompute search results <<<")
|
||||
for i, res in enumerate(recompute_results, 1):
|
||||
print(
|
||||
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||
)
|
||||
|
||||
# Compare results
|
||||
print("\n--- Result comparison ---")
|
||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
||||
|
||||
print("\nBasic search vs Recompute results:")
|
||||
for i in range(min(len(results), len(recompute_results))):
|
||||
basic_score = results[i].score
|
||||
recompute_score = recompute_results[i].score
|
||||
score_diff = abs(basic_score - recompute_score)
|
||||
print(
|
||||
f" Position {i + 1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}"
|
||||
)
|
||||
|
||||
if recompute_time > basic_time:
|
||||
print("✅ Recompute mode working correctly (more accurate but slower)")
|
||||
else:
|
||||
print("i️ Recompute time is unusually fast, network recomputation may not be enabled")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Recompute search failed: {e}")
|
||||
print("This usually indicates an embedding server connection issue")
|
||||
|
||||
# --- 4. Chat demo ---
|
||||
print("\n[PHASE 4] Starting chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
chat_response = chat.ask(query)
|
||||
print(f"You: {query}")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
print("\n==========================================================")
|
||||
print("✅ Demo finished successfully!")
|
||||
print("==========================================================")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,362 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
except ModuleNotFoundError:
|
||||
# python-dotenv is not installed; skip loading environment variables
|
||||
dotenv = None
|
||||
from pathlib import Path
|
||||
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||
|
||||
# Default Chrome profile path
|
||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs: list[Path],
|
||||
index_path: str = "chrome_history_index.leann",
|
||||
max_count: int = -1,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple Chrome profile data sources.
|
||||
|
||||
Args:
|
||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process per profile
|
||||
embedding_model: The embedding model to use
|
||||
embedding_mode: The embedding backend mode
|
||||
"""
|
||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Chrome profile directory
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing Chrome profile {i + 1}/{len(profile_dirs)}: {profile_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir), max_count=max_count
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {profile_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||
print(
|
||||
"\033[91mYou need to close or quit all chrome browser before running this script\033[0m"
|
||||
)
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode=embedding_mode,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
profile_path: str | None = None,
|
||||
index_path: str = "chrome_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from Chrome history data.
|
||||
|
||||
Args:
|
||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process
|
||||
embedding_model: The embedding model to use
|
||||
embedding_mode: The embedding backend mode
|
||||
"""
|
||||
print("Creating LEANN index from Chrome history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
documents = reader.load_data(chrome_profile_path=profile_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} history documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode=embedding_mode,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN Chrome History Reader - Create and query browser history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=DEFAULT_CHROME_PROFILE,
|
||||
help=f"Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./google_history_index",
|
||||
help="Directory to store the LEANN index (default: ./chrome_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of history entries to process (default: 1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="The embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-existing-index",
|
||||
action="store_true",
|
||||
help="Use existing index without rebuilding",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||
|
||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
if args.use_existing_index:
|
||||
# Use existing index without rebuilding
|
||||
if not Path(INDEX_PATH).exists():
|
||||
print(f"Error: Index file not found at {INDEX_PATH}")
|
||||
return
|
||||
print(f"Using existing index at {INDEX_PATH}")
|
||||
index_path = INDEX_PATH
|
||||
else:
|
||||
# Find Chrome profile directories
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
if args.auto_find_profiles:
|
||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found automatically. Exiting.")
|
||||
return
|
||||
else:
|
||||
# Use single specified profile
|
||||
profile_path = Path(args.chrome_profile)
|
||||
if not profile_path.exists():
|
||||
print(f"Chrome profile not found: {profile_path}")
|
||||
return
|
||||
profile_dirs = [profile_path]
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs, INDEX_PATH, args.max_entries, args.embedding_model, args.embedding_mode
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"What websites did I visit about machine learning?",
|
||||
"Find my search history about programming",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,342 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# Auto-detect user's mail path
|
||||
def get_mail_path():
|
||||
"""Get the mail path for the current user"""
|
||||
home_dir = os.path.expanduser("~")
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
|
||||
# Default mail path for macOS
|
||||
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_sources(
|
||||
messages_dirs: list[Path],
|
||||
index_path: str = "mail_index.leann",
|
||||
max_count: int = -1,
|
||||
include_html: bool = False,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple mail data sources.
|
||||
|
||||
Args:
|
||||
messages_dirs: List of Path objects pointing to Messages directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process per directory
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from multiple mail data sources...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Messages directory
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing Messages directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(messages_dir)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {messages_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
mail_path: str,
|
||||
index_path: str = "mail_index.leann",
|
||||
max_count: int = 1000,
|
||||
include_html: bool = False,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from mail data.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from mail data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
documents = reader.load_data(Path(mail_path))
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path, llm_config={"type": "openai", "model": "gpt-4o"})
|
||||
|
||||
print(f"You: {query}")
|
||||
import time
|
||||
|
||||
time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
)
|
||||
time.time()
|
||||
# print(f"Time taken: {end_time - start_time} seconds")
|
||||
# highlight the answer
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="LEANN Mail Reader - Create and query email index")
|
||||
# Remove --mail-path argument and auto-detect all Messages directories
|
||||
# Remove DEFAULT_MAIL_PATH
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./mail_index",
|
||||
help="Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-emails",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of emails to process (-1 means all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="Give me some funny advertisement about apple or other companies",
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-html",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Include HTML content in email processing (default: False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Embedding model to use (default: facebook/contriever)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"args: {args}")
|
||||
|
||||
# Automatically find all Messages directories under the current user's Mail directory
|
||||
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||
# messages_dirs = messages_dirs[:1]
|
||||
|
||||
print("len(messages_dirs): ", len(messages_dirs))
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Messages directories found. Exiting.")
|
||||
return
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_sources(
|
||||
messages_dirs,
|
||||
INDEX_PATH,
|
||||
args.max_emails,
|
||||
args.include_html,
|
||||
args.embedding_model,
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,135 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import torch
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# --- EMBEDDING MODEL ---
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
# --- END EMBEDDING MODEL ---
|
||||
# Import EmlxReader from the new module
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
def create_and_save_index(
|
||||
mail_path: str,
|
||||
save_dir: str = "mail_index_embedded",
|
||||
max_count: int = 1000,
|
||||
include_html: bool = False,
|
||||
):
|
||||
print("Creating index from mail data with embedded metadata...")
|
||||
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
# Use facebook/contriever as the embedder
|
||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||
# set on device
|
||||
|
||||
if torch.cuda.is_available():
|
||||
embed_model._model.to("cuda")
|
||||
# set mps
|
||||
elif torch.backends.mps.is_available():
|
||||
embed_model._model.to("mps")
|
||||
else:
|
||||
embed_model._model.to("cpu")
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[text_splitter], embed_model=embed_model
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index_embedded"):
|
||||
try:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LlamaIndex Mail Reader - Create and query email index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||
help="Path to mail data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-dir",
|
||||
type=str,
|
||||
default="mail_index_embedded",
|
||||
help="Directory to store the index (default: mail_index_embedded)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-emails",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum number of emails to process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-html",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Include HTML content in email processing (default: False)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mail_path = args.mail_path
|
||||
save_dir = args.save_dir
|
||||
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(
|
||||
mail_path,
|
||||
save_dir,
|
||||
max_count=args.max_emails,
|
||||
include_html=args.include_html,
|
||||
)
|
||||
if index:
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,146 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# LeannBuilder now automatically detects normalized embeddings and sets appropriate distance metric
|
||||
print(f"Using {args.embedding_model} with {args.embedding_mode} mode")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
# distance_metric is automatically set based on embedding model
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
# Build llm_config based on command line arguments
|
||||
if args.llm == "simulated":
|
||||
llm_config = {"type": "simulated"}
|
||||
elif args.llm == "ollama":
|
||||
llm_config = {"type": "ollama", "model": args.model, "host": args.host}
|
||||
elif args.llm == "hf":
|
||||
llm_config = {"type": "hf", "model": args.model}
|
||||
elif args.llm == "openai":
|
||||
llm_config = {"type": "openai", "model": args.model}
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM type: {args.llm}")
|
||||
|
||||
print(f"Using LLM: {args.llm} with model: {args.model if args.llm != 'simulated' else 'N/A'}")
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
# query = (
|
||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
# )
|
||||
query = args.query
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
||||
parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="openai",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
help="The LLM backend to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="gpt-4o",
|
||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="The embedding backend mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="The host for the Ollama API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="examples/data",
|
||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
|
||||
help="The query to ask the Leann chat system.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
@@ -1,360 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-Vector Aggregator for Fat Embeddings
|
||||
==========================================
|
||||
|
||||
This module implements aggregation strategies for multi-vector embeddings,
|
||||
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||
|
||||
Key features:
|
||||
- MaxSim aggregation (take maximum similarity across patches)
|
||||
- Voting-based aggregation (count patch matches)
|
||||
- Weighted aggregation (attention-score weighted)
|
||||
- Spatial clustering of matching patches
|
||||
- Document-level result consolidation
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchResult:
|
||||
"""Represents a single patch search result."""
|
||||
|
||||
patch_id: int
|
||||
image_name: str
|
||||
image_path: str
|
||||
coordinates: tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||
score: float
|
||||
attention_score: float
|
||||
scale: float
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatedResult:
|
||||
"""Represents an aggregated document-level result."""
|
||||
|
||||
image_name: str
|
||||
image_path: str
|
||||
doc_score: float
|
||||
patch_count: int
|
||||
best_patch: PatchResult
|
||||
all_patches: list[PatchResult]
|
||||
aggregation_method: str
|
||||
spatial_clusters: list[list[PatchResult]] | None = None
|
||||
|
||||
|
||||
class MultiVectorAggregator:
|
||||
"""
|
||||
Aggregates multiple patch-level results into document-level results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aggregation_method: str = "maxsim",
|
||||
spatial_clustering: bool = True,
|
||||
cluster_distance_threshold: float = 100.0,
|
||||
):
|
||||
"""
|
||||
Initialize the aggregator.
|
||||
|
||||
Args:
|
||||
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||
spatial_clustering: Whether to cluster spatially close patches
|
||||
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||
"""
|
||||
self.aggregation_method = aggregation_method
|
||||
self.spatial_clustering = spatial_clustering
|
||||
self.cluster_distance_threshold = cluster_distance_threshold
|
||||
|
||||
def aggregate_results(
|
||||
self, search_results: list[dict[str, Any]], top_k: int = 10
|
||||
) -> list[AggregatedResult]:
|
||||
"""
|
||||
Aggregate patch-level search results into document-level results.
|
||||
|
||||
Args:
|
||||
search_results: List of search results from LeannSearcher
|
||||
top_k: Number of top documents to return
|
||||
|
||||
Returns:
|
||||
List of aggregated document results
|
||||
"""
|
||||
# Group results by image
|
||||
image_groups = defaultdict(list)
|
||||
|
||||
for result in search_results:
|
||||
metadata = result.metadata
|
||||
if "image_name" in metadata and "patch_id" in metadata:
|
||||
patch_result = PatchResult(
|
||||
patch_id=metadata["patch_id"],
|
||||
image_name=metadata["image_name"],
|
||||
image_path=metadata["image_path"],
|
||||
coordinates=tuple(metadata["coordinates"]),
|
||||
score=result.score,
|
||||
attention_score=metadata.get("attention_score", 0.0),
|
||||
scale=metadata.get("scale", 1.0),
|
||||
metadata=metadata,
|
||||
)
|
||||
image_groups[metadata["image_name"]].append(patch_result)
|
||||
|
||||
# Aggregate each image group
|
||||
aggregated_results = []
|
||||
for image_name, patches in image_groups.items():
|
||||
if len(patches) == 0:
|
||||
continue
|
||||
|
||||
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||
aggregated_results.append(agg_result)
|
||||
|
||||
# Sort by aggregated score and return top-k
|
||||
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||
return aggregated_results[:top_k]
|
||||
|
||||
def _aggregate_image_patches(
|
||||
self, image_name: str, patches: list[PatchResult]
|
||||
) -> AggregatedResult:
|
||||
"""Aggregate patches for a single image."""
|
||||
|
||||
if self.aggregation_method == "maxsim":
|
||||
doc_score = max(patch.score for patch in patches)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "voting":
|
||||
# Count patches above threshold
|
||||
threshold = np.percentile([p.score for p in patches], 75)
|
||||
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "weighted":
|
||||
# Weight by attention scores
|
||||
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||
total_weights = sum(p.attention_score for p in patches)
|
||||
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||
|
||||
elif self.aggregation_method == "mean":
|
||||
doc_score = np.mean([patch.score for patch in patches])
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||
|
||||
# Spatial clustering if enabled
|
||||
spatial_clusters = None
|
||||
if self.spatial_clustering:
|
||||
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||
|
||||
return AggregatedResult(
|
||||
image_name=image_name,
|
||||
image_path=patches[0].image_path,
|
||||
doc_score=float(doc_score),
|
||||
patch_count=len(patches),
|
||||
best_patch=best_patch,
|
||||
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||
aggregation_method=self.aggregation_method,
|
||||
spatial_clusters=spatial_clusters,
|
||||
)
|
||||
|
||||
def _cluster_patches_spatially(self, patches: list[PatchResult]) -> list[list[PatchResult]]:
|
||||
"""Cluster patches that are spatially close to each other."""
|
||||
if len(patches) <= 1:
|
||||
return [patches]
|
||||
|
||||
clusters = []
|
||||
remaining_patches = patches.copy()
|
||||
|
||||
while remaining_patches:
|
||||
# Start new cluster with highest scoring remaining patch
|
||||
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||
current_cluster = [seed_patch]
|
||||
remaining_patches.remove(seed_patch)
|
||||
|
||||
# Add nearby patches to cluster
|
||||
added_to_cluster = True
|
||||
while added_to_cluster:
|
||||
added_to_cluster = False
|
||||
for patch in remaining_patches.copy():
|
||||
if self._is_patch_nearby(patch, current_cluster):
|
||||
current_cluster.append(patch)
|
||||
remaining_patches.remove(patch)
|
||||
added_to_cluster = True
|
||||
|
||||
clusters.append(current_cluster)
|
||||
|
||||
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||
|
||||
def _is_patch_nearby(self, patch: PatchResult, cluster: list[PatchResult]) -> bool:
|
||||
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||
patch_center = self._get_patch_center(patch.coordinates)
|
||||
|
||||
for cluster_patch in cluster:
|
||||
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||
distance = np.sqrt(
|
||||
(patch_center[0] - cluster_center[0]) ** 2
|
||||
+ (patch_center[1] - cluster_center[1]) ** 2
|
||||
)
|
||||
|
||||
if distance <= self.cluster_distance_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_patch_center(self, coordinates: tuple[int, int, int, int]) -> tuple[float, float]:
|
||||
"""Get center point of a patch."""
|
||||
x1, y1, x2, y2 = coordinates
|
||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
def print_aggregated_results(
|
||||
self, results: list[AggregatedResult], max_patches_per_doc: int = 3
|
||||
):
|
||||
"""Pretty print aggregated results."""
|
||||
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||
print("=" * 80)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n{i + 1}. {result.image_name}")
|
||||
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||
print(f" Path: {result.image_path}")
|
||||
|
||||
# Show best patch
|
||||
best = result.best_patch
|
||||
print(
|
||||
f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})"
|
||||
)
|
||||
|
||||
# Show top patches
|
||||
print(" 📍 Top Patches:")
|
||||
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||
print(
|
||||
f" {j + 1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}"
|
||||
)
|
||||
|
||||
# Show spatial clusters if available
|
||||
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||
cluster_score = max(p.score for p in cluster)
|
||||
print(
|
||||
f" Cluster {j + 1}: {len(cluster)} patches (best: {cluster_score:.4f})"
|
||||
)
|
||||
|
||||
|
||||
def demo_aggregation():
|
||||
"""Demonstrate the multi-vector aggregation functionality."""
|
||||
print("=== Multi-Vector Aggregation Demo ===")
|
||||
|
||||
# Simulate some patch-level search results
|
||||
# In real usage, these would come from LeannSearcher.search()
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, score, metadata):
|
||||
self.score = score
|
||||
self.metadata = metadata
|
||||
|
||||
# Simulate results for 2 images with multiple patches each
|
||||
mock_results = [
|
||||
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||
MockResult(
|
||||
0.85,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 3,
|
||||
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||
"attention_score": 0.92,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.78,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 7,
|
||||
"coordinates": [200, 300, 324, 424], # Cat area
|
||||
"attention_score": 0.88,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.72,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 12,
|
||||
"coordinates": [150, 100, 274, 224], # Appliances
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.65,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 15,
|
||||
"coordinates": [50, 250, 174, 374], # Furniture
|
||||
"attention_score": 0.70,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
# Image 2: city_street.jpg - 3 patches
|
||||
MockResult(
|
||||
0.68,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 2,
|
||||
"coordinates": [300, 100, 424, 224], # Buildings
|
||||
"attention_score": 0.80,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.62,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 8,
|
||||
"coordinates": [100, 350, 224, 474], # Street level
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.55,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 11,
|
||||
"coordinates": [400, 200, 524, 324], # Sky area
|
||||
"attention_score": 0.60,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
# Test different aggregation methods
|
||||
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||
|
||||
for method in methods:
|
||||
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
|
||||
|
||||
aggregator = MultiVectorAggregator(
|
||||
aggregation_method=method,
|
||||
spatial_clustering=True,
|
||||
cluster_distance_threshold=100.0,
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||
aggregator.print_aggregated_results(aggregated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_aggregation()
|
||||
@@ -1,113 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OpenAI Embedding Example
|
||||
|
||||
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
# Load environment variables
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
def main():
|
||||
# Check if OpenAI API key is available
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("ERROR: OPENAI_API_KEY environment variable not set")
|
||||
return False
|
||||
|
||||
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
||||
|
||||
# Sample texts
|
||||
sample_texts = [
|
||||
"Machine learning is a powerful technology that enables computers to learn from data.",
|
||||
"Natural language processing helps computers understand and generate human language.",
|
||||
"Deep learning uses neural networks with multiple layers to solve complex problems.",
|
||||
"Computer vision allows machines to interpret and understand visual information.",
|
||||
"Reinforcement learning trains agents to make decisions through trial and error.",
|
||||
"Data science combines statistics, math, and programming to extract insights from data.",
|
||||
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
||||
"Python is a popular programming language used extensively in data science and AI.",
|
||||
"Neural networks are inspired by the structure and function of the human brain.",
|
||||
"Big data refers to extremely large datasets that require special tools to process.",
|
||||
]
|
||||
|
||||
INDEX_DIR = Path("./simple_openai_test_index")
|
||||
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
||||
|
||||
print("\n=== Building Index with OpenAI Embeddings ===")
|
||||
print(f"Index path: {INDEX_PATH}")
|
||||
|
||||
try:
|
||||
# Use proper configuration for OpenAI embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
# HNSW settings for OpenAI embeddings
|
||||
M=16, # Smaller graph degree
|
||||
efConstruction=64, # Smaller construction complexity
|
||||
is_compact=True, # Enable compact storage for recompute
|
||||
is_recompute=True, # MUST enable for OpenAI embeddings
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
print(f"Adding {len(sample_texts)} texts to the index...")
|
||||
for i, text in enumerate(sample_texts):
|
||||
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
||||
builder.add_text(text, metadata)
|
||||
|
||||
print("Building index...")
|
||||
builder.build_index(INDEX_PATH)
|
||||
print("✅ Index built successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error building index: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print("\n=== Testing Search ===")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
|
||||
test_queries = [
|
||||
"What is machine learning?",
|
||||
"How do neural networks work?",
|
||||
"Programming languages for data science",
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\n🔍 Query: '{query}'")
|
||||
results = searcher.search(query, top_k=3)
|
||||
|
||||
print(f" Found {len(results)} results:")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i + 1}. Score: {result.score:.4f}")
|
||||
print(f" Text: {result.text[:80]}...")
|
||||
|
||||
print("\n✅ Search test completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during search: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
if success:
|
||||
print("\n🎉 Simple OpenAI index test completed successfully!")
|
||||
else:
|
||||
print("\n💥 Simple OpenAI index test failed!")
|
||||
@@ -1,23 +0,0 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from leann.api import LeannChat
|
||||
|
||||
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
|
||||
async def main():
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
response = chat.ask(
|
||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1
|
||||
)
|
||||
print(f"\n[PHASE 2] Response: {response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,320 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Default WeChat export directory
|
||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs: list[Path],
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = -1,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple WeChat export data sources.
|
||||
|
||||
Args:
|
||||
export_dirs: List of Path objects pointing to WeChat export directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process per export
|
||||
"""
|
||||
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each WeChat export directory
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=True, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = (
|
||||
"[Contact] means the message is from: "
|
||||
+ doc.metadata["contact_name"]
|
||||
+ "\n"
|
||||
+ node.get_content()
|
||||
)
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
export_dir: str | None = None,
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from WeChat chat history data.
|
||||
|
||||
Args:
|
||||
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process
|
||||
"""
|
||||
print("Creating LEANN index from WeChat chat history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=export_dir,
|
||||
max_count=max_count,
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} chat documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=16,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function with integrated WeChat export functionality."""
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_magic_test_11Debug_new",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||
|
||||
print(f"Using WeChat export directory: {args.export_dir}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Initialize WeChat reader with export capabilities
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Exiting.")
|
||||
return
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -459,7 +459,14 @@ class LeannSearcher:
|
||||
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {self.meta_path_str}")
|
||||
parent_dir = Path(index_path).parent
|
||||
print(
|
||||
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
|
||||
)
|
||||
# highlight in red the filenotfound error
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
|
||||
)
|
||||
with open(self.meta_path_str, encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
@@ -493,6 +500,16 @@ class LeannSearcher:
|
||||
logger.info(f" Top_k: {top_k}")
|
||||
logger.info(f" Additional kwargs: {kwargs}")
|
||||
|
||||
# Smart top_k detection and adjustment
|
||||
total_docs = len(self.passage_manager.global_offset_map)
|
||||
original_top_k = top_k
|
||||
if top_k > total_docs:
|
||||
top_k = total_docs
|
||||
logger.warning(
|
||||
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
|
||||
)
|
||||
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
|
||||
|
||||
zmq_port = None
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -134,6 +134,14 @@ dev = [
|
||||
"ruff>=0.12.4",
|
||||
]
|
||||
|
||||
[tool.lychee]
|
||||
accept = ["200", "403", "429", "503"]
|
||||
timeout = 20
|
||||
max_retries = 2
|
||||
exclude = ["localhost", "127.0.0.1", "example.com"]
|
||||
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
||||
scheme = ["https", "http"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if (
|
||||
part.get_content_type() == "text/plain"
|
||||
or part.get_content_type() == "text/html"
|
||||
):
|
||||
body += part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
# break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
if count == 0:
|
||||
print("--------------------------------")
|
||||
print("dir path", dirpath)
|
||||
print(metadata)
|
||||
print(doc_content)
|
||||
print("--------------------------------")
|
||||
body = []
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
print(
|
||||
"-------------------------------- get content type -------------------------------"
|
||||
)
|
||||
print(part.get_content_type())
|
||||
print(part)
|
||||
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
|
||||
print(
|
||||
"-------------------------------- get content type -------------------------------"
|
||||
)
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
print(body)
|
||||
|
||||
print(body)
|
||||
print("--------------------------------")
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
# Use the custom EmlxReader instead of MboxReader
|
||||
documents = EmlxReader().load_data(
|
||||
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||
max_count=1000,
|
||||
) # Returns list of documents
|
||||
|
||||
# Configure the index with larger chunk size to handle long metadata
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Create a custom text splitter with larger chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[text_splitter]
|
||||
) # Initialize index with documents
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||
print(res)
|
||||
@@ -1,219 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
|
||||
"""
|
||||
Create the index from mail data and save it to disk.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
save_dir: Directory to save the index
|
||||
max_count: Maximum number of emails to process
|
||||
"""
|
||||
print("Creating index from mail data...")
|
||||
|
||||
# Load documents
|
||||
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
# Create text splitter
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
|
||||
|
||||
# Save the index
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||
save_dir = "mail_index"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||
|
||||
if index:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"What emails mention GSR appointments?",
|
||||
"Find emails about deadlines",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,219 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with reduced metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create minimal metadata (only essential info)
|
||||
metadata = {
|
||||
"subject": subject[:50], # Truncate subject
|
||||
"from": from_addr[:30], # Truncate from
|
||||
"date": date[:20], # Truncate date
|
||||
"filename": filename, # Keep filename
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def create_and_save_index(
|
||||
mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000
|
||||
):
|
||||
"""
|
||||
Create the index from mail data and save it to disk.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
save_dir: Directory to save the index
|
||||
max_count: Maximum number of emails to process
|
||||
"""
|
||||
print("Creating index from mail data with small chunks...")
|
||||
|
||||
# Load documents
|
||||
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
# Create text splitter with small chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
|
||||
|
||||
# Save the index
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index_small"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||
save_dir = "mail_index_small"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||
|
||||
if index:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"What emails mention GSR appointments?",
|
||||
"Find emails about deadlines",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,154 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Check if directory exists and is accessible
|
||||
if not os.path.exists(input_dir):
|
||||
print(f"Error: Directory '{input_dir}' does not exist")
|
||||
return docs
|
||||
|
||||
if not os.access(input_dir, os.R_OK):
|
||||
print(f"Error: Directory '{input_dir}' is not accessible (permission denied)")
|
||||
print("This is likely due to macOS security restrictions on Mail app data")
|
||||
return docs
|
||||
|
||||
print(f"Scanning directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
print(f"Found .emlx file: {filepath}")
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def main():
|
||||
# Use the current directory where the sample.emlx file is located
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
print("Testing EmlxReader with sample .emlx file...")
|
||||
print(f"Scanning directory: {current_dir}")
|
||||
|
||||
# Use the custom EmlxReader
|
||||
documents = EmlxReader().load_data(current_dir, max_count=1000)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Make sure sample.emlx exists in the examples directory.")
|
||||
return
|
||||
|
||||
print(f"\nSuccessfully loaded {len(documents)} document(s)")
|
||||
|
||||
# Initialize index with documents
|
||||
index = VectorStoreIndex.from_documents(documents)
|
||||
query_engine = index.as_query_engine()
|
||||
|
||||
print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'")
|
||||
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||
print(f"Response: {res}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,105 +0,0 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"\nQuery: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
save_dir = "mail_index"
|
||||
|
||||
# Check if index exists
|
||||
if not os.path.exists(save_dir) or not os.path.exists(
|
||||
os.path.join(save_dir, "vector_store.json")
|
||||
):
|
||||
print(f"Index not found in {save_dir}")
|
||||
print("Please run mail_reader_save_load.py first to create the index.")
|
||||
return
|
||||
|
||||
# Load the index
|
||||
index = load_index(save_dir)
|
||||
|
||||
if not index:
|
||||
print("Failed to load index.")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Email Query Interface")
|
||||
print("=" * 60)
|
||||
print("Type 'quit' to exit")
|
||||
print("Type 'help' for example queries")
|
||||
print("=" * 60)
|
||||
|
||||
# Interactive query loop
|
||||
while True:
|
||||
try:
|
||||
query = input("\nEnter your query: ").strip()
|
||||
|
||||
if query.lower() == "quit":
|
||||
print("Goodbye!")
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("\nExample queries:")
|
||||
print("- Hows Berkeley Graduate Student Instructor")
|
||||
print("- What emails mention GSR appointments?")
|
||||
print("- Find emails about deadlines")
|
||||
print("- Search for emails from specific sender")
|
||||
print("- Find emails about meetings")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
query_index(index, query)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error processing query: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import zmq
|
||||
|
||||
sys.path.append("packages/leann-backend-diskann")
|
||||
from leann_backend_diskann import embedding_pb2
|
||||
|
||||
|
||||
def test_zmq_with_same_model():
|
||||
print("=== Testing ZMQ with same model as main_cli_example.py ===")
|
||||
|
||||
# Test the exact same model that main_cli_example.py uses
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
# Start server with the same model
|
||||
import subprocess
|
||||
|
||||
server_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
||||
"--zmq-port",
|
||||
"5556", # Use different port to avoid conflicts
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
print(f"Starting server with command: {' '.join(server_cmd)}")
|
||||
server_process = subprocess.Popen(
|
||||
server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
|
||||
# Wait for server to start
|
||||
print("Waiting for server to start...")
|
||||
time.sleep(10)
|
||||
|
||||
# Check if server is running
|
||||
if server_process.poll() is not None:
|
||||
stdout, stderr = server_process.communicate()
|
||||
print(f"Server failed to start. stdout: {stdout}")
|
||||
print(f"Server failed to start. stderr: {stderr}")
|
||||
return False
|
||||
|
||||
print(f"Server started with PID: {server_process.pid}")
|
||||
|
||||
try:
|
||||
# Test client
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect("tcp://127.0.0.1:5556")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
# Create request with same format as C++
|
||||
request = embedding_pb2.NodeEmbeddingRequest()
|
||||
request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs
|
||||
|
||||
print(f"Sending request with {len(request.node_ids)} node IDs...")
|
||||
start_time = time.time()
|
||||
|
||||
# Send request
|
||||
socket.send(request.SerializeToString())
|
||||
|
||||
# Receive response
|
||||
response_data = socket.recv()
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Received response in {end_time - start_time:.3f} seconds")
|
||||
print(f"Response size: {len(response_data)} bytes")
|
||||
|
||||
# Parse response
|
||||
response = embedding_pb2.NodeEmbeddingResponse()
|
||||
response.ParseFromString(response_data)
|
||||
|
||||
print(f"Response dimensions: {list(response.dimensions)}")
|
||||
print(f"Embeddings data size: {len(response.embeddings_data)} bytes")
|
||||
print(f"Missing IDs: {list(response.missing_ids)}")
|
||||
|
||||
# Calculate expected size
|
||||
if len(response.dimensions) == 2:
|
||||
batch_size = response.dimensions[0]
|
||||
embedding_dim = response.dimensions[1]
|
||||
expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float
|
||||
print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}")
|
||||
|
||||
if len(response.embeddings_data) == expected_bytes:
|
||||
print("✅ Response format is correct!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Response format mismatch!")
|
||||
return False
|
||||
else:
|
||||
print("❌ Invalid response dimensions!")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during ZMQ test: {e}")
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
server_process.terminate()
|
||||
server_process.wait()
|
||||
print("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_zmq_with_same_model()
|
||||
if success:
|
||||
print("\n✅ ZMQ communication test passed!")
|
||||
else:
|
||||
print("\n❌ ZMQ communication test failed!")
|
||||
@@ -18,8 +18,8 @@ Basic functionality tests that verify:
|
||||
- Basic index building and searching works for both HNSW and DiskANN backends
|
||||
- Uses parametrized tests to test both backends
|
||||
|
||||
### `test_main_cli.py`
|
||||
Tests the main CLI example functionality:
|
||||
### `test_document_rag.py`
|
||||
Tests the document RAG example functionality:
|
||||
- Tests with facebook/contriever embeddings
|
||||
- Tests with OpenAI embeddings (if API key is available)
|
||||
- Tests error handling with invalid parameters
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_package_imports():
|
||||
def test_cli_help():
|
||||
"""Test that CLI example shows help."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, "examples/main_cli_example.py", "--help"], capture_output=True, text=True
|
||||
[sys.executable, "apps/document_rag.py", "--help"], capture_output=True, text=True
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Test main_cli_example functionality using pytest.
|
||||
Test document_rag functionality using pytest.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -14,20 +14,20 @@ import pytest
|
||||
@pytest.fixture
|
||||
def test_data_dir():
|
||||
"""Return the path to test data directory."""
|
||||
return Path("examples/data")
|
||||
return Path("data")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
def test_main_cli_simulated(test_data_dir):
|
||||
"""Test main_cli with simulated LLM."""
|
||||
def test_document_rag_simulated(test_data_dir):
|
||||
"""Test document_rag with simulated LLM."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use a subdirectory that doesn't exist yet to force index creation
|
||||
index_dir = Path(temp_dir) / "test_index"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"examples/main_cli_example.py",
|
||||
"apps/document_rag.py",
|
||||
"--llm",
|
||||
"simulated",
|
||||
"--embedding-model",
|
||||
@@ -53,19 +53,19 @@ def test_main_cli_simulated(test_data_dir):
|
||||
|
||||
# Verify output
|
||||
output = result.stdout + result.stderr
|
||||
assert "Leann index built at" in output or "Using existing index" in output
|
||||
assert "Index saved to" in output or "Using existing index" in output
|
||||
assert "This is a simulated answer" in output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
||||
def test_main_cli_openai(test_data_dir):
|
||||
"""Test main_cli with OpenAI embeddings."""
|
||||
def test_document_rag_openai(test_data_dir):
|
||||
"""Test document_rag with OpenAI embeddings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use a subdirectory that doesn't exist yet to force index creation
|
||||
index_dir = Path(temp_dir) / "test_index_openai"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"examples/main_cli_example.py",
|
||||
"apps/document_rag.py",
|
||||
"--llm",
|
||||
"simulated", # Use simulated LLM to avoid GPT-4 costs
|
||||
"--embedding-model",
|
||||
@@ -99,12 +99,12 @@ def test_main_cli_openai(test_data_dir):
|
||||
)
|
||||
|
||||
|
||||
def test_main_cli_error_handling(test_data_dir):
|
||||
"""Test main_cli with invalid parameters."""
|
||||
def test_document_rag_error_handling(test_data_dir):
|
||||
"""Test document_rag with invalid parameters."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"examples/main_cli_example.py",
|
||||
"apps/document_rag.py",
|
||||
"--llm",
|
||||
"invalid_llm_type",
|
||||
"--index-dir",
|
||||
@@ -117,4 +117,4 @@ def test_main_cli_error_handling(test_data_dir):
|
||||
|
||||
# Should fail with invalid LLM type
|
||||
assert result.returncode != 0
|
||||
assert "Unknown LLM type" in result.stderr or "invalid_llm_type" in result.stderr
|
||||
assert "invalid choice" in result.stderr or "invalid_llm_type" in result.stderr
|
||||
Reference in New Issue
Block a user