Compare commits
138 Commits
master
...
perf-build
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d084f678c | ||
|
|
54155e8b10 | ||
|
|
5259ace111 | ||
|
|
48ea5566e9 | ||
|
|
3f8b6c5bbd | ||
|
|
725b32e74f | ||
|
|
d0b71f393f | ||
|
|
8a92efdae3 | ||
|
|
019cdce2e8 | ||
|
|
b64aa54fac | ||
|
|
c0d040f9d4 | ||
|
|
32364320f8 | ||
|
|
f47f76d6d7 | ||
|
|
1dc3923b53 | ||
|
|
7e226a51c9 | ||
|
|
f4998bb316 | ||
|
|
7522de1d41 | ||
|
|
15f8bd1cc9 | ||
|
|
34c71c072d | ||
|
|
6d2149c503 | ||
|
|
043b0bf69d | ||
|
|
9b07e392c6 | ||
|
|
e60fad8c73 | ||
|
|
19c1b182c3 | ||
|
|
49edea780c | ||
|
|
12ef5a1900 | ||
|
|
d21a134b2a | ||
|
|
1cd809aa41 | ||
|
|
e728449b8f | ||
|
|
d0c20b14d5 | ||
|
|
83b7ea5a59 | ||
|
|
0796a52df1 | ||
|
|
85b7ba0168 | ||
|
|
e117743d24 | ||
|
|
aec2291f04 | ||
|
|
335ae003ac | ||
|
|
71c7de9c84 | ||
|
|
1c5fec5565 | ||
|
|
99d439577d | ||
|
|
4f83086788 | ||
|
|
a13c527e39 | ||
|
|
90d9f27383 | ||
|
|
0db81c16cd | ||
|
|
e115e186b7 | ||
|
|
6546b29ef7 | ||
|
|
51255bdffa | ||
|
|
f77c4e38cb | ||
|
|
2a1a152073 | ||
|
|
7b9406a3ea | ||
|
|
c3fb949693 | ||
|
|
ed3f8dbfd6 | ||
|
|
42aa6db170 | ||
|
|
a6591d20ca | ||
|
|
c1bc2603a2 | ||
|
|
e595bbb5fb | ||
|
|
4a2cb914d7 | ||
|
|
b1c93fe178 | ||
|
|
0719458775 | ||
|
|
6a1dc895fb | ||
|
|
125c1f6f25 | ||
|
|
1ceaa7d709 | ||
|
|
dec3ee85fd | ||
|
|
d94a5176dc | ||
|
|
326783f7f1 | ||
|
|
e5a9ca8787 | ||
|
|
f2feccdbd0 | ||
|
|
246a077d64 | ||
|
|
3ba100ff25 | ||
|
|
1e3b571e72 | ||
|
|
b89e56e9c2 | ||
|
|
ed8a02e721 | ||
|
|
baa60b40d1 | ||
|
|
ef01d6997a | ||
|
|
3da5b44d7f | ||
|
|
8b4654921b | ||
|
|
cf1cbafa78 | ||
|
|
c96091744b | ||
|
|
711fb4a775 | ||
|
|
3b5a185e60 | ||
|
|
77ac013a74 | ||
|
|
b8e5728e6a | ||
|
|
d038319d8b | ||
|
|
c611d0f30f | ||
|
|
c17899662f | ||
|
|
c51d5320fa | ||
|
|
6fa9512a64 | ||
|
|
fddc61df5e | ||
|
|
53c58fa755 | ||
|
|
c69afb56e4 | ||
|
|
0fa8a9191f | ||
|
|
48dda1cb5b | ||
|
|
71ef4b7d4c | ||
|
|
ecab43e307 | ||
|
|
88ca09440d | ||
|
|
8e0ab4a28d | ||
|
|
9b8c5041dc | ||
|
|
74ffd7ec64 | ||
|
|
eb6f504789 | ||
|
|
91a026f38b | ||
|
|
595138a0a3 | ||
|
|
19df04095f | ||
|
|
8239bbb48f | ||
|
|
16ee9d0422 | ||
|
|
8a961f8ab3 | ||
|
|
558126c46e | ||
|
|
04c9684488 | ||
|
|
b744faa7e6 | ||
|
|
27b3a26e75 | ||
|
|
41d872504e | ||
|
|
963cd05273 | ||
|
|
09b6e67baf | ||
|
|
dafb2aacab | ||
|
|
a6c400cd4f | ||
|
|
c013e5ccce | ||
|
|
f25a1a3840 | ||
|
|
6497e17671 | ||
|
|
44369a8138 | ||
|
|
dfca00c21b | ||
|
|
637dab379e | ||
|
|
6fc57eb48e | ||
|
|
95a653993a | ||
|
|
af0959818d | ||
|
|
cf17c85607 | ||
|
|
a38bc0a3fc | ||
|
|
449983c937 | ||
|
|
df63526503 | ||
|
|
e92deee1e8 | ||
|
|
910927a405 | ||
|
|
0aa84e147b | ||
|
|
368474d036 | ||
|
|
a627abe794 | ||
|
|
44815ee7fd | ||
|
|
371e3de04e | ||
|
|
b81b5d0f86 | ||
|
|
ee507bfe7a | ||
|
|
30898814ae | ||
|
|
a075fd6f47 | ||
|
|
303ff6fe1d |
15
.gitignore
vendored
15
.gitignore
vendored
@@ -8,11 +8,17 @@ demo/indices/
|
|||||||
*pycache*
|
*pycache*
|
||||||
outputs/
|
outputs/
|
||||||
*.pkl
|
*.pkl
|
||||||
|
*.pdf
|
||||||
|
*.idx
|
||||||
|
*.map
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
scripts/
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
|
*.eml
|
||||||
|
*.emlx
|
||||||
|
*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
@@ -29,6 +35,11 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
|
examples/data/*
|
||||||
|
!examples/data/2501.14312v1 (1).pdf
|
||||||
|
!examples/data/2506.08276v1.pdf
|
||||||
|
!examples/data/PrideandPrejudice.txt
|
||||||
|
!examples/data/README.md
|
||||||
*.qdstrm
|
*.qdstrm
|
||||||
benchmark_results/
|
benchmark_results/
|
||||||
results/
|
results/
|
||||||
@@ -41,6 +52,7 @@ embedding_comparison_results/
|
|||||||
*.ivecs
|
*.ivecs
|
||||||
*.index
|
*.index
|
||||||
*.bin
|
*.bin
|
||||||
|
*.old
|
||||||
|
|
||||||
read_graph
|
read_graph
|
||||||
analyze_diskann_graph
|
analyze_diskann_graph
|
||||||
@@ -70,3 +82,6 @@ test_indices*/
|
|||||||
test_*.py
|
test_*.py
|
||||||
!tests/**
|
!tests/**
|
||||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||||
|
|
||||||
|
*.meta.json
|
||||||
|
*.passages.json
|
||||||
14
.gitmodules
vendored
14
.gitmodules
vendored
@@ -1,6 +1,16 @@
|
|||||||
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
||||||
path = packages/leann-backend-diskann/third_party/DiskANN
|
path = packages/leann-backend-diskann/third_party/DiskANN
|
||||||
url = https://github.com/yichuan520030910320/DiskANN.git
|
url = https://github.com/yichuan-w/DiskANN.git
|
||||||
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
||||||
path = packages/leann-backend-hnsw/third_party/faiss
|
path = packages/leann-backend-hnsw/third_party/faiss
|
||||||
url = https://github.com/yichuan520030910320/faiss.git
|
url = https://github.com/yichuan-w/faiss.git
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/msgpack-c
|
||||||
|
url = https://github.com/msgpack/msgpack-c.git
|
||||||
|
branch = cpp_master
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/cppzmq"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/cppzmq
|
||||||
|
url = https://github.com/zeromq/cppzmq.git
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
|||||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2024 Rulin Shao
|
Copyright (c) 2025 LEANN Contributors
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
|||||||
531
README.md
531
README.md
@@ -1,170 +1,360 @@
|
|||||||
# 🚀 LEANN: A Low-Storage Vector Index
|
<p align="center">
|
||||||
|
<img src="assets/logo-text.png" alt="LEANN Logo" width="400">
|
||||||
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome">
|
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows-lightgrey" alt="Platform">
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
|
</h2>
|
||||||
|
|
||||||
|
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Why LEANN?
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>⚡ Real-time embedding computation for large-scale RAG on consumer hardware</strong>
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
|
||||||
<a href="#-quick-start">Quick Start</a> •
|
|
||||||
<a href="#-features">Features</a> •
|
|
||||||
<a href="#-benchmarks">Benchmarks</a> •
|
|
||||||
<a href="#-documentation">Documentation</a> •
|
|
||||||
<a href="#-paper">Paper</a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
---
|
## Why This Matters
|
||||||
|
|
||||||
## 🌟 What is Leann?
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
|
|
||||||
**Leann** revolutionizes Retrieval-Augmented Generation (RAG) by eliminating the storage bottleneck of traditional vector databases. Instead of pre-computing and storing billions of embeddings, Leann dynamically computes embeddings at query time using highly optimized graph-based search algorithms.
|
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||||
|
|
||||||
### 🎯 Why Leann?
|
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||||
|
|
||||||
Traditional RAG systems face a fundamental trade-off:
|
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||||
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
|
|
||||||
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
|
|
||||||
- **💰 Cost**: Vector databases are expensive to scale
|
|
||||||
|
|
||||||
**Leann solves this by:**
|
## Quick Start in 1 minute
|
||||||
- ✅ **Zero embedding storage** - Only graph structure is persisted
|
|
||||||
- ✅ **Real-time computation** - Embeddings computed on-demand with ms latency
|
|
||||||
- ✅ **Memory efficient** - Runs on consumer hardware (8GB RAM)
|
|
||||||
- ✅ **Always fresh** - No stale embeddings, ever
|
|
||||||
|
|
||||||
## 🚀 Quick Start
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/yichuan520030910320/Power-RAG.git leann
|
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||||
cd leann
|
cd leann
|
||||||
uv sync
|
git submodule update --init --recursive
|
||||||
```
|
```
|
||||||
|
|
||||||
### 30-Second Example
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq
|
||||||
|
export CC=$(brew --prefix llvm)/bin/clang
|
||||||
|
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||||
|
|
||||||
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
# Or add DiskANN backend if you want to test more options
|
||||||
|
uv sync --extra diskann
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (Ubuntu/Debian):**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
|
||||||
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
# Or add DiskANN backend if you want to test more options
|
||||||
|
uv sync --extra diskann
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ollama Setup (Optional for Local LLM):**
|
||||||
|
|
||||||
|
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
|
||||||
|
|
||||||
|
*macOS:*
|
||||||
|
|
||||||
|
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Pull a lightweight model (recommended for consumer hardware)
|
||||||
|
ollama pull llama3.2:1b
|
||||||
|
```
|
||||||
|
|
||||||
|
*Linux:*
|
||||||
|
```bash
|
||||||
|
# Install Ollama
|
||||||
|
curl -fsSL https://ollama.ai/install.sh | sh
|
||||||
|
|
||||||
|
# Start Ollama service manually
|
||||||
|
ollama serve &
|
||||||
|
|
||||||
|
# Pull a lightweight model (recommended for consumer hardware)
|
||||||
|
ollama pull llama3.2:1b
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
|
||||||
|
|
||||||
|
## Dead Simple API
|
||||||
|
|
||||||
|
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
# 1. Build index (no embeddings stored!)
|
# 1. Build index (no embeddings stored!)
|
||||||
builder = LeannBuilder(backend_name="diskann")
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
|
builder.add_text("C# is a powerful programming language")
|
||||||
builder.add_text("Python is a powerful programming language")
|
builder.add_text("Python is a powerful programming language")
|
||||||
builder.add_text("Machine learning transforms industries")
|
builder.add_text("Machine learning transforms industries")
|
||||||
builder.add_text("Neural networks process complex data")
|
builder.add_text("Neural networks process complex data")
|
||||||
|
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
|
||||||
builder.build_index("knowledge.leann")
|
builder.build_index("knowledge.leann")
|
||||||
|
|
||||||
# 2. Search with real-time embeddings
|
# 2. Search with real-time embeddings
|
||||||
searcher = LeannSearcher("knowledge.leann")
|
searcher = LeannSearcher("knowledge.leann")
|
||||||
results = searcher.search("programming languages", top_k=2)
|
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
|
||||||
|
print(results)
|
||||||
for result in results:
|
|
||||||
print(f"Score: {result['score']:.3f} - {result['text']}")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run the Demo
|
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
||||||
|
|
||||||
|
[Try the interactive demo →](demo.ipynb)
|
||||||
|
|
||||||
|
## Wild Things You Can Do
|
||||||
|
|
||||||
|
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
||||||
|
|
||||||
|
### Process Any Documents (.pdf, .txt, .md)
|
||||||
|
|
||||||
|
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run examples/document_search.py
|
# Drop your PDFs, .txt, .md files into examples/data/
|
||||||
|
uv run ./examples/main_cli_example.py
|
||||||
|
|
||||||
|
# Or use python directly
|
||||||
|
source .venv/bin/activate
|
||||||
|
python ./examples/main_cli_example.py
|
||||||
```
|
```
|
||||||
|
|
||||||
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
|
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
||||||
|
|
||||||
This demo showcases how to build a RAG system for PDF documents using Leann.
|
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
||||||
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
|
|
||||||
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
|
### Search Your Entire Life
|
||||||
|
```bash
|
||||||
|
python examples/mail_reader_leann.py
|
||||||
|
# "What did my boss say about the Christmas party last year?"
|
||||||
|
# "Find all emails from my mom about birthday plans"
|
||||||
|
```
|
||||||
|
**90K emails → 14MB.** Finally, search your email like you search Google.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run examples/main_cli_example.py
|
# Use default mail path (works for most macOS setups)
|
||||||
|
python examples/mail_reader_leann.py
|
||||||
|
|
||||||
|
# Run with custom index directory
|
||||||
|
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
|
||||||
|
|
||||||
|
# Process all emails (may take time but indexes everything)
|
||||||
|
python examples/mail_reader_leann.py --max-emails -1
|
||||||
|
|
||||||
|
# Limit number of emails processed (useful for testing)
|
||||||
|
python examples/mail_reader_leann.py --max-emails 1000
|
||||||
|
|
||||||
|
# Run a single query
|
||||||
|
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
|
||||||
```
|
```
|
||||||
|
|
||||||
## ✨ Features
|
</details>
|
||||||
|
|
||||||
### 🔥 Core Features
|
<details>
|
||||||
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
|
<summary><strong>📋 Click to expand: Example queries you can try</strong></summary>
|
||||||
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
|
|
||||||
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
|
|
||||||
- **📈 Scalable Architecture**: Handles millions of documents on consumer hardware
|
|
||||||
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
|
|
||||||
|
|
||||||
### 🛠️ Technical Highlights
|
Once the index is built, you can ask questions like:
|
||||||
- **Zero-copy operations** for maximum performance
|
- "Find emails from my boss about deadlines"
|
||||||
- **SIMD-optimized** distance computations (AVX2/AVX512)
|
- "What did John say about the project timeline?"
|
||||||
- **Async embedding pipeline** with batched processing
|
- "Show me emails about travel expenses"
|
||||||
- **Memory-mapped indices** for fast startup
|
</details>
|
||||||
- **Recompute mode** for highest accuracy scenarios
|
|
||||||
|
|
||||||
### 🎨 Developer Experience
|
|
||||||
- **Simple Python API** - Get started in minutes
|
|
||||||
- **Extensible backend system** - Easy to add new algorithms
|
|
||||||
- **Comprehensive examples** - From basic usage to production deployment
|
|
||||||
- **Rich debugging tools** - Built-in performance profiling
|
|
||||||
|
|
||||||
## 📊 Benchmarks
|
|
||||||
|
|
||||||
### Memory Usage Comparison
|
|
||||||
|
|
||||||
| System | 1M Documents | 10M Documents | 100M Documents |
|
|
||||||
|--------|-------------|---------------|----------------|
|
|
||||||
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
|
|
||||||
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
|
|
||||||
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
|
|
||||||
|
|
||||||
### Query Performance
|
|
||||||
|
|
||||||
| Backend | Index Size | Query Time | Recall@10 |
|
|
||||||
|---------|------------|------------|-----------|
|
|
||||||
| DiskANN | 1M docs | 12ms | 0.95 |
|
|
||||||
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
|
|
||||||
| HNSW | 1M docs | 8ms | 0.93 |
|
|
||||||
|
|
||||||
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
|
|
||||||
|
|
||||||
## 🏗️ Architecture
|
|
||||||
|
|
||||||
|
### Time Machine for the Web
|
||||||
|
```bash
|
||||||
|
python examples/google_history_reader_leann.py
|
||||||
|
# "What was that AI paper I read last month?"
|
||||||
|
# "Show me all the cooking videos I watched"
|
||||||
```
|
```
|
||||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
||||||
│ Query Text │───▶│ Embedding │───▶│ Graph-based │
|
|
||||||
│ │ │ Computation │ │ Search │
|
<details>
|
||||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
│ │
|
|
||||||
▼ ▼
|
```bash
|
||||||
┌──────────────┐ ┌──────────────┐
|
# Use default Chrome profile (auto-finds all profiles)
|
||||||
│ ZMQ Server │ │ Pruned Graph │
|
python examples/google_history_reader_leann.py
|
||||||
│ (Cached) │ │ Index │
|
|
||||||
└──────────────┘ └──────────────┘
|
# Run with custom index directory
|
||||||
|
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
|
||||||
|
|
||||||
|
# Limit number of history entries processed (useful for testing)
|
||||||
|
python examples/google_history_reader_leann.py --max-entries 500
|
||||||
|
|
||||||
|
# Run a single query
|
||||||
|
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Key Components
|
</details>
|
||||||
|
|
||||||
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
|
<details>
|
||||||
2. **📊 Graph Index**: Memory-efficient navigation structures
|
<summary><strong>📋 Click to expand: How to find your Chrome profile</strong></summary>
|
||||||
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
|
|
||||||
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
|
|
||||||
|
|
||||||
## 🎓 Supported Models & Backends
|
The default Chrome profile path is configured for a typical macOS setup. If you need to find your specific Chrome profile:
|
||||||
|
|
||||||
### 🤖 Embedding Models
|
1. Open Terminal
|
||||||
- **sentence-transformers/all-mpnet-base-v2** (default)
|
2. Run: `ls ~/Library/Application\ Support/Google/Chrome/`
|
||||||
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
|
3. Look for folders like "Default", "Profile 1", "Profile 2", etc.
|
||||||
- Any HuggingFace sentence-transformer model
|
4. Use the full path as your `--chrome-profile` argument
|
||||||
- Custom model support via API
|
|
||||||
|
|
||||||
### 🔧 Search Backends
|
**Common Chrome profile locations:**
|
||||||
- **DiskANN**: Microsoft's billion-scale ANN algorithm
|
- macOS: `~/Library/Application Support/Google/Chrome/Default`
|
||||||
- **HNSW**: Hierarchical Navigable Small World graphs
|
- Linux: `~/.config/google-chrome/Default`
|
||||||
- **Coming soon**: ScaNN, Faiss-IVF, NGT
|
|
||||||
|
|
||||||
### 📏 Distance Functions
|
</details>
|
||||||
- **L2**: Euclidean distance for precise similarity
|
|
||||||
- **Cosine**: Angular similarity for normalized vectors
|
<details>
|
||||||
- **MIPS**: Maximum Inner Product Search for recommendation systems
|
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||||
|
|
||||||
|
Once the index is built, you can ask questions like:
|
||||||
|
|
||||||
|
- "What websites did I visit about machine learning?"
|
||||||
|
- "Find my search history about programming"
|
||||||
|
- "What YouTube videos did I watch recently?"
|
||||||
|
- "Show me websites I visited about travel planning"
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### WeChat Detective
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python examples/wechat_history_reader_leann.py
|
||||||
|
# "Show me all group chats about weekend plans"
|
||||||
|
```
|
||||||
|
**400K messages → 64MB.** Search years of chat history in any language.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||||
|
|
||||||
|
First, you need to install the WeChat exporter:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo packages/wechat-exporter/wechattweak-cli install
|
||||||
|
```
|
||||||
|
|
||||||
|
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use default settings (recommended for first run)
|
||||||
|
python examples/wechat_history_reader_leann.py
|
||||||
|
|
||||||
|
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
|
||||||
|
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
|
||||||
|
|
||||||
|
# Run with custom index directory
|
||||||
|
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
|
||||||
|
|
||||||
|
# Limit number of chat entries processed (useful for testing)
|
||||||
|
python examples/wechat_history_reader_leann.py --max-entries 1000
|
||||||
|
|
||||||
|
# Run a single query
|
||||||
|
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||||
|
|
||||||
|
Once the index is built, you can ask questions like:
|
||||||
|
|
||||||
|
- "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?" (Chinese: Show me chat records about buying Magic Johnson's jersey)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="assets/arch.png" alt="LEANN Architecture" width="800">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
**The magic:** Most vector DBs store every single embedding (expensive). LEANN stores a pruned graph structure (cheap) and recomputes embeddings only when needed (fast).
|
||||||
|
|
||||||
|
**Core techniques:**
|
||||||
|
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||||
|
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||||
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
|
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
Run the comparison yourself:
|
||||||
|
```bash
|
||||||
|
python examples/compare_faiss_vs_leann.py
|
||||||
|
```
|
||||||
|
|
||||||
|
| System | Storage |
|
||||||
|
|--------|---------|
|
||||||
|
| FAISS HNSW | 5.5 MB |
|
||||||
|
| LEANN | 0.5 MB |
|
||||||
|
| **Savings** | **91%** |
|
||||||
|
|
||||||
|
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||||
|
|
||||||
|
## Reproduce Our Results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
|
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||||
|
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||||
|
```
|
||||||
|
|
||||||
|
The evaluation script downloads data automatically on first run.
|
||||||
|
|
||||||
|
### Storage Usage Comparison
|
||||||
|
|
||||||
|
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
||||||
|
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
||||||
|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
||||||
|
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
||||||
|
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
||||||
|
|
||||||
|
<!-- ### Memory Usage Comparison
|
||||||
|
|
||||||
|
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
|
||||||
|
| --------------------- | ---------------- | ---------------- | ---------------- |
|
||||||
|
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
|
||||||
|
| **Leann** | **xx MB** | **x GB** | **x GB** |
|
||||||
|
| **Reduction** | **x%** | **x%** | **x%** |
|
||||||
|
|
||||||
|
### Query Performance of LEANN
|
||||||
|
|
||||||
|
| Backend | Index Size | Query Time | Recall@3 |
|
||||||
|
| ------------------- | ---------- | ---------- | --------- |
|
||||||
|
| DiskANN | 1M docs | xms | 0.95 |
|
||||||
|
| HNSW | 1M docs | xms | 0.95 | -->
|
||||||
|
|
||||||
|
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||||
|
|
||||||
## 🔬 Paper
|
## 🔬 Paper
|
||||||
|
|
||||||
@@ -184,91 +374,87 @@ If you find Leann useful, please cite:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🌍 Use Cases
|
## ✨ Features
|
||||||
|
|
||||||
### 💼 Enterprise RAG
|
### 🔥 Core Features
|
||||||
```python
|
|
||||||
# Handle millions of documents with limited resources
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
distance_metric="cosine",
|
|
||||||
graph_degree=64,
|
|
||||||
memory_budget="4GB"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 🔬 Research & Experimentation
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
```python
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
# Quick prototyping with different algorithms
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
for backend in ["diskann", "hnsw"]:
|
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||||
searcher = LeannSearcher(index_path, backend=backend)
|
|
||||||
evaluate_recall(searcher, queries, ground_truth)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 🚀 Real-time Applications
|
### 🛠️ Technical Highlights
|
||||||
```python
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
# Sub-second response times
|
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||||
chat = LeannChat("knowledge.leann")
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
response = chat.ask("What is quantum computing?")
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
# Returns in <100ms with recompute mode
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
```
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||||
|
|
||||||
|
### 🎨 Developer Experience
|
||||||
|
|
||||||
|
- **Simple Python API** - Get started in minutes
|
||||||
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
|
|
||||||
## 🤝 Contributing
|
## 🤝 Contributing
|
||||||
|
|
||||||
We welcome contributions! Leann is built by the community, for the community.
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
### Ways to Contribute
|
### Ways to Contribute
|
||||||
|
|
||||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
- 📖 **Documentation**: Help make Leann more accessible
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
- 🧪 **Benchmarks**: Share your performance results
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
### Development Setup
|
|
||||||
```bash
|
<!-- ## ❓ FAQ
|
||||||
git clone https://github.com/yourname/leann
|
|
||||||
cd leann
|
### Common Issues
|
||||||
uv sync --dev
|
|
||||||
uv run pytest tests/
|
#### NCCL Topology Error
|
||||||
|
|
||||||
|
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
||||||
|
|
||||||
|
```
|
||||||
|
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
||||||
```
|
```
|
||||||
|
|
||||||
### Quick Tests
|
**Solution**: Set these environment variables before running your script:
|
||||||
```bash
|
|
||||||
# Sanity check all distance functions
|
|
||||||
uv run python tests/sanity_checks/test_distance_functions.py
|
|
||||||
|
|
||||||
# Verify L2 implementation
|
```bash
|
||||||
uv run python tests/sanity_checks/test_l2_verification.py
|
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
||||||
```
|
export NCCL_DEBUG=INFO
|
||||||
|
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
||||||
|
export NCCL_IB_DISABLE=1
|
||||||
|
export NCCL_NET_PLUGIN=none
|
||||||
|
export NCCL_SOCKET_IFNAME=ens5
|
||||||
|
``` -->
|
||||||
|
|
||||||
## 📈 Roadmap
|
## 📈 Roadmap
|
||||||
|
|
||||||
### 🎯 Q1 2024
|
### 🎯 Q2 2025
|
||||||
- [x] DiskANN backend with MIPS/L2/Cosine support
|
|
||||||
- [x] HNSW backend integration
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
- [x] Real-time embedding pipeline
|
- [X] HNSW backend integration
|
||||||
- [x] Memory-efficient graph pruning
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
### 🚀 Q3 2025
|
||||||
|
|
||||||
|
|
||||||
### 🚀 Q2 2024
|
|
||||||
- [ ] Distributed search across multiple nodes
|
|
||||||
- [ ] ScaNN backend support
|
|
||||||
- [ ] Advanced caching strategies
|
- [ ] Advanced caching strategies
|
||||||
- [ ] Kubernetes deployment guides
|
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||||
|
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||||
|
- [ ] Add OpenAI recompute API
|
||||||
|
|
||||||
|
### 🌟 Q4 2025
|
||||||
|
|
||||||
### 🌟 Q3 2024
|
|
||||||
- [ ] GPU-accelerated embedding computation
|
|
||||||
- [ ] Approximate distance functions
|
|
||||||
- [ ] Integration with LangChain/LlamaIndex
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
- [ ] Visual similarity search
|
- [ ] Visual similarity search
|
||||||
|
- [ ] Query rewrtiting, rerank and expansion
|
||||||
## 💬 Community
|
|
||||||
|
|
||||||
Join our growing community of researchers and engineers!
|
|
||||||
|
|
||||||
- 🐦 **Twitter**: [@LeannAI](https://twitter.com/LeannAI)
|
|
||||||
- 💬 **Discord**: [Join our server](https://discord.gg/leann)
|
|
||||||
- 📧 **Email**: leann@yourcompany.com
|
|
||||||
- 🐙 **GitHub Discussions**: [Ask questions here](https://github.com/yourname/leann/discussions)
|
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
|
|
||||||
@@ -290,3 +476,4 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
Made with ❤️ by the Leann team
|
Made with ❤️ by the Leann team
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|||||||
BIN
assets/arch.png
Normal file
BIN
assets/arch.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
BIN
assets/effects.png
Normal file
BIN
assets/effects.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 339 KiB |
BIN
assets/logo-text.png
Normal file
BIN
assets/logo-text.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 818 KiB |
BIN
assets/logo.png
Normal file
BIN
assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 276 KiB |
82
data/.gitattributes
vendored
Normal file
82
data/.gitattributes
vendored
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - uncompressed
|
||||||
|
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - compressed
|
||||||
|
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - uncompressed
|
||||||
|
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - compressed
|
||||||
|
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Video files - compressed
|
||||||
|
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
44
data/README.md
Normal file
44
data/README.md
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
license: mit
|
||||||
|
---
|
||||||
|
|
||||||
|
# LEANN-RAG Evaluation Data
|
||||||
|
|
||||||
|
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||||
|
|
||||||
|
## Dataset Components
|
||||||
|
|
||||||
|
This dataset is structured into three main parts:
|
||||||
|
|
||||||
|
1. **Pre-built LEANN Indices**:
|
||||||
|
* `dpr/`: A pre-built index for the DPR dataset.
|
||||||
|
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||||
|
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||||
|
|
||||||
|
2. **Ground Truth Data**:
|
||||||
|
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||||
|
|
||||||
|
3. **Queries**:
|
||||||
|
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install huggingface-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir="data"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||||
231
demo.ipynb
231
demo.ipynb
@@ -2,225 +2,34 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO: LeannBuilder initialized with 'diskann' backend.\n",
|
|
||||||
"INFO: Computing embeddings for 6 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 77.61it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO: Building DiskANN index for 6 vectors with metric Metric.INNER_PRODUCT...\n",
|
|
||||||
"Using Inner Product search, so need to pre-process base data into temp file. Please ensure there is additional (n*(d+1)*4) bytes for storing pre-processed base vectors, apart from the interim indices created by DiskANN and the final index.\n",
|
|
||||||
"Pre-processing base file by adding extra coordinate\n",
|
|
||||||
"✅ DiskANN index built successfully at 'knowledge'\n",
|
|
||||||
"Writing bin: knowledge_disk.index_max_base_norm.bin\n",
|
|
||||||
"bin: #pts = 1, #dims = 1, size = 12B\n",
|
|
||||||
"Finished writing bin.\n",
|
|
||||||
"Time for preprocessing data for inner product: 0.000165 seconds\n",
|
|
||||||
"Reading max_norm_of_base from knowledge_disk.index_max_base_norm.bin\n",
|
|
||||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
|
||||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
|
||||||
"Metadata: #pts = 1, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"max_norm_of_base: 1\n",
|
|
||||||
"! Using prepped_base file at knowledge_prepped_base.bin\n",
|
|
||||||
"Starting index build: R=32 L=64 Query RAM budget: 4.02653e+09 Indexing ram budget: 8 T: 8\n",
|
|
||||||
"getting bin metadata\n",
|
|
||||||
"Time for getting bin metadata: 0.000008 seconds\n",
|
|
||||||
"Compressing 769-dimensional data into 512 bytes per vector.\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Training data with 6 samples loaded.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"PQ pivot file exists. Not generating again\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 4, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 769, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 513, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Loaded PQ pivot information\n",
|
|
||||||
"Processing points [0, 6)...done.\n",
|
|
||||||
"Time for generating quantized data: 0.023918 seconds\n",
|
|
||||||
"Full index fits in RAM budget, should consume at most 2.03973e-05GiBs, so building in one shot\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"Passed, empty search_params while creating index config\n",
|
|
||||||
"Using only first 6 from file.. \n",
|
|
||||||
"Starting index build with 6 points... \n",
|
|
||||||
"0% of index build completed.Starting final cleanup..done. Link time: 9e-05s\n",
|
|
||||||
"Index built with degree: max:5 avg:5 min:5 count(deg<2):0\n",
|
|
||||||
"Not saving tags as they are not enabled.\n",
|
|
||||||
"Time taken for save: 0.000178s.\n",
|
|
||||||
"Time for building merged vamana index: 0.000579 seconds\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Vamana index file size=168\n",
|
|
||||||
"Opened: knowledge_disk.index, cache_size: 67108864\n",
|
|
||||||
"medoid: 0B\n",
|
|
||||||
"max_node_len: 3100B\n",
|
|
||||||
"nnodes_per_sector: 1B\n",
|
|
||||||
"# sectors: 6\n",
|
|
||||||
"Sector #0written\n",
|
|
||||||
"Finished writing 28672B\n",
|
|
||||||
"Writing bin: knowledge_disk.index\n",
|
|
||||||
"bin: #pts = 9, #dims = 1, size = 80B\n",
|
|
||||||
"Finished writing bin.\n",
|
|
||||||
"Output disk index file written to knowledge_disk.index\n",
|
|
||||||
"Finished writing 28672B\n",
|
|
||||||
"Time for generating disk layout: 0.043488 seconds\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Loading base knowledge_prepped_base.bin. #points: 6. #dim: 769.\n",
|
|
||||||
"Wrote 1 points to sample file: knowledge_sample_data.bin\n",
|
|
||||||
"Indexing time: 0.0684344\n",
|
|
||||||
"INFO: Leann metadata saved to knowledge.leann.meta.json\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n",
|
|
||||||
"Opened file : knowledge_disk.index\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Since data is floating point, we assume that it has been appropriately pre-processed (normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we shall invoke an l2 distance function.\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"Before index load\n",
|
|
||||||
"✅ DiskANN index loaded successfully.\n",
|
|
||||||
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
|
|
||||||
"Reading bin file knowledge_pq_compressed.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_compressed.bin... \n",
|
|
||||||
"Metadata: #pts = 6, #dims = 512...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 4, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Offsets: 4096 791560 794644 796704\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 769, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 513, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Loaded PQ Pivots: #ctrs: 256, #dims: 769, #chunks: 512\n",
|
|
||||||
"Loaded PQ centroids and in-memory compressed vectors. #points: 6 #dim: 769 #aligned_dim: 776 #chunks: 512\n",
|
|
||||||
"Loading index metadata from knowledge_disk.index\n",
|
|
||||||
"Disk-Index File Meta-data: # nodes per sector: 1, max node len (bytes): 3100, max node degree: 5\n",
|
|
||||||
"Disk-Index Meta: nodes per sector: 1, max node len: 3100, max node degree: 5\n",
|
|
||||||
"Setting up thread-specific contexts for nthreads: 8\n",
|
|
||||||
"allocating ctx: 0x78348f4de000 to thread-id:132170359560000\n",
|
|
||||||
"allocating ctx: 0x78348f4cd000 to thread-id:132158431693760\n",
|
|
||||||
"allocating ctx: 0x78348f4bc000 to thread-id:132158442179392\n",
|
|
||||||
"allocating ctx: 0x78348f4ab000 to thread-id:132158421208128\n",
|
|
||||||
"allocating ctx: 0x78348f49a000 to thread-id:132158452665024\n",
|
|
||||||
"allocating ctx: 0x78348f489000 to thread-id:132158389751232\n",
|
|
||||||
"allocating ctx: 0x78348f478000 to thread-id:132158410722496\n",
|
|
||||||
"allocating ctx: 0x78348f467000 to thread-id:132158400236864\n",
|
|
||||||
"Loading centroid data from medoids vector data of 1 medoid(s)\n",
|
|
||||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
|
||||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
|
||||||
"Metadata: #pts = 1, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Setting re-scaling factor of base vectors to 1\n",
|
|
||||||
"load_from_separate_paths done.\n",
|
|
||||||
"Reading (with alignment) bin file knowledge_sample_data.bin ...Metadata: #pts = 1, #dims = 769, aligned_dim = 776... allocating aligned memory of 3104 bytes... done. Copying data to mem_aligned buffer... done.\n",
|
|
||||||
"reserve ratio: 1\n",
|
|
||||||
"Graph traversal completed, hops: 3\n",
|
|
||||||
"Loading the cache list into memory....done.\n",
|
|
||||||
"After index load\n",
|
|
||||||
"Clearing scratch\n",
|
|
||||||
"INFO: Computing embeddings for 1 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 92.66it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Score: -0.481 - C++ is a powerful programming language\n",
|
|
||||||
"Score: -1.049 - Java is a powerful programming language\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"reserve ratio: 1\n",
|
|
||||||
"Graph traversal completed, hops: 3\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder, LeannSearcher\n",
|
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
||||||
"import leann_backend_diskann\n",
|
|
||||||
"# 1. Build index (no embeddings stored!)\n",
|
"# 1. Build index (no embeddings stored!)\n",
|
||||||
"builder = LeannBuilder(backend_name=\"diskann\")\n",
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
"builder.add_text(\"Python is a powerful programming language\")\n",
|
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
|
||||||
|
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
||||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
"builder.add_text(\"Machine learning transforms industries\") \n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Java is a powerful programming language\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
|
||||||
"builder.add_text(\"C++ is a powerful programming language\")\n",
|
|
||||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
|
||||||
"builder.build_index(\"knowledge.leann\")\n",
|
"builder.build_index(\"knowledge.leann\")\n",
|
||||||
"\n",
|
|
||||||
"# 2. Search with real-time embeddings\n",
|
"# 2. Search with real-time embeddings\n",
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||||
"results = searcher.search(\"C++ programming languages\", top_k=2)\n",
|
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
|
||||||
|
"print(results)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for result in results:\n",
|
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
|
||||||
" print(f\"Score: {result['score']:.3f} - {result['text']}\")"
|
"\n",
|
||||||
|
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||||
|
"\n",
|
||||||
|
"response = chat.ask(\n",
|
||||||
|
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
|
||||||
|
" top_k=2,\n",
|
||||||
|
" recompute_beighbor_embeddings=True,\n",
|
||||||
|
")\n",
|
||||||
|
"print(response)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -240,7 +49,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.11"
|
"version": "3.11.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
335
examples/compare_faiss_vs_leann.py
Normal file
335
examples/compare_faiss_vs_leann.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
"""Get current memory usage in MB"""
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
def print_memory_stats(stage: str, start_mem: float):
|
||||||
|
"""Print memory statistics"""
|
||||||
|
current_mem = get_memory_usage()
|
||||||
|
diff = current_mem - start_mem
|
||||||
|
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracker:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.start_mem = get_memory_usage()
|
||||||
|
self.stages = []
|
||||||
|
|
||||||
|
def checkpoint(self, stage: str):
|
||||||
|
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
|
||||||
|
self.stages.append((stage, current_mem))
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
print(f"\n=== {self.name} Memory Summary ===")
|
||||||
|
for stage, mem in self.stages:
|
||||||
|
print(f"{stage}: {mem:.1f} MB")
|
||||||
|
peak_mem = max(mem for _, mem in self.stages)
|
||||||
|
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||||
|
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
|
||||||
|
return peak_mem
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_hnsw():
|
||||||
|
"""Test Faiss HNSW Vector Store in subprocess"""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("TESTING FAISS HNSW VECTOR STORE")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, "examples/faiss_only.py"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Stderr:", result.stderr)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return {
|
||||||
|
"peak_memory": float("inf"),
|
||||||
|
"error": f"Process failed with code {result.returncode}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse peak memory from output
|
||||||
|
lines = result.stdout.split("\n")
|
||||||
|
peak_memory = 0.0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if "Peak Memory:" in line:
|
||||||
|
peak_memory = float(
|
||||||
|
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"peak_memory": float("inf"),
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_leann_hnsw():
|
||||||
|
"""Test LEANN HNSW Search Memory (load existing index)"""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("TESTING LEANN HNSW SEARCH MEMORY")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
tracker = MemoryTracker("LEANN HNSW Search")
|
||||||
|
|
||||||
|
# Import and setup
|
||||||
|
tracker.checkpoint("Initial")
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
# Load and parse documents
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
"examples/data",
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data()
|
||||||
|
|
||||||
|
tracker.checkpoint("After document loading")
|
||||||
|
|
||||||
|
# Parse into chunks
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
|
# Build LEANN index
|
||||||
|
INDEX_DIR = Path("./test_leann_comparison")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
|
||||||
|
|
||||||
|
# Check if index already exists
|
||||||
|
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||||
|
print("Loading existing LEANN HNSW index...")
|
||||||
|
tracker.checkpoint("After loading existing index")
|
||||||
|
else:
|
||||||
|
print("Building new LEANN HNSW index...")
|
||||||
|
# Clean up previous index
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
shutil.rmtree(INDEX_DIR)
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.checkpoint("After builder setup")
|
||||||
|
|
||||||
|
print("Building LEANN HNSW index...")
|
||||||
|
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
|
# Find existing LEANN index
|
||||||
|
index_paths = [
|
||||||
|
"./test_leann_comparison/comparison.leann",
|
||||||
|
]
|
||||||
|
index_path = None
|
||||||
|
for path in index_paths:
|
||||||
|
if os.path.exists(path + ".meta.json"):
|
||||||
|
index_path = path
|
||||||
|
break
|
||||||
|
|
||||||
|
if not index_path:
|
||||||
|
print("❌ LEANN index not found. Please build it first")
|
||||||
|
return {"peak_memory": float("inf"), "error": "Index not found"}
|
||||||
|
|
||||||
|
# Measure runtime memory overhead
|
||||||
|
print("\nMeasuring runtime memory overhead...")
|
||||||
|
runtime_start_mem = get_memory_usage()
|
||||||
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
|
# Load searcher
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print("Running search queries...")
|
||||||
|
queries = [
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
|
"What is LEANN and how does it work?",
|
||||||
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
start_time = time.time()
|
||||||
|
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
|
||||||
|
_ = searcher.search(query, top_k=20, ef=120)
|
||||||
|
query_time = time.time() - start_time
|
||||||
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
|
runtime_end_mem = get_memory_usage()
|
||||||
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
|
peak_memory = tracker.summary()
|
||||||
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|
||||||
|
# Get storage size before cleanup
|
||||||
|
storage_size = 0
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
|
||||||
|
for filename in filenames:
|
||||||
|
# Only count actual index files, skip text data and backups
|
||||||
|
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
|
||||||
|
continue
|
||||||
|
# Count .index, .idx, .map files (actual index structures)
|
||||||
|
if filename.endswith((".index", ".idx", ".map")):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
del searcher
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"peak_memory": peak_memory,
|
||||||
|
"storage_size": storage_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run comparison tests"""
|
||||||
|
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Test Faiss HNSW
|
||||||
|
faiss_results = test_faiss_hnsw()
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Test LEANN HNSW
|
||||||
|
leann_results = test_leann_hnsw()
|
||||||
|
|
||||||
|
# Final comparison
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("STORAGE + SEARCH MEMORY COMPARISON")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Get storage sizes
|
||||||
|
faiss_storage_size = 0
|
||||||
|
leann_storage_size = leann_results.get("storage_size", 0)
|
||||||
|
|
||||||
|
# Get Faiss storage size using Python
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk("./storage_faiss"):
|
||||||
|
for filename in filenames:
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
print("Faiss HNSW:")
|
||||||
|
if "error" in faiss_results:
|
||||||
|
print(f" ❌ Failed: {faiss_results['error']}")
|
||||||
|
else:
|
||||||
|
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
print("\nLEANN HNSW:")
|
||||||
|
if "error" in leann_results:
|
||||||
|
print(f" ❌ Failed: {leann_results['error']}")
|
||||||
|
else:
|
||||||
|
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
# Calculate improvements only if both tests succeeded
|
||||||
|
if "error" not in faiss_results and "error" not in leann_results:
|
||||||
|
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
||||||
|
|
||||||
|
print("\nLEANN vs Faiss Performance:")
|
||||||
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
|
print(
|
||||||
|
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Storage comparison
|
||||||
|
if leann_storage_size > faiss_storage_size:
|
||||||
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
|
print(
|
||||||
|
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
||||||
|
)
|
||||||
|
elif faiss_storage_size > leann_storage_size:
|
||||||
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
|
print(
|
||||||
|
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(" Storage Size: similar")
|
||||||
|
else:
|
||||||
|
if "error" not in leann_results:
|
||||||
|
print("\n✅ LEANN HNSW completed successfully!")
|
||||||
|
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
if "error" not in faiss_results:
|
||||||
|
print("\n✅ Faiss HNSW completed successfully!")
|
||||||
|
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
examples/data/2501.14312v1 (1).pdf
Normal file
BIN
examples/data/2501.14312v1 (1).pdf
Normal file
Binary file not shown.
14907
examples/data/PrideandPrejudice.txt
Normal file
14907
examples/data/PrideandPrejudice.txt
Normal file
File diff suppressed because it is too large
Load Diff
82
examples/data/README.md
Normal file
82
examples/data/README.md
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
|
||||||
|
|
||||||
|
各位好,
|
||||||
|
|
||||||
|
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
|
||||||
|
|
||||||
|
首先为自证身份,列举一些细节:
|
||||||
|
|
||||||
|
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
|
||||||
|
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
|
||||||
|
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
|
||||||
|
4. 诺亚曾经传说是研究型的,但是来了之后因为在四野做大模型项目,项目成员完全变成了交付型的,且充满了例会,评审,汇报。很多时候做实验都要申请。团队需要对接终端小艺,华为云,ICT等诸多业务线,交付压力不小。
|
||||||
|
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”,一开始只有内部需要申请试用的网页版,到后续迫于压力在welink上接入和公测开放。
|
||||||
|
|
||||||
|
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
|
||||||
|
|
||||||
|
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
|
||||||
|
|
||||||
|
华为确实主要在昇腾卡上训练大模型(小模型实验室有不少英伟达的卡,他们之前也会用来训练,后面转移到昇腾)。曾经我被华为“打造世界第二选择”的决心而折服,我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打,从充满bug到现在能训出模型,付出了巨大的心血和代价。
|
||||||
|
|
||||||
|
最初我们的算力非常有限,在910A上训练模型。那会只支持fp16,训练的稳定性远不如bf16。盘古的moe开始很早,23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型,后面主力模型也逐渐在910B上训练。
|
||||||
|
|
||||||
|
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低,每个单个的符号,数字,空格,乃至汉字都会占用一个token。可想而知这会非常浪费算力,且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好(虽然事后来看,他的怀疑是无疑正确的),于是就决定,让71B和135B换tokenizer,因为小模型实验室曾经尝试过。团队缝合了两个tokenizer,开始了tokenizer的更换。71B模型的更换失败了,而135B因为采用了更精细的embedding初始化策略,续训了至少1T的数据后词表总算更换成功,但可想而知,效果并不会变好。
|
||||||
|
|
||||||
|
于此同期,阿里和智谱等国内其他公司在GPU上训练,且已经摸索出了正确的方法,盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败,导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时,团队的士气低迷到了极点。团队在算力极其有限的时候,做出了很多努力和挣扎。比如,团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数,还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B,架构相对落后,团队进行了一系列的操作,比如切换绝对位置编码到rope,去掉bias,切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验,这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训,变成了第二代38B dense模型(在几个月内这个模型都是主要的盘古中档位模型),曾经具有一定的竞争力。但是,由于更大的135B模型架构落后,且更换词表模型损伤巨大(后续分析发现当时更换的缝合词表有更严重的bug),续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
|
||||||
|
|
||||||
|
在这种情况下,王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来,通过训练短短的几百B数据,各项指标平均提升了十个点左右。实际上,这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行,使得领导完全对于这种扯淡的事情没有概念,他们只会觉得肯定是有什么算法创新。经过内部的分析,他们实际上是使用Qwen 1.5 110B续训而来,通过加层,扩增ffn维度,添加盘古pi论文的一些机制得来,凑够了大概135B的参数。实际上,旧的135B有107层,而这个模型只有82层,各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen,甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游,甚至包括外部客户。
|
||||||
|
|
||||||
|
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击,内部很多人其实都知道这件事,甚至包括终端和华为云。我们都戏称以后别叫盘古模型了,叫千古吧。当时团队成员就想向bcg举报了,毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来,因为更高级别的领导(比如姚老师,以及可能熊总和查老)其实后面也知道了,但是并不管,因为通过套壳拿出好的结果,对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷,离职跑路也逐渐成为挂在嘴边的事。
|
||||||
|
|
||||||
|
此时,盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来,当时诺亚完全没有掌握从头训练的技术,何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下,盘古开始了第三代模型的训练,付出了巨大的努力后,在数据架构和训练算法方面都与业界逐渐接轨,而这其中的艰辛和小模型实验室的人一点关系都没有。
|
||||||
|
|
||||||
|
一开始团队成员毫无信心,只从一个13B的模型开始训练,但是后面发现效果还不错,于是这个模型后续再次进行了一次参数扩增,变成了第三代的38B,代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的(也是业界常见的做法)。而当时王云鹤的实验室做出来了另一个词表(也就是后续pangu系列的词表)。当时两个词表还被迫进行了一次赛马,最终没有明显的好坏结论。于是,领导当即决定,应该统一词表,使用王云鹤他们的。于是,在后续从头训练的135B V3(也就是对外的Pangu Ultra),便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑,为什么当时同为V3代的两个不同档位的模型,会使用不同的tokenizer。
|
||||||
|
|
||||||
|
|
||||||
|
我们打心眼里觉得,135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的,华为全栈自研,正经从头训练的千亿级别的模型,且效果与24年同期竞品可比的。写到这里我已经热泪盈眶,太不容易了。当时为了稳定训练,团队做了大量实验对比,并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难,我们做到了,我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨,我们为了它的训练而不眠。在被内部心声骂的一文不值的时候,我们有多么不甘,有多少的委屈,我们挺住了。
|
||||||
|
|
||||||
|
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
|
||||||
|
|
||||||
|
然而,我们的所有辛苦的成果,经常被小模型实验室轻飘飘的拿走了。数据,直接要走。代码,直接要走,还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦,他们取得荣耀。果然应了那句话,你在负重前行是因为有人替你岁月静好。在这种情况下,越来越多的战友再也坚持不下去了,选择了离开。看到身边那些优秀的同事一个个离职,我的内心又感叹又难过。在这种作战一样的环境下,我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方,堪称良师。看到他们去了诸如字节Seed,Deepseek,月之暗面,腾讯和快手等等很多出色的团队,我打心眼里为他们高兴和祝福,脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新,ta说:“来这里是我技术生涯中的耻辱,在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足,以及没法适应互联网公司高淘汰的环境,让我多次想离职的心始终没有迈出这一步。
|
||||||
|
|
||||||
|
盘古除了dense模型,后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的,小模型实验室也开启了第二次主要的套壳行动(次要的插曲可能还包括一些别的模型,比如math模型),即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的(就算如此,这也与技术报告不符,何况是套壳qwen 2.5的14b续训)。还记得他们训了没几天,内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型,都知道他们的套壳行动,只是迫于各种原因,无法伸张正义。实际上,对于后续训了很久很久的这个模型,Honestagi能够分析出这个量级的相似性我已经很诧异了,因为这个模型为了续训洗参数,所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印,采取了不少办法,甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
|
||||||
|
|
||||||
|
24年底和25年初,在Deepseek v3和r1发布之后,由于其惊艳的技术水平,团队受到了巨大的冲击,也受到了更大的质疑。于是为了紧跟潮流,盘古模仿Deepseek的模型尺寸,开启了718B moe的训练。这个时候,小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数,进行训练。连任务加载ckpt的目录都是deepseekv3,改都不改,何其嚣张?与之相反,一些有真正技术信仰的同事,在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然,这个模型怎么可能比直接套壳的好呢?如果不是团队leader坚持,早就被叫停了。
|
||||||
|
|
||||||
|
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
|
||||||
|
|
||||||
|
HonestAGI的事情出来后,内部让大家不停的研讨分析,如何公关和“回应”。诚然,这个原文的分析也许不够有力,给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此,这两天我内心感到作呕,时时怀疑自己的人生意义以及苍天无眼。我不奉陪了,我要离职了,同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到,他们竟然猖狂到敢开源。我没想到,他们敢如此愚弄世人,大肆宣发。当时,我也许是存了侥幸心理,没有拒绝署名。我相信很多扎实做事的战友,也只是被迫上了贼船,或者不知情。但这件事已经无法挽回,我希望我的余生能够坚持扎实做真正有意义的事,为我当时的软弱和不坚定赎罪。
|
||||||
|
|
||||||
|
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
|
||||||
|
|
||||||
|
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
|
||||||
|
|
||||||
|
现在,我累了,我想投降。
|
||||||
|
|
||||||
|
其实时至今日,我还是真心希望华为能认真吸取教训,能做好盘古,把盘古做到世界一流,把昇腾变成英伟达的水平。内部的劣币驱逐良币,使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着,施展着他们的抱负才华,为中美在AI的激烈竞赛中奉献力量。我时常感叹,华为不是没有人才,而是根本不知道怎么留住人才。如果给这些人合适的环境,合适的资源,更少的枷锁,更少的政治斗争,盘古何愁不成?
|
||||||
|
|
||||||
|
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
|
||||||
|
|
||||||
|
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
|
||||||
|
|
||||||
|
如果我消失了,就当是我为了真理和理想,为了华为乃至中国能够更好地发展算力和AI而牺牲了吧,我愿埋葬于那片曾经奋斗过的地方。
|
||||||
|
|
||||||
|
诺亚,再见
|
||||||
|
|
||||||
|
2025年7月6日凌晨 写于深圳
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
各位好,
|
||||||
|
|
||||||
|
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
|
||||||
|
|
||||||
|
我补充一些细节,以免某些人继续颠倒黑白。
|
||||||
|
|
||||||
|
关于135B V2,小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后(比如任务令表彰和及时激励),因为不想继续支撑下游应用和模型迭代,又把这个烫手山芋甩给了四纵。确实技高一筹,直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型,最终拿回了一个当时一个魔改的先进的千问。做大模型的人,自己做的模型就像自己孩子一样熟悉,不要把别人都当傻子。就像自家儿子出门一趟,回来个别人家孩子。
|
||||||
|
|
||||||
|
盘古report的署名是不符合学术规范的。例如,135B V3有不少有技术贡献的人,因为作者名额数量限制,劳动成果没有得到应有的回报,团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶,甚至是团队当时的精神支柱,支撑着不少兄弟们继续留在诺亚。所谓的名额限制,以及挂名了一些毫无技术贡献的人(如一些小模型实验室的人),让兄弟们何其心寒。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317
|
||||||
@@ -74,7 +74,7 @@ def main():
|
|||||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||||
print(">>> Basic search results <<<")
|
print(">>> Basic search results <<<")
|
||||||
for i, res in enumerate(results, 1):
|
for i, res in enumerate(results, 1):
|
||||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||||
|
|
||||||
# --- 3. Recompute search demo ---
|
# --- 3. Recompute search demo ---
|
||||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
||||||
@@ -107,7 +107,7 @@ def main():
|
|||||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||||
print(">>> Recompute search results <<<")
|
print(">>> Recompute search results <<<")
|
||||||
for i, res in enumerate(recompute_results, 1):
|
for i, res in enumerate(recompute_results, 1):
|
||||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
print(f"\n--- Result comparison ---")
|
print(f"\n--- Result comparison ---")
|
||||||
@@ -116,8 +116,8 @@ def main():
|
|||||||
|
|
||||||
print("\nBasic search vs Recompute results:")
|
print("\nBasic search vs Recompute results:")
|
||||||
for i in range(min(len(results), len(recompute_results))):
|
for i in range(min(len(results), len(recompute_results))):
|
||||||
basic_score = results[i]['score']
|
basic_score = results[i].score
|
||||||
recompute_score = recompute_results[i]['score']
|
recompute_score = recompute_results[i].score
|
||||||
score_diff = abs(basic_score - recompute_score)
|
score_diff = abs(basic_score - recompute_score)
|
||||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
||||||
|
|
||||||
|
|||||||
124
examples/email_data/LEANN_email_reader.py
Normal file
124
examples/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import os
|
||||||
|
import email
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
def find_all_messages_directories(root: str = None) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Recursively find all 'Messages' directories under the given root.
|
||||||
|
Returns a list of Path objects.
|
||||||
|
"""
|
||||||
|
if root is None:
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
messages_dirs = []
|
||||||
|
for dirpath, dirnames, filenames in os.walk(root):
|
||||||
|
if os.path.basename(dirpath) == "Messages":
|
||||||
|
messages_dirs.append(Path(dirpath))
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
|
"""
|
||||||
|
self.include_html = include_html
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get('Subject', 'No Subject')
|
||||||
|
from_addr = msg.get('From', 'Unknown')
|
||||||
|
to_addr = msg.get('To', 'Unknown')
|
||||||
|
date = msg.get('Date', 'Unknown')
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||||
|
if part.get_content_type() == "text/html" and not self.include_html:
|
||||||
|
continue
|
||||||
|
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
# break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[EMAIL METADATA]
|
||||||
|
File: {filename}
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
[END METADATA]
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
192
examples/email_data/email.py
Normal file
192
examples/email_data/email.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Contains simple parser for mbox files.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
|
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MboxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Extract messages from mailbox files.
|
||||||
|
Returns string including date, subject, sender, receiver and
|
||||||
|
content for each message.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
|
"Date: {_date}\n"
|
||||||
|
"From: {_from}\n"
|
||||||
|
"To: {_to}\n"
|
||||||
|
"Subject: {_subject}\n"
|
||||||
|
"Content: {_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
max_count: int = 0,
|
||||||
|
message_format: str = DEFAULT_MESSAGE_FORMAT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # noqa
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.max_count = max_count
|
||||||
|
self.message_format = message_format
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
file: Path,
|
||||||
|
extra_info: Optional[Dict] = None,
|
||||||
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Parse file into string."""
|
||||||
|
# Import required libraries
|
||||||
|
import mailbox
|
||||||
|
from email.parser import BytesParser
|
||||||
|
from email.policy import default
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but MboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
results: List[str] = []
|
||||||
|
# Load file using mailbox
|
||||||
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
|
|
||||||
|
# Iterate through all messages
|
||||||
|
for _, _msg in enumerate(mbox):
|
||||||
|
try:
|
||||||
|
msg: mailbox.mboxMessage = _msg
|
||||||
|
# Parse multipart messages
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
ctype = part.get_content_type()
|
||||||
|
cdispo = str(part.get("Content-Disposition"))
|
||||||
|
if "attachment" in cdispo:
|
||||||
|
print(f"Attachment found: {part.get_filename()}")
|
||||||
|
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||||
|
content = part.get_payload(decode=True) # decode
|
||||||
|
break
|
||||||
|
# Get plain message payload for non-multipart messages
|
||||||
|
else:
|
||||||
|
content = msg.get_payload(decode=True)
|
||||||
|
|
||||||
|
# Parse message HTML content and remove unneeded whitespace
|
||||||
|
soup = BeautifulSoup(content)
|
||||||
|
stripped_content = " ".join(soup.get_text().split())
|
||||||
|
# Format message to include date, sender, receiver and subject
|
||||||
|
msg_string = self.message_format.format(
|
||||||
|
_date=msg["date"],
|
||||||
|
_from=msg["from"],
|
||||||
|
_to=msg["to"],
|
||||||
|
_subject=msg["subject"],
|
||||||
|
_content=stripped_content,
|
||||||
|
)
|
||||||
|
# Add message string to results
|
||||||
|
results.append(msg_string)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
|
||||||
|
|
||||||
|
# Increment counter and return if max count is met
|
||||||
|
i += 1
|
||||||
|
if self.max_count > 0 and i >= self.max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
return [Document(text=result, metadata=extra_info or {}) for result in results]
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxMboxReader(MboxReader):
|
||||||
|
"""
|
||||||
|
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||||
|
|
||||||
|
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||||
|
1. Reading .emlx files from a directory
|
||||||
|
2. Converting them to mbox format in memory
|
||||||
|
3. Using the parent MboxReader's parsing logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
directory: Path,
|
||||||
|
extra_info: Optional[Dict] = None,
|
||||||
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find all .emlx files in the directory
|
||||||
|
emlx_files = list(directory.glob("*.emlx"))
|
||||||
|
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||||
|
|
||||||
|
if not emlx_files:
|
||||||
|
logger.warning(f"No .emlx files found in {directory}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a temporary mbox file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||||
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
|
# Convert .emlx files to mbox format
|
||||||
|
for emlx_file in emlx_files:
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx format: first line is length, rest is email content
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
|
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||||
|
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Close the temporary file so MboxReader can read it
|
||||||
|
temp_mbox.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the parent MboxReader's logic to parse the mbox file
|
||||||
|
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||||
|
finally:
|
||||||
|
# Clean up temporary file
|
||||||
|
try:
|
||||||
|
os.unlink(temp_mbox_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
151
examples/faiss_only.py
Normal file
151
examples/faiss_only.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracker:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.start_mem = get_memory_usage()
|
||||||
|
self.stages = []
|
||||||
|
|
||||||
|
def checkpoint(self, stage: str):
|
||||||
|
current_mem = get_memory_usage()
|
||||||
|
diff = current_mem - self.start_mem
|
||||||
|
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||||
|
self.stages.append((stage, current_mem))
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
peak_mem = max(mem for _, mem in self.stages)
|
||||||
|
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||||
|
return peak_mem
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ImportError:
|
||||||
|
print("Faiss is not installed.")
|
||||||
|
print("Please install it with `uv pip install faiss-cpu`")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from llama_index.core import (
|
||||||
|
SimpleDirectoryReader,
|
||||||
|
VectorStoreIndex,
|
||||||
|
StorageContext,
|
||||||
|
Settings,
|
||||||
|
node_parser,
|
||||||
|
Document,
|
||||||
|
)
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
|
||||||
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
|
tracker.checkpoint("Initial")
|
||||||
|
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||||
|
Settings.embed_model = embed_model
|
||||||
|
tracker.checkpoint("After embedding model setup")
|
||||||
|
|
||||||
|
d = 768
|
||||||
|
faiss_index = faiss.IndexHNSWFlat(d, 32)
|
||||||
|
faiss_index.hnsw.efConstruction = 64
|
||||||
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
"examples/data",
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data()
|
||||||
|
tracker.checkpoint("After document loading")
|
||||||
|
|
||||||
|
# Parse into chunks using the same splitter as LEANN
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.checkpoint("After text splitter setup")
|
||||||
|
|
||||||
|
# Check if index already exists and try to load it
|
||||||
|
index_loaded = False
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
print("Loading existing Faiss HNSW index...")
|
||||||
|
try:
|
||||||
|
# Use the correct Faiss loading pattern from the example
|
||||||
|
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
|
||||||
|
storage_context = StorageContext.from_defaults(
|
||||||
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
|
)
|
||||||
|
from llama_index.core import load_index_from_storage
|
||||||
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
|
print(f"Index loaded from ./storage_faiss")
|
||||||
|
tracker.checkpoint("After loading existing index")
|
||||||
|
index_loaded = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load existing index: {e}")
|
||||||
|
print("Cleaning up corrupted index and building new one...")
|
||||||
|
# Clean up corrupted index
|
||||||
|
import shutil
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
|
if not index_loaded:
|
||||||
|
print("Building new Faiss HNSW index...")
|
||||||
|
|
||||||
|
# Use the correct Faiss building pattern from the example
|
||||||
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents,
|
||||||
|
storage_context=storage_context,
|
||||||
|
transformations=[node_parser]
|
||||||
|
)
|
||||||
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
|
# Save index to disk using the correct pattern
|
||||||
|
index.storage_context.persist(persist_dir="./storage_faiss")
|
||||||
|
tracker.checkpoint("After index saving")
|
||||||
|
|
||||||
|
# Measure runtime memory overhead
|
||||||
|
print("\nMeasuring runtime memory overhead...")
|
||||||
|
runtime_start_mem = get_memory_usage()
|
||||||
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
|
queries = [
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
|
"What is LEANN and how does it work?",
|
||||||
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
start_time = time.time()
|
||||||
|
_ = query_engine.query(query)
|
||||||
|
query_time = time.time() - start_time
|
||||||
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
|
runtime_end_mem = get_memory_usage()
|
||||||
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
|
peak_memory = tracker.summary()
|
||||||
|
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||||
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
281
examples/google_history_reader_leann.py
Normal file
281
examples/google_history_reader_leann.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
try:
|
||||||
|
import dotenv
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# python-dotenv is not installed; skip loading environment variables
|
||||||
|
dotenv = None
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||||
|
|
||||||
|
# Default Chrome profile path
|
||||||
|
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple Chrome profile data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of history entries to process per profile
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||||
|
|
||||||
|
# Load documents using ChromeHistoryReader from history_data
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Chrome profile directory
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir),
|
||||||
|
max_count=max_count
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {profile_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
|
||||||
|
"""
|
||||||
|
Create LEANN index from Chrome history data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of history entries to process
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from Chrome history data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using ChromeHistoryReader from history_data
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=profile_path,
|
||||||
|
max_count=max_count
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} history documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=10,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=32,
|
||||||
|
beam_width=1,
|
||||||
|
llm_config={
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
llm_kwargs={
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 1000
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||||
|
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||||
|
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||||
|
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
||||||
|
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||||
|
parser.add_argument('--max-entries', type=int, default=1000,
|
||||||
|
help='Maximum number of history entries to process (default: 1000)')
|
||||||
|
parser.add_argument('--query', type=str, default=None,
|
||||||
|
help='Single query to run (default: runs example queries)')
|
||||||
|
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
||||||
|
help='Automatically find all Chrome profiles (default: True)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||||
|
|
||||||
|
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
|
# Find Chrome profile directories
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
if args.auto_find_profiles:
|
||||||
|
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found automatically. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Use single specified profile
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
if not profile_path.exists():
|
||||||
|
print(f"Chrome profile not found: {profile_path}")
|
||||||
|
return
|
||||||
|
profile_dirs = [profile_path]
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"What websites did I visit about machine learning?",
|
||||||
|
"Find my search history about programming"
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
3
examples/history_data/__init__.py
Normal file
3
examples/history_data/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .history import ChromeHistoryReader
|
||||||
|
|
||||||
|
__all__ = ['ChromeHistoryReader']
|
||||||
176
examples/history_data/history.py
Normal file
176
examples/history_data/history.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
import sqlite3
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
class ChromeHistoryReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
|
|
||||||
|
Reads Chrome history from the default Chrome profile location and creates documents
|
||||||
|
with embedded metadata similar to the email reader structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Not used for Chrome history (kept for compatibility)
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of history entries to read.
|
||||||
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
||||||
|
|
||||||
|
# Default Chrome profile path on macOS
|
||||||
|
if chrome_profile_path is None:
|
||||||
|
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
|
||||||
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
|
if not os.path.exists(history_db_path):
|
||||||
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect to the Chrome history database
|
||||||
|
print(f"Connecting to database: {history_db_path}")
|
||||||
|
conn = sqlite3.connect(history_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Query to get browsing history with metadata (removed created_time column)
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
|
url,
|
||||||
|
title,
|
||||||
|
visit_count,
|
||||||
|
typed_count,
|
||||||
|
hidden
|
||||||
|
FROM urls
|
||||||
|
ORDER BY last_visit_time DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(f"Executing query on database: {history_db_path}")
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
print(f"Query returned {len(rows)} rows")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for row in rows:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[BROWSING HISTORY METADATA]
|
||||||
|
URL: {url}
|
||||||
|
Title: {title}
|
||||||
|
Last Visit: {last_visit}
|
||||||
|
Visit Count: {visit_count}
|
||||||
|
Typed Count: {typed_count}
|
||||||
|
Hidden: {hidden}
|
||||||
|
[END METADATA]
|
||||||
|
|
||||||
|
Title: {title}
|
||||||
|
URL: {url}
|
||||||
|
Last visited: {last_visit}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create document with embedded metadata
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Loaded {len(docs)} Chrome history documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading Chrome history: {e}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_chrome_profiles() -> List[Path]:
|
||||||
|
"""
|
||||||
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to Chrome profile directories
|
||||||
|
"""
|
||||||
|
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||||
|
profile_dirs = []
|
||||||
|
|
||||||
|
if not chrome_base_path.exists():
|
||||||
|
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||||
|
return profile_dirs
|
||||||
|
|
||||||
|
# Find all profile directories
|
||||||
|
for profile_dir in chrome_base_path.iterdir():
|
||||||
|
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||||
|
history_path = profile_dir / "History"
|
||||||
|
if history_path.exists():
|
||||||
|
profile_dirs.append(profile_dir)
|
||||||
|
print(f"Found Chrome profile: {profile_dir}")
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
return profile_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
||||||
|
"""
|
||||||
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Path to the output file
|
||||||
|
max_count: Maximum number of entries to export
|
||||||
|
"""
|
||||||
|
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
|
if not os.path.exists(history_db_path):
|
||||||
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(history_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
|
url,
|
||||||
|
title,
|
||||||
|
visit_count,
|
||||||
|
typed_count,
|
||||||
|
hidden
|
||||||
|
FROM urls
|
||||||
|
ORDER BY last_visit_time DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute(query, (max_count,))
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for row in rows:
|
||||||
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting Chrome history: {e}")
|
||||||
720
examples/history_data/wechat_history.py
Normal file
720
examples/history_data/wechat_history.py
Normal file
@@ -0,0 +1,720 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any, Dict, Optional
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class WeChatHistoryReader(BaseReader):
|
||||||
|
"""
|
||||||
|
WeChat chat history reader that extracts chat data from exported JSON files.
|
||||||
|
|
||||||
|
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
||||||
|
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
||||||
|
|
||||||
|
Also includes utilities for automatic WeChat chat history export.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
||||||
|
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
||||||
|
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
||||||
|
|
||||||
|
def check_wechat_running(self) -> bool:
|
||||||
|
"""Check if WeChat is currently running."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
|
||||||
|
return result.returncode == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def install_wechattweak(self) -> bool:
|
||||||
|
"""Install WeChatTweak CLI tool."""
|
||||||
|
try:
|
||||||
|
# Create wechat-exporter directory if it doesn't exist
|
||||||
|
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
|
if not wechattweak_path.exists():
|
||||||
|
print("Downloading WeChatTweak CLI...")
|
||||||
|
subprocess.run([
|
||||||
|
"curl", "-L", "-o", str(wechattweak_path),
|
||||||
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
||||||
|
], check=True)
|
||||||
|
|
||||||
|
# Make executable
|
||||||
|
wechattweak_path.chmod(0o755)
|
||||||
|
|
||||||
|
# Install WeChatTweak
|
||||||
|
print("Installing WeChatTweak...")
|
||||||
|
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error installing WeChatTweak: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def restart_wechat(self):
|
||||||
|
"""Restart WeChat to apply WeChatTweak."""
|
||||||
|
try:
|
||||||
|
print("Restarting WeChat...")
|
||||||
|
subprocess.run(["pkill", "-f", "WeChat"], check=False)
|
||||||
|
time.sleep(2)
|
||||||
|
subprocess.run(["open", "-a", "WeChat"], check=True)
|
||||||
|
time.sleep(5) # Wait for WeChat to start
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error restarting WeChat: {e}")
|
||||||
|
|
||||||
|
def check_api_available(self) -> bool:
|
||||||
|
"""Check if WeChatTweak API is available."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run([
|
||||||
|
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
||||||
|
], capture_output=True, text=True, timeout=5)
|
||||||
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract readable text from message content, removing XML and system messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The raw message content (can be string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned, readable text
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Handle dictionary content (like quoted messages)
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Extract text from dictionary structure
|
||||||
|
text_parts = []
|
||||||
|
if 'title' in content:
|
||||||
|
text_parts.append(str(content['title']))
|
||||||
|
if 'quoted' in content:
|
||||||
|
text_parts.append(str(content['quoted']))
|
||||||
|
if 'content' in content:
|
||||||
|
text_parts.append(str(content['content']))
|
||||||
|
if 'text' in content:
|
||||||
|
text_parts.append(str(content['text']))
|
||||||
|
|
||||||
|
if text_parts:
|
||||||
|
return " | ".join(text_parts)
|
||||||
|
else:
|
||||||
|
# If we can't extract meaningful text from dict, return empty
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Handle string content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
|
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||||
|
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||||
|
|
||||||
|
# If it's just XML or system message, return empty
|
||||||
|
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return clean_content.strip()
|
||||||
|
|
||||||
|
def _is_text_message(self, content: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a message contains readable text content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The message content (can be string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the message contains readable text, False otherwise
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle dictionary content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Check if dict has any readable text fields
|
||||||
|
text_fields = ['title', 'quoted', 'content', 'text']
|
||||||
|
for field in text_fields:
|
||||||
|
if field in content and content[field]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle string content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip image messages (contain XML with img tags)
|
||||||
|
if '<img' in content and 'cdnurl' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
|
if '<emoji' in content and 'productid' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip voice messages
|
||||||
|
if '<voice' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip video messages
|
||||||
|
if '<video' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip file messages
|
||||||
|
if '<appmsg' in content and 'appid' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip system messages (like "recalled a message")
|
||||||
|
if 'recalled a message' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
|
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||||
|
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||||
|
|
||||||
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
||||||
|
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||||
|
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||||
|
overlap_messages: Number of messages to overlap between consecutive groups
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of concatenated message groups
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
concatenated_groups = []
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
last_timestamp = None
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# Extract message info
|
||||||
|
content = message.get('content', '')
|
||||||
|
message_text = message.get('message', '')
|
||||||
|
create_time = message.get('createTime', 0)
|
||||||
|
from_user = message.get('fromUser', '')
|
||||||
|
to_user = message.get('toUser', '')
|
||||||
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
readable_text = message_text
|
||||||
|
|
||||||
|
# Skip empty messages
|
||||||
|
if not readable_text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check time window constraint (only if time_window_minutes != -1)
|
||||||
|
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||||
|
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||||
|
if time_diff_minutes > time_window_minutes:
|
||||||
|
# Time gap too large, start new group
|
||||||
|
if current_group:
|
||||||
|
concatenated_groups.append({
|
||||||
|
'messages': current_group,
|
||||||
|
'total_length': current_length,
|
||||||
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
|
})
|
||||||
|
# Keep last few messages for overlap
|
||||||
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
|
current_group = current_group[-overlap_messages:]
|
||||||
|
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||||
|
else:
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# Check length constraint (only if max_length != -1)
|
||||||
|
message_length = len(readable_text)
|
||||||
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
|
# Current group would exceed max length, save it and start new
|
||||||
|
concatenated_groups.append({
|
||||||
|
'messages': current_group,
|
||||||
|
'total_length': current_length,
|
||||||
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
|
})
|
||||||
|
# Keep last few messages for overlap
|
||||||
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
|
current_group = current_group[-overlap_messages:]
|
||||||
|
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||||
|
else:
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# Add message to current group
|
||||||
|
current_group.append(message)
|
||||||
|
current_length += message_length
|
||||||
|
last_timestamp = create_time
|
||||||
|
|
||||||
|
# Add the last group if it exists
|
||||||
|
if current_group:
|
||||||
|
concatenated_groups.append({
|
||||||
|
'messages': current_group,
|
||||||
|
'total_length': current_length,
|
||||||
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
return concatenated_groups
|
||||||
|
|
||||||
|
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_group: Dictionary containing messages and metadata
|
||||||
|
contact_name: Name of the contact
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted concatenated content
|
||||||
|
"""
|
||||||
|
messages = message_group['messages']
|
||||||
|
start_time = message_group['start_time']
|
||||||
|
end_time = message_group['end_time']
|
||||||
|
|
||||||
|
# Format timestamps
|
||||||
|
if start_time:
|
||||||
|
try:
|
||||||
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
|
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
start_time_str = str(start_time)
|
||||||
|
else:
|
||||||
|
start_time_str = "Unknown"
|
||||||
|
|
||||||
|
if end_time:
|
||||||
|
try:
|
||||||
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
|
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
end_time_str = str(end_time)
|
||||||
|
else:
|
||||||
|
end_time_str = "Unknown"
|
||||||
|
|
||||||
|
# Build concatenated message content
|
||||||
|
message_parts = []
|
||||||
|
for message in messages:
|
||||||
|
content = message.get('content', '')
|
||||||
|
message_text = message.get('message', '')
|
||||||
|
create_time = message.get('createTime', 0)
|
||||||
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
readable_text = message_text
|
||||||
|
|
||||||
|
# Format individual message
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime('%H:%M:%S')
|
||||||
|
except:
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
sender = "Me" if is_sent_from_self else "Contact"
|
||||||
|
message_parts.append(f"[{time_str}] {sender}: {readable_text}")
|
||||||
|
|
||||||
|
concatenated_text = "\n".join(message_parts)
|
||||||
|
|
||||||
|
# Create final document content
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
|
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||||
|
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
return doc_content
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing exported WeChat JSON files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of chat entries to read.
|
||||||
|
wechat_export_dir (str): Custom path to WeChat export directory.
|
||||||
|
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
|
||||||
|
concatenate_messages (bool): Whether to concatenate messages based on length rules.
|
||||||
|
max_length (int): Maximum length for concatenated message groups (default: 1000).
|
||||||
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
||||||
|
include_non_text = load_kwargs.get('include_non_text', False)
|
||||||
|
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
||||||
|
max_length = load_kwargs.get('max_length', 1000)
|
||||||
|
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
||||||
|
|
||||||
|
# Default WeChat export path
|
||||||
|
if wechat_export_dir is None:
|
||||||
|
wechat_export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
|
if not os.path.exists(wechat_export_dir):
|
||||||
|
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find all JSON files in the export directory
|
||||||
|
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
||||||
|
print(f"Found {len(json_files)} WeChat chat history files")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for json_file in json_files:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, 'r', encoding='utf-8') as f:
|
||||||
|
chat_data = json.load(f)
|
||||||
|
|
||||||
|
# Extract contact name from filename
|
||||||
|
contact_name = json_file.stem
|
||||||
|
|
||||||
|
if concatenate_messages:
|
||||||
|
# Filter messages to only include readable text messages
|
||||||
|
readable_messages = []
|
||||||
|
for message in chat_data:
|
||||||
|
try:
|
||||||
|
content = message.get('content', '')
|
||||||
|
if not include_non_text and not self._is_text_message(content):
|
||||||
|
continue
|
||||||
|
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text and not include_non_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
readable_messages.append(message)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Concatenate messages based on rules
|
||||||
|
message_groups = self._concatenate_messages(
|
||||||
|
readable_messages,
|
||||||
|
max_length=-1,
|
||||||
|
time_window_minutes=-1,
|
||||||
|
overlap_messages=0 # Keep 2 messages overlap between groups
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create documents from concatenated groups
|
||||||
|
for message_group in message_groups:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
doc_content = self._create_concatenated_content(message_group, contact_name)
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Original single-message processing
|
||||||
|
for message in chat_data:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract message information
|
||||||
|
from_user = message.get('fromUser', '')
|
||||||
|
to_user = message.get('toUser', '')
|
||||||
|
content = message.get('content', '')
|
||||||
|
message_text = message.get('message', '')
|
||||||
|
create_time = message.get('createTime', 0)
|
||||||
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
|
# Handle content that might be dict or string
|
||||||
|
try:
|
||||||
|
# Check if this is a readable text message
|
||||||
|
if not include_non_text and not self._is_text_message(content):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text and not include_non_text:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# Skip messages that cause processing errors
|
||||||
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert timestamp to readable format
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
# Create document content with metadata header and contact info
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
Is sent from self: {is_sent_from_self}
|
||||||
|
Time: {time_str}
|
||||||
|
Message: {readable_text if readable_text else message_text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create document with embedded metadata
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} WeChat chat documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading WeChat history: {e}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_wechat_export_dirs() -> List[Path]:
|
||||||
|
"""
|
||||||
|
Find all WeChat export directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to WeChat export directories
|
||||||
|
"""
|
||||||
|
export_dirs = []
|
||||||
|
|
||||||
|
# Look for common export directory names
|
||||||
|
possible_dirs = [
|
||||||
|
Path("./wechat_export_test"),
|
||||||
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_chat_history"),
|
||||||
|
Path("./chat_export")
|
||||||
|
]
|
||||||
|
|
||||||
|
for export_dir in possible_dirs:
|
||||||
|
if export_dir.exists() and export_dir.is_dir():
|
||||||
|
json_files = list(export_dir.glob("*.json"))
|
||||||
|
if json_files:
|
||||||
|
export_dirs.append(export_dir)
|
||||||
|
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
||||||
|
|
||||||
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
|
return export_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
||||||
|
"""
|
||||||
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Path to the output file
|
||||||
|
max_count: Maximum number of entries to export
|
||||||
|
export_dir: Directory containing WeChat JSON files
|
||||||
|
include_non_text: Whether to include non-text messages
|
||||||
|
"""
|
||||||
|
if export_dir is None:
|
||||||
|
export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
|
if not os.path.exists(export_dir):
|
||||||
|
print(f"WeChat export directory not found at: {export_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
count = 0
|
||||||
|
for json_file in json_files:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, 'r', encoding='utf-8') as json_f:
|
||||||
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
|
contact_name = json_file.stem
|
||||||
|
f.write(f"\n=== Chat with {contact_name} ===\n")
|
||||||
|
|
||||||
|
for message in chat_data:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
from_user = message.get('fromUser', '')
|
||||||
|
content = message.get('content', '')
|
||||||
|
message_text = message.get('message', '')
|
||||||
|
create_time = message.get('createTime', 0)
|
||||||
|
|
||||||
|
# Skip non-text messages unless requested
|
||||||
|
if not include_non_text:
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
if not reader._is_text_message(content):
|
||||||
|
continue
|
||||||
|
readable_text = reader._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
continue
|
||||||
|
message_text = readable_text
|
||||||
|
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Exported {count} chat entries to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
||||||
|
"""
|
||||||
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory to save exported chat history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to export directory if successful, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_path = Path(export_dir)
|
||||||
|
export_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Exporting WeChat chat history to {export_path}...")
|
||||||
|
|
||||||
|
# Check if wechat-exporter directory exists
|
||||||
|
if not self.wechat_exporter_dir.exists():
|
||||||
|
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Install requirements if needed
|
||||||
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
|
if requirements_file.exists():
|
||||||
|
print("Installing wechat-exporter requirements...")
|
||||||
|
subprocess.run([
|
||||||
|
"uv", "pip", "install", "-r", str(requirements_file)
|
||||||
|
], check=True)
|
||||||
|
|
||||||
|
# Run the export command
|
||||||
|
print("Running wechat-exporter...")
|
||||||
|
result = subprocess.run([
|
||||||
|
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
||||||
|
"export-all", str(export_path)
|
||||||
|
], capture_output=True, text=True, check=True)
|
||||||
|
|
||||||
|
print("Export command output:")
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Export errors:")
|
||||||
|
print(result.stderr)
|
||||||
|
|
||||||
|
# Check if export was successful
|
||||||
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
|
json_files = list(export_path.glob("*.json"))
|
||||||
|
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
||||||
|
return export_path
|
||||||
|
else:
|
||||||
|
print("Export completed but no JSON files found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Export command failed: {e}")
|
||||||
|
print(f"Command output: {e.stdout}")
|
||||||
|
print(f"Command errors: {e.stderr}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export failed: {e}")
|
||||||
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
||||||
|
"""
|
||||||
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory to save exported chat history if needed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to WeChat export directories
|
||||||
|
"""
|
||||||
|
export_dirs = []
|
||||||
|
|
||||||
|
# Look for existing exports in common locations
|
||||||
|
possible_export_dirs = [
|
||||||
|
Path("./wechat_database_export"),
|
||||||
|
Path("./wechat_export_test"),
|
||||||
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
|
Path("./wechat_chat_history"),
|
||||||
|
Path("./chat_export")
|
||||||
|
]
|
||||||
|
|
||||||
|
for export_dir_path in possible_export_dirs:
|
||||||
|
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
||||||
|
export_dirs.append(export_dir_path)
|
||||||
|
print(f"Found existing export: {export_dir_path}")
|
||||||
|
|
||||||
|
# If no existing exports, try to export automatically
|
||||||
|
if not export_dirs:
|
||||||
|
print("No existing WeChat exports found. Starting direct export...")
|
||||||
|
|
||||||
|
# Try to export using wechat-exporter
|
||||||
|
exported_path = self.export_wechat_chat_history(export_dir)
|
||||||
|
if exported_path:
|
||||||
|
export_dirs = [exported_path]
|
||||||
|
else:
|
||||||
|
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
|
|
||||||
|
return export_dirs
|
||||||
286
examples/mail_reader_leann.py
Normal file
286
examples/mail_reader_leann.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
# Add the project root to Python path so we can import from examples
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
def get_mail_path():
|
||||||
|
"""Get the mail path for the current user"""
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
return os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
# Default mail path for macOS
|
||||||
|
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple mail data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages_dirs: List of Path objects pointing to Messages directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process per directory
|
||||||
|
include_html: Whether to include HTML content in email processing
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple mail data sources...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
reader = EmlxReader(include_html=include_html)
|
||||||
|
# from email_data.email import EmlxMboxReader
|
||||||
|
# from pathlib import Path
|
||||||
|
# reader = EmlxMboxReader()
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Messages directory
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(messages_dir)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {messages_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
|
"""
|
||||||
|
Create LEANN index from mail data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mail_path: Path to the mail directory
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process
|
||||||
|
include_html: Whether to include HTML content in email processing
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from mail data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
reader = EmlxReader(include_html=include_html)
|
||||||
|
# from email_data.email import EmlxMboxReader
|
||||||
|
# from pathlib import Path
|
||||||
|
# reader = EmlxMboxReader()
|
||||||
|
documents = reader.load_data(Path(mail_path))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} email documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path,
|
||||||
|
llm_config={"type": "openai", "model": "gpt-4o"})
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=10,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=12,
|
||||||
|
beam_width=1,
|
||||||
|
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||||
|
# Remove --mail-path argument and auto-detect all Messages directories
|
||||||
|
# Remove DEFAULT_MAIL_PATH
|
||||||
|
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
||||||
|
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||||
|
parser.add_argument('--max-emails', type=int, default=1000,
|
||||||
|
help='Maximum number of emails to process (-1 means all)')
|
||||||
|
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
||||||
|
help='Single query to run (default: runs example queries)')
|
||||||
|
parser.add_argument('--include-html', action='store_true', default=False,
|
||||||
|
help='Include HTML content in email processing (default: False)')
|
||||||
|
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
||||||
|
help='Embedding model to use (default: facebook/contriever)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"args: {args}")
|
||||||
|
|
||||||
|
# Automatically find all Messages directories under the current user's Mail directory
|
||||||
|
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||||
|
mail_path = get_mail_path()
|
||||||
|
print(f"Searching for email data in: {mail_path}")
|
||||||
|
messages_dirs = find_all_messages_directories(mail_path)
|
||||||
|
|
||||||
|
print('len(messages_dirs): ', len(messages_dirs))
|
||||||
|
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Messages directories found. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"how's the icloud related advertisement saying",
|
||||||
|
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||||
|
]
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
108
examples/mail_reader_llamaindex.py
Normal file
108
examples/mail_reader_llamaindex.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
# Add the project root to Python path so we can import from examples
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from llama_index.core import VectorStoreIndex, StorageContext
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# --- EMBEDDING MODEL ---
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# --- END EMBEDDING MODEL ---
|
||||||
|
|
||||||
|
# Import EmlxReader from the new module
|
||||||
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000, include_html: bool = False):
|
||||||
|
print("Creating index from mail data with embedded metadata...")
|
||||||
|
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
# Use facebook/contriever as the embedder
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||||
|
# set on device
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
embed_model._model.to("cuda")
|
||||||
|
# set mps
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
embed_model._model.to("mps")
|
||||||
|
else:
|
||||||
|
embed_model._model.to("cpu")
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents,
|
||||||
|
transformations=[text_splitter],
|
||||||
|
embed_model=embed_model
|
||||||
|
)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
index.storage_context.persist(persist_dir=save_dir)
|
||||||
|
print(f"Index saved to {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index_embedded"):
|
||||||
|
try:
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store,
|
||||||
|
storage_context=storage_context
|
||||||
|
)
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"Query: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='LlamaIndex Mail Reader - Create and query email index')
|
||||||
|
parser.add_argument('--mail-path', type=str,
|
||||||
|
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||||
|
help='Path to mail data directory')
|
||||||
|
parser.add_argument('--save-dir', type=str, default="mail_index_embedded",
|
||||||
|
help='Directory to store the index (default: mail_index_embedded)')
|
||||||
|
parser.add_argument('--max-emails', type=int, default=10000,
|
||||||
|
help='Maximum number of emails to process')
|
||||||
|
parser.add_argument('--include-html', action='store_true', default=False,
|
||||||
|
help='Include HTML content in email processing (default: False)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mail_path = args.mail_path
|
||||||
|
save_dir = args.save_dir
|
||||||
|
|
||||||
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
|
print("Loading existing index...")
|
||||||
|
index = load_index(save_dir)
|
||||||
|
else:
|
||||||
|
print("Creating new index...")
|
||||||
|
index = create_and_save_index(mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html)
|
||||||
|
if index:
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"how's the icloud related advertisement saying",
|
||||||
|
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||||
|
]
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*50)
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,76 +1,110 @@
|
|||||||
|
import argparse
|
||||||
from llama_index.core import SimpleDirectoryReader, Settings
|
from llama_index.core import SimpleDirectoryReader, Settings
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
from llama_index.node_parser.docling import DoclingNodeParser
|
|
||||||
from llama_index.readers.docling import DoclingReader
|
|
||||||
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
import leann_backend_diskann # Import to ensure backend registration
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)
|
node_parser = SentenceSplitter(
|
||||||
file_extractor: dict[str, BaseReader] = {
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
".docx": reader,
|
|
||||||
".pptx": reader,
|
|
||||||
".pdf": reader,
|
|
||||||
".xlsx": reader,
|
|
||||||
}
|
|
||||||
node_parser = DoclingNodeParser(
|
|
||||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=10240)
|
|
||||||
)
|
)
|
||||||
|
print("Loading documents...")
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"examples/data",
|
"examples/data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
file_extractor=file_extractor,
|
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
).load_data(show_progress=True)
|
).load_data(show_progress=True)
|
||||||
|
print("Documents loaded.")
|
||||||
# Extract text from documents and prepare for Leann
|
|
||||||
all_texts = []
|
all_texts = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
# DoclingNodeParser returns Node objects, which have a text attribute
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.text)
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_pdf_index")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
|
||||||
|
|
||||||
if INDEX_DIR.exists():
|
async def main(args):
|
||||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
INDEX_DIR = Path(args.index_dir)
|
||||||
shutil.rmtree(INDEX_DIR)
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
|
||||||
builder = LeannBuilder(
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
backend_name="diskann",
|
|
||||||
embedding_model="sentence-transformers/all-mpnet-base-v2", # Using a common sentence transformer model
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
# Use HNSW backend for better macOS compatibility
|
||||||
for chunk_text in all_texts:
|
builder = LeannBuilder(
|
||||||
builder.add_text(chunk_text)
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
async def main():
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
|
||||||
|
|
||||||
query = "Based on the paper, what are the two main techniques LEANN uses to achieve low storage overhead and high retrieval accuracy?"
|
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||||
|
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
|
|
||||||
|
# query = (
|
||||||
|
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
# )
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, recompute_beighbor_embeddings=True)
|
chat_response = chat.ask(
|
||||||
|
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
||||||
|
)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Leann Chat with various LLM backends."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="hf",
|
||||||
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
|
help="The LLM backend to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="Qwen/Qwen3-0.6B",
|
||||||
|
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:11434",
|
||||||
|
help="The host for the Ollama API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./test_doc_files",
|
||||||
|
help="Directory where the Leann index will be stored.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(main(args))
|
||||||
|
|||||||
319
examples/multi_vector_aggregator.py
Normal file
319
examples/multi_vector_aggregator.py
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Multi-Vector Aggregator for Fat Embeddings
|
||||||
|
==========================================
|
||||||
|
|
||||||
|
This module implements aggregation strategies for multi-vector embeddings,
|
||||||
|
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- MaxSim aggregation (take maximum similarity across patches)
|
||||||
|
- Voting-based aggregation (count patch matches)
|
||||||
|
- Weighted aggregation (attention-score weighted)
|
||||||
|
- Spatial clustering of matching patches
|
||||||
|
- Document-level result consolidation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PatchResult:
|
||||||
|
"""Represents a single patch search result."""
|
||||||
|
patch_id: int
|
||||||
|
image_name: str
|
||||||
|
image_path: str
|
||||||
|
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||||
|
score: float
|
||||||
|
attention_score: float
|
||||||
|
scale: float
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AggregatedResult:
|
||||||
|
"""Represents an aggregated document-level result."""
|
||||||
|
image_name: str
|
||||||
|
image_path: str
|
||||||
|
doc_score: float
|
||||||
|
patch_count: int
|
||||||
|
best_patch: PatchResult
|
||||||
|
all_patches: List[PatchResult]
|
||||||
|
aggregation_method: str
|
||||||
|
spatial_clusters: Optional[List[List[PatchResult]]] = None
|
||||||
|
|
||||||
|
class MultiVectorAggregator:
|
||||||
|
"""
|
||||||
|
Aggregates multiple patch-level results into document-level results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
aggregation_method: str = "maxsim",
|
||||||
|
spatial_clustering: bool = True,
|
||||||
|
cluster_distance_threshold: float = 100.0):
|
||||||
|
"""
|
||||||
|
Initialize the aggregator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||||
|
spatial_clustering: Whether to cluster spatially close patches
|
||||||
|
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||||
|
"""
|
||||||
|
self.aggregation_method = aggregation_method
|
||||||
|
self.spatial_clustering = spatial_clustering
|
||||||
|
self.cluster_distance_threshold = cluster_distance_threshold
|
||||||
|
|
||||||
|
def aggregate_results(self,
|
||||||
|
search_results: List[Dict[str, Any]],
|
||||||
|
top_k: int = 10) -> List[AggregatedResult]:
|
||||||
|
"""
|
||||||
|
Aggregate patch-level search results into document-level results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of search results from LeannSearcher
|
||||||
|
top_k: Number of top documents to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of aggregated document results
|
||||||
|
"""
|
||||||
|
# Group results by image
|
||||||
|
image_groups = defaultdict(list)
|
||||||
|
|
||||||
|
for result in search_results:
|
||||||
|
metadata = result.metadata
|
||||||
|
if "image_name" in metadata and "patch_id" in metadata:
|
||||||
|
patch_result = PatchResult(
|
||||||
|
patch_id=metadata["patch_id"],
|
||||||
|
image_name=metadata["image_name"],
|
||||||
|
image_path=metadata["image_path"],
|
||||||
|
coordinates=tuple(metadata["coordinates"]),
|
||||||
|
score=result.score,
|
||||||
|
attention_score=metadata.get("attention_score", 0.0),
|
||||||
|
scale=metadata.get("scale", 1.0),
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
image_groups[metadata["image_name"]].append(patch_result)
|
||||||
|
|
||||||
|
# Aggregate each image group
|
||||||
|
aggregated_results = []
|
||||||
|
for image_name, patches in image_groups.items():
|
||||||
|
if len(patches) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||||
|
aggregated_results.append(agg_result)
|
||||||
|
|
||||||
|
# Sort by aggregated score and return top-k
|
||||||
|
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||||
|
return aggregated_results[:top_k]
|
||||||
|
|
||||||
|
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
|
||||||
|
"""Aggregate patches for a single image."""
|
||||||
|
|
||||||
|
if self.aggregation_method == "maxsim":
|
||||||
|
doc_score = max(patch.score for patch in patches)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "voting":
|
||||||
|
# Count patches above threshold
|
||||||
|
threshold = np.percentile([p.score for p in patches], 75)
|
||||||
|
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "weighted":
|
||||||
|
# Weight by attention scores
|
||||||
|
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||||
|
total_weights = sum(p.attention_score for p in patches)
|
||||||
|
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "mean":
|
||||||
|
doc_score = np.mean([patch.score for patch in patches])
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||||
|
|
||||||
|
# Spatial clustering if enabled
|
||||||
|
spatial_clusters = None
|
||||||
|
if self.spatial_clustering:
|
||||||
|
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||||
|
|
||||||
|
return AggregatedResult(
|
||||||
|
image_name=image_name,
|
||||||
|
image_path=patches[0].image_path,
|
||||||
|
doc_score=float(doc_score),
|
||||||
|
patch_count=len(patches),
|
||||||
|
best_patch=best_patch,
|
||||||
|
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||||
|
aggregation_method=self.aggregation_method,
|
||||||
|
spatial_clusters=spatial_clusters
|
||||||
|
)
|
||||||
|
|
||||||
|
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
|
||||||
|
"""Cluster patches that are spatially close to each other."""
|
||||||
|
if len(patches) <= 1:
|
||||||
|
return [patches]
|
||||||
|
|
||||||
|
clusters = []
|
||||||
|
remaining_patches = patches.copy()
|
||||||
|
|
||||||
|
while remaining_patches:
|
||||||
|
# Start new cluster with highest scoring remaining patch
|
||||||
|
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||||
|
current_cluster = [seed_patch]
|
||||||
|
remaining_patches.remove(seed_patch)
|
||||||
|
|
||||||
|
# Add nearby patches to cluster
|
||||||
|
added_to_cluster = True
|
||||||
|
while added_to_cluster:
|
||||||
|
added_to_cluster = False
|
||||||
|
for patch in remaining_patches.copy():
|
||||||
|
if self._is_patch_nearby(patch, current_cluster):
|
||||||
|
current_cluster.append(patch)
|
||||||
|
remaining_patches.remove(patch)
|
||||||
|
added_to_cluster = True
|
||||||
|
|
||||||
|
clusters.append(current_cluster)
|
||||||
|
|
||||||
|
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||||
|
|
||||||
|
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
|
||||||
|
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||||
|
patch_center = self._get_patch_center(patch.coordinates)
|
||||||
|
|
||||||
|
for cluster_patch in cluster:
|
||||||
|
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||||
|
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
|
||||||
|
(patch_center[1] - cluster_center[1])**2)
|
||||||
|
|
||||||
|
if distance <= self.cluster_distance_threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
||||||
|
"""Get center point of a patch."""
|
||||||
|
x1, y1, x2, y2 = coordinates
|
||||||
|
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||||
|
|
||||||
|
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
|
||||||
|
"""Pretty print aggregated results."""
|
||||||
|
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f"\n{i+1}. {result.image_name}")
|
||||||
|
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||||
|
print(f" Path: {result.image_path}")
|
||||||
|
|
||||||
|
# Show best patch
|
||||||
|
best = result.best_patch
|
||||||
|
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
|
||||||
|
|
||||||
|
# Show top patches
|
||||||
|
print(f" 📍 Top Patches:")
|
||||||
|
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||||
|
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
|
||||||
|
|
||||||
|
# Show spatial clusters if available
|
||||||
|
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||||
|
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||||
|
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||||
|
cluster_score = max(p.score for p in cluster)
|
||||||
|
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
|
||||||
|
|
||||||
|
def demo_aggregation():
|
||||||
|
"""Demonstrate the multi-vector aggregation functionality."""
|
||||||
|
print("=== Multi-Vector Aggregation Demo ===")
|
||||||
|
|
||||||
|
# Simulate some patch-level search results
|
||||||
|
# In real usage, these would come from LeannSearcher.search()
|
||||||
|
|
||||||
|
class MockResult:
|
||||||
|
def __init__(self, score, metadata):
|
||||||
|
self.score = score
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
# Simulate results for 2 images with multiple patches each
|
||||||
|
mock_results = [
|
||||||
|
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||||
|
MockResult(0.85, {
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 3,
|
||||||
|
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||||
|
"attention_score": 0.92,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.78, {
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 7,
|
||||||
|
"coordinates": [200, 300, 324, 424], # Cat area
|
||||||
|
"attention_score": 0.88,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.72, {
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 12,
|
||||||
|
"coordinates": [150, 100, 274, 224], # Appliances
|
||||||
|
"attention_score": 0.75,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.65, {
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 15,
|
||||||
|
"coordinates": [50, 250, 174, 374], # Furniture
|
||||||
|
"attention_score": 0.70,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
|
||||||
|
# Image 2: city_street.jpg - 3 patches
|
||||||
|
MockResult(0.68, {
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 2,
|
||||||
|
"coordinates": [300, 100, 424, 224], # Buildings
|
||||||
|
"attention_score": 0.80,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.62, {
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 8,
|
||||||
|
"coordinates": [100, 350, 224, 474], # Street level
|
||||||
|
"attention_score": 0.75,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.55, {
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 11,
|
||||||
|
"coordinates": [400, 200, 524, 324], # Sky area
|
||||||
|
"attention_score": 0.60,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test different aggregation methods
|
||||||
|
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
|
||||||
|
|
||||||
|
aggregator = MultiVectorAggregator(
|
||||||
|
aggregation_method=method,
|
||||||
|
spatial_clustering=True,
|
||||||
|
cluster_distance_threshold=100.0
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||||
|
aggregator.print_aggregated_results(aggregated)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo_aggregation()
|
||||||
108
examples/openai_hnsw_example.py
Normal file
108
examples/openai_hnsw_example.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
OpenAI Embedding Example
|
||||||
|
|
||||||
|
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import dotenv
|
||||||
|
from pathlib import Path
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Check if OpenAI API key is available
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
print("ERROR: OPENAI_API_KEY environment variable not set")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
||||||
|
|
||||||
|
# Sample texts
|
||||||
|
sample_texts = [
|
||||||
|
"Machine learning is a powerful technology that enables computers to learn from data.",
|
||||||
|
"Natural language processing helps computers understand and generate human language.",
|
||||||
|
"Deep learning uses neural networks with multiple layers to solve complex problems.",
|
||||||
|
"Computer vision allows machines to interpret and understand visual information.",
|
||||||
|
"Reinforcement learning trains agents to make decisions through trial and error.",
|
||||||
|
"Data science combines statistics, math, and programming to extract insights from data.",
|
||||||
|
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
||||||
|
"Python is a popular programming language used extensively in data science and AI.",
|
||||||
|
"Neural networks are inspired by the structure and function of the human brain.",
|
||||||
|
"Big data refers to extremely large datasets that require special tools to process."
|
||||||
|
]
|
||||||
|
|
||||||
|
INDEX_DIR = Path("./simple_openai_test_index")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
||||||
|
|
||||||
|
print(f"\n=== Building Index with OpenAI Embeddings ===")
|
||||||
|
print(f"Index path: {INDEX_PATH}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use proper configuration for OpenAI embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
# HNSW settings for OpenAI embeddings
|
||||||
|
M=16, # Smaller graph degree
|
||||||
|
efConstruction=64, # Smaller construction complexity
|
||||||
|
is_compact=True, # Enable compact storage for recompute
|
||||||
|
is_recompute=True, # MUST enable for OpenAI embeddings
|
||||||
|
num_threads=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(sample_texts)} texts to the index...")
|
||||||
|
for i, text in enumerate(sample_texts):
|
||||||
|
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
||||||
|
builder.add_text(text, metadata)
|
||||||
|
|
||||||
|
print("Building index...")
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print(f"✅ Index built successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error building index: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"\n=== Testing Search ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
|
||||||
|
test_queries = [
|
||||||
|
"What is machine learning?",
|
||||||
|
"How do neural networks work?",
|
||||||
|
"Programming languages for data science"
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
print(f"\n🔍 Query: '{query}'")
|
||||||
|
results = searcher.search(query, top_k=3)
|
||||||
|
|
||||||
|
print(f" Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f" {i+1}. Score: {result.score:.4f}")
|
||||||
|
print(f" Text: {result.text[:80]}...")
|
||||||
|
|
||||||
|
print(f"\n✅ Search test completed successfully!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error during search: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = main()
|
||||||
|
if success:
|
||||||
|
print(f"\n🎉 Simple OpenAI index test completed successfully!")
|
||||||
|
else:
|
||||||
|
print(f"\n💥 Simple OpenAI index test failed!")
|
||||||
18
examples/resue_index.py
Normal file
18
examples/resue_index.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import asyncio
|
||||||
|
from leann.api import LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH)
|
||||||
|
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||||
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
|
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
|
||||||
|
print(f"\n[PHASE 2] Response: {response}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
382
examples/run_evaluation.py
Normal file
382
examples/run_evaluation.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
This script runs a recall evaluation on a given LEANN index.
|
||||||
|
It correctly compares results by fetching the text content for both the new search
|
||||||
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher, LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
|
if not data_root.exists():
|
||||||
|
print(f"Data directory '{data_root}' not found.")
|
||||||
|
print(
|
||||||
|
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
if download_embeddings:
|
||||||
|
# Download everything including embeddings (large files)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
print("Data download complete (including embeddings)!")
|
||||||
|
else:
|
||||||
|
# Download only specific folders, excluding embeddings
|
||||||
|
allow_patterns = [
|
||||||
|
"ground_truth/**",
|
||||||
|
"indices/**",
|
||||||
|
"queries/**",
|
||||||
|
"*.md",
|
||||||
|
"*.txt",
|
||||||
|
]
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
print("Data download complete (excluding embeddings)!")
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
|
)
|
||||||
|
print("uv pip install -e '.[dev]'")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred during data download: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||||
|
"""Download embeddings files specifically."""
|
||||||
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
# Check if specific dataset embeddings exist
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
print(f"Embeddings for {dataset_type} already exist")
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
print("Downloading embeddings from HuggingFace Hub...")
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# Download only embeddings folder
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=["embeddings/**/*.pkl"],
|
||||||
|
)
|
||||||
|
print("Embeddings download complete!")
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
return str(embeddings_dir)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading embeddings: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helper Function to get Golden Passages ---
|
||||||
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||||
|
"""
|
||||||
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
|
passage manager.
|
||||||
|
"""
|
||||||
|
golden_texts = set()
|
||||||
|
for gid in golden_ids:
|
||||||
|
try:
|
||||||
|
# PassageManager uses string IDs
|
||||||
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
|
golden_texts.add(passage_data["text"])
|
||||||
|
except KeyError:
|
||||||
|
print(
|
||||||
|
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||||
|
)
|
||||||
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(file_path: Path) -> List[str]:
|
||||||
|
queries = []
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["query"])
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def build_index_from_embeddings(
|
||||||
|
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||||
|
output_path: Path where to save the index
|
||||||
|
backend: Backend to use ("hnsw" or "diskann")
|
||||||
|
"""
|
||||||
|
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||||
|
|
||||||
|
# Create builder with appropriate parameters
|
||||||
|
if backend == "hnsw":
|
||||||
|
builder_kwargs = {
|
||||||
|
"M": 32, # Graph degree
|
||||||
|
"efConstruction": 256, # Construction complexity
|
||||||
|
"is_compact": True, # Use compact storage
|
||||||
|
"is_recompute": True, # Enable pruning for better recall
|
||||||
|
}
|
||||||
|
elif backend == "diskann":
|
||||||
|
builder_kwargs = {
|
||||||
|
"complexity": 64,
|
||||||
|
"graph_degree": 32,
|
||||||
|
"search_memory_maximum": 8.0, # GB
|
||||||
|
"build_memory_maximum": 16.0, # GB
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
builder_kwargs = {}
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||||
|
dimensions=768, # Will be auto-detected from embeddings
|
||||||
|
**builder_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build index from precomputed embeddings
|
||||||
|
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||||
|
print(f"Index saved to: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run recall evaluation on a LEANN index."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"index_path",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="Path to the LEANN index to evaluate or build (optional).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
choices=["evaluate", "build"],
|
||||||
|
default="evaluate",
|
||||||
|
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embeddings-file",
|
||||||
|
type=str,
|
||||||
|
help="Path to embeddings pickle file (optional for build mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
default="hnsw",
|
||||||
|
help="Backend to use for building index (default: hnsw)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# --- Path Configuration ---
|
||||||
|
# Assumes a project structure where the script is in 'examples/'
|
||||||
|
# and data is in 'data/' at the project root.
|
||||||
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
|
data_root = project_root / "data"
|
||||||
|
|
||||||
|
# Download data based on mode
|
||||||
|
if args.mode == "build":
|
||||||
|
# For building mode, we need embeddings
|
||||||
|
download_data_if_needed(
|
||||||
|
data_root, download_embeddings=False
|
||||||
|
) # Basic data first
|
||||||
|
|
||||||
|
# Auto-detect dataset type and download embeddings
|
||||||
|
if args.embeddings_file:
|
||||||
|
embeddings_file = args.embeddings_file
|
||||||
|
# Try to detect dataset type from embeddings file path
|
||||||
|
if "rpj_wiki" in str(embeddings_file):
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in str(embeddings_file):
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default
|
||||||
|
else:
|
||||||
|
# Auto-detect from index path if provided, otherwise default to DPR
|
||||||
|
if args.index_path:
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
|
||||||
|
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||||
|
|
||||||
|
# Auto-generate index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
indices_dir = data_root / "indices" / dataset_type
|
||||||
|
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||||
|
print(f"Auto-generated index path: {args.index_path}")
|
||||||
|
|
||||||
|
print(f"Building index from embeddings: {embeddings_file}")
|
||||||
|
built_index_path = build_index_from_embeddings(
|
||||||
|
embeddings_file, args.index_path, args.backend
|
||||||
|
)
|
||||||
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
|
# Ask if user wants to run evaluation
|
||||||
|
eval_response = (
|
||||||
|
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
|
)
|
||||||
|
if eval_response != "y":
|
||||||
|
print("Index building complete. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# For evaluation mode, don't need embeddings
|
||||||
|
download_data_if_needed(data_root, download_embeddings=False)
|
||||||
|
|
||||||
|
# Auto-detect index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
# Default to using downloaded indices
|
||||||
|
indices_dir = data_root / "indices"
|
||||||
|
|
||||||
|
# Try common datasets in order of preference
|
||||||
|
for dataset in ["dpr", "rpj_wiki"]:
|
||||||
|
dataset_dir = indices_dir / dataset
|
||||||
|
if dataset_dir.exists():
|
||||||
|
# Look for index files
|
||||||
|
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||||
|
dataset_dir.glob("*_disk.index")
|
||||||
|
)
|
||||||
|
if index_files:
|
||||||
|
args.index_path = str(
|
||||||
|
index_files[0].with_suffix("")
|
||||||
|
) # Remove .index extension
|
||||||
|
print(f"Using index: {args.index_path}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not args.index_path:
|
||||||
|
print(
|
||||||
|
"No indices found. The data download should have included pre-built indices."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Please check the data/indices/ directory or provide --index-path manually."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Detect dataset type from index path to select the correct ground truth
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
# Fallback: try to infer from the index directory name
|
||||||
|
dataset_type = Path(args.index_path).name
|
||||||
|
print(
|
||||||
|
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
|
golden_results_file = (
|
||||||
|
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
|
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(args.index_path)
|
||||||
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
|
with open(golden_results_file, "r") as f:
|
||||||
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
|
queries = queries[:num_eval_queries]
|
||||||
|
|
||||||
|
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
||||||
|
recall_scores = []
|
||||||
|
search_times = []
|
||||||
|
|
||||||
|
for i in range(num_eval_queries):
|
||||||
|
start_time = time.time()
|
||||||
|
new_results = searcher.search(
|
||||||
|
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||||
|
)
|
||||||
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
|
# Get golden texts directly from the searcher's passage manager
|
||||||
|
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
||||||
|
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||||
|
|
||||||
|
overlap = len(new_texts & golden_texts)
|
||||||
|
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||||
|
recall_scores.append(recall)
|
||||||
|
|
||||||
|
print("\n--- EVALUATION RESULTS ---")
|
||||||
|
print(f"Query: {queries[i]}")
|
||||||
|
print(f"New Results: {new_texts}")
|
||||||
|
print(f"Golden Results: {golden_texts}")
|
||||||
|
print(f"Overlap: {overlap}")
|
||||||
|
print(f"Recall: {recall}")
|
||||||
|
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||||
|
print("--------------------------------")
|
||||||
|
|
||||||
|
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||||
|
avg_time = np.mean(search_times) if search_times else 0
|
||||||
|
|
||||||
|
print("\n🎉 --- Evaluation Complete ---")
|
||||||
|
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||||
|
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,11 +3,17 @@ Simple demo showing basic leann usage
|
|||||||
Run: uv run python examples/simple_demo.py
|
Run: uv run python examples/simple_demo.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("=== Leann Simple Demo ===")
|
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
|
||||||
|
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Sample knowledge base
|
# Sample knowledge base
|
||||||
@@ -24,10 +30,11 @@ def main():
|
|||||||
|
|
||||||
print("1. Building index (no embeddings stored)...")
|
print("1. Building index (no embeddings stored)...")
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
embedding_model="sentence-transformers/all-mpnet-base-v2",
|
embedding_model=args.embedding_model,
|
||||||
prune_ratio=0.7, # Keep 30% of connections
|
backend_name="hnsw",
|
||||||
)
|
)
|
||||||
builder.add_chunks(chunks)
|
for chunk in chunks:
|
||||||
|
builder.add_text(chunk)
|
||||||
builder.build_index("demo_knowledge.leann")
|
builder.build_index("demo_knowledge.leann")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@@ -49,14 +56,7 @@ def main():
|
|||||||
print(f" Text: {result.text[:100]}...")
|
print(f" Text: {result.text[:100]}...")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("3. Memory stats:")
|
print("3. Interactive chat demo:")
|
||||||
stats = searcher.get_memory_stats()
|
|
||||||
print(f" Cache size: {stats.embedding_cache_size}")
|
|
||||||
print(f" Cache memory: {stats.embedding_cache_memory_mb:.1f} MB")
|
|
||||||
print(f" Total chunks: {stats.total_chunks}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("4. Interactive chat demo:")
|
|
||||||
print(" (Note: Requires OpenAI API key for real responses)")
|
print(" (Note: Requires OpenAI API key for real responses)")
|
||||||
|
|
||||||
chat = LeannChat("demo_knowledge.leann")
|
chat = LeannChat("demo_knowledge.leann")
|
||||||
|
|||||||
318
examples/wechat_history_reader_leann.py
Normal file
318
examples/wechat_history_reader_leann.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any, Optional
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Default WeChat export directory
|
||||||
|
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs: List[Path],
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple WeChat export data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dirs: List of Path objects pointing to WeChat export directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of chat entries to process per export
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||||
|
|
||||||
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each WeChat export directory
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(
|
||||||
|
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_count,
|
||||||
|
concatenate_messages=False, # Disable concatenation - one message per document
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=64)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index(
|
||||||
|
export_dir: str = None,
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = 1000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from WeChat chat history data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of chat entries to process
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from WeChat chat history data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=export_dir,
|
||||||
|
max_count=max_count,
|
||||||
|
concatenate_messages=False, # Disable concatenation - one message per document
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} chat documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=20,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=128,
|
||||||
|
beam_width=1,
|
||||||
|
llm_config={
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
|
)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function with integrated WeChat export functionality."""
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||||
|
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_history_june19_test",
|
||||||
|
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-entries",
|
||||||
|
type=int,
|
||||||
|
default=5000,
|
||||||
|
help="Maximum number of chat entries to process (default: 5000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||||
|
|
||||||
|
print(f"Using WeChat export directory: {args.export_dir}")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||||
|
)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "0.1.0",
|
|
||||||
"backend_name": "diskann",
|
|
||||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
|
||||||
"num_chunks": 6,
|
|
||||||
"chunks": [
|
|
||||||
{
|
|
||||||
"text": "Python is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Machine learning transforms industries",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Neural networks process complex data",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Java is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "C++ is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "C# is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
1
packages/__init__.py
Normal file
1
packages/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
project(leann_backend_diskann_wrapper)
|
project(leann_backend_diskann_wrapper)
|
||||||
|
|
||||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
# DiskANN will handle everything itself, including compiling Python bindings
|
||||||
add_subdirectory(src/third_party/DiskANN)
|
add_subdirectory(src/third_party/DiskANN)
|
||||||
|
|||||||
1
packages/leann-backend-diskann/__init__.py
Normal file
1
packages/leann-backend-diskann/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# This file makes the directory a Python package
|
||||||
@@ -1,7 +1 @@
|
|||||||
print("Initializing leann-backend-diskann...")
|
from . import diskann_backend
|
||||||
|
|
||||||
try:
|
|
||||||
from .diskann_backend import DiskannBackend
|
|
||||||
print("INFO: DiskANN backend loaded successfully")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"WARNING: Could not import DiskANN backend: {e}")
|
|
||||||
@@ -1,30 +1,29 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import struct
|
import struct
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict, Any, List, Literal
|
||||||
import contextlib
|
import contextlib
|
||||||
import threading
|
import pickle
|
||||||
import time
|
|
||||||
import atexit
|
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendFactoryInterface,
|
LeannBackendFactoryInterface,
|
||||||
LeannBackendBuilderInterface,
|
LeannBackendBuilderInterface,
|
||||||
LeannBackendSearcherInterface
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
from . import _diskannpy as diskannpy
|
|
||||||
|
|
||||||
METRIC_MAP = {
|
|
||||||
"mips": diskannpy.Metric.INNER_PRODUCT,
|
def _get_diskann_metrics():
|
||||||
"l2": diskannpy.Metric.L2,
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
"cosine": diskannpy.Metric.COSINE,
|
|
||||||
}
|
return {
|
||||||
|
"mips": diskannpy.Metric.INNER_PRODUCT,
|
||||||
|
"l2": diskannpy.Metric.L2,
|
||||||
|
"cosine": diskannpy.Metric.COSINE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def chdir(path):
|
def chdir(path):
|
||||||
@@ -35,102 +34,14 @@ def chdir(path):
|
|||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
|
|
||||||
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
|
|
||||||
|
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||||
num_vectors, dim = data.shape
|
num_vectors, dim = data.shape
|
||||||
with open(file_path, 'wb') as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(struct.pack('I', num_vectors))
|
f.write(struct.pack("I", num_vectors))
|
||||||
f.write(struct.pack('I', dim))
|
f.write(struct.pack("I", dim))
|
||||||
f.write(data.tobytes())
|
f.write(data.tobytes())
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.server_process = None
|
|
||||||
self.server_port = None
|
|
||||||
atexit.register(self.stop_server)
|
|
||||||
|
|
||||||
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查端口是否已被其他无关进程占用
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
print(f"INFO: Starting session-level embedding server as a background process...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
command = [
|
|
||||||
sys.executable,
|
|
||||||
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
|
||||||
"--zmq-port", str(port),
|
|
||||||
"--model-name", model_name
|
|
||||||
]
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent
|
|
||||||
print(f"INFO: Running command from project root: {project_root}")
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
cwd=project_root,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
max_wait, wait_interval = 30, 0.5
|
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"✅ Embedding server is up and ready for this session.")
|
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
|
||||||
log_thread.start()
|
|
||||||
return True
|
|
||||||
if self.server_process.poll() is not None:
|
|
||||||
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
|
||||||
self._log_monitor()
|
|
||||||
return False
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
|
||||||
self.stop_server()
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _log_monitor(self):
|
|
||||||
if not self.server_process:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if self.server_process.stdout:
|
|
||||||
for line in iter(self.server_process.stdout.readline, ''):
|
|
||||||
print(f"[EmbeddingServer LOG]: {line.strip()}")
|
|
||||||
self.server_process.stdout.close()
|
|
||||||
if self.server_process.stderr:
|
|
||||||
for line in iter(self.server_process.stderr.readline, ''):
|
|
||||||
print(f"[EmbeddingServer ERROR]: {line.strip()}")
|
|
||||||
self.server_process.stderr.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Log monitor error: {e}")
|
|
||||||
|
|
||||||
def stop_server(self):
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
|
||||||
self.server_process.terminate()
|
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=5)
|
|
||||||
print("INFO: Server process terminated.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
print("WARNING: Server process did not terminate gracefully, killing it.")
|
|
||||||
self.server_process.kill()
|
|
||||||
self.server_process = None
|
|
||||||
|
|
||||||
@register_backend("diskann")
|
@register_backend("diskann")
|
||||||
class DiskannBackend(LeannBackendFactoryInterface):
|
class DiskannBackend(LeannBackendFactoryInterface):
|
||||||
@@ -140,160 +51,164 @@ class DiskannBackend(LeannBackendFactoryInterface):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
||||||
path = Path(index_path)
|
|
||||||
meta_path = path.parent / f"{path.name}.meta.json"
|
|
||||||
if not meta_path.exists():
|
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
|
||||||
with open(meta_path, 'r') as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
model = SentenceTransformer(meta.get("embedding_model"))
|
|
||||||
dimensions = model.get_sentence_embedding_dimension()
|
|
||||||
kwargs['dimensions'] = dimensions
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
|
|
||||||
|
|
||||||
return DiskannSearcher(index_path, **kwargs)
|
return DiskannSearcher(index_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class DiskannBuilder(LeannBackendBuilderInterface):
|
class DiskannBuilder(LeannBackendBuilderInterface):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
|
|
||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
if not data.flags['C_CONTIGUOUS']:
|
|
||||||
data = np.ascontiguousarray(data)
|
|
||||||
|
|
||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
metric_enum = _get_diskann_metrics().get(
|
||||||
metric_enum = METRIC_MAP.get(metric_str)
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
|
)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
raise ValueError("Unsupported distance_metric.")
|
||||||
|
|
||||||
complexity = build_kwargs.get("complexity", 64)
|
|
||||||
graph_degree = build_kwargs.get("graph_degree", 32)
|
|
||||||
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
|
|
||||||
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
|
|
||||||
num_threads = build_kwargs.get("num_threads", 8)
|
|
||||||
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
|
|
||||||
codebook_prefix = ""
|
|
||||||
|
|
||||||
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
|
|
||||||
with chdir(index_dir):
|
with chdir(index_dir):
|
||||||
diskannpy.build_disk_float_index(
|
diskannpy.build_disk_float_index(
|
||||||
metric_enum,
|
metric_enum,
|
||||||
data_filename,
|
data_filename,
|
||||||
index_prefix,
|
index_prefix,
|
||||||
complexity,
|
build_kwargs.get("complexity", 64),
|
||||||
graph_degree,
|
build_kwargs.get("graph_degree", 32),
|
||||||
final_index_ram_limit,
|
build_kwargs.get("search_memory_maximum", 4.0),
|
||||||
indexing_ram_budget,
|
build_kwargs.get("build_memory_maximum", 8.0),
|
||||||
num_threads,
|
build_kwargs.get("num_threads", 8),
|
||||||
pq_disk_bytes,
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
codebook_prefix
|
"",
|
||||||
)
|
)
|
||||||
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
os.remove(temp_data_file)
|
os.remove(temp_data_file)
|
||||||
|
|
||||||
class DiskannSearcher(LeannBackendSearcherInterface):
|
|
||||||
|
class DiskannSearcher(BaseSearcher):
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
super().__init__(
|
||||||
index_dir = path.parent
|
index_path,
|
||||||
index_prefix = path.stem
|
backend_module_name="leann_backend_diskann.embedding_server",
|
||||||
metric_str = kwargs.get("distance_metric", "mips").lower()
|
**kwargs,
|
||||||
metric_enum = METRIC_MAP.get(metric_str)
|
)
|
||||||
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
|
|
||||||
|
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||||
|
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||||
|
|
||||||
num_threads = kwargs.get("num_threads", 8)
|
self.num_threads = kwargs.get("num_threads", 8)
|
||||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
self.zmq_port = kwargs.get("zmq_port", 6666)
|
||||||
dimensions = kwargs.get("dimensions")
|
|
||||||
if not dimensions:
|
|
||||||
raise ValueError("Vector dimension not provided to DiskannSearcher.")
|
|
||||||
|
|
||||||
try:
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
full_index_prefix = str(index_dir / index_prefix)
|
self._index = diskannpy.StaticDiskFloatIndex(
|
||||||
self._index = diskannpy.StaticDiskFloatIndex(
|
metric_enum,
|
||||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
full_index_prefix,
|
||||||
|
self.num_threads,
|
||||||
|
kwargs.get("num_nodes_to_cache", 0),
|
||||||
|
1,
|
||||||
|
self.zmq_port,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: np.ndarray,
|
||||||
|
top_k: int,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = False,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
zmq_port: int = 5557,
|
||||||
|
batch_recompute: bool = False,
|
||||||
|
dedup_node_dis: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Search for nearest neighbors using DiskANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query vectors (B, D) where B is batch size, D is dimension
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server
|
||||||
|
pruning_strategy: PQ candidate selection strategy:
|
||||||
|
- "global": Use global pruning strategy (default)
|
||||||
|
- "local": Use local pruning strategy
|
||||||
|
- "proportional": Not supported in DiskANN, falls back to global
|
||||||
|
zmq_port: ZMQ port for embedding server
|
||||||
|
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||||
|
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||||
|
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
|
"""
|
||||||
|
# DiskANN doesn't support "proportional" strategy
|
||||||
|
if pruning_strategy == "proportional":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||||
)
|
)
|
||||||
self.num_threads = num_threads
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager()
|
|
||||||
print("✅ DiskANN index loaded successfully.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
# Use recompute_embeddings parameter
|
||||||
complexity = kwargs.get("complexity", 100)
|
use_recompute = recompute_embeddings
|
||||||
beam_width = kwargs.get("beam_width", 4)
|
if use_recompute:
|
||||||
|
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
|
if not meta_file_path.exists():
|
||||||
skip_search_reorder = kwargs.get("skip_search_reorder", False)
|
raise RuntimeError(
|
||||||
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
|
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||||
dedup_node_dis = kwargs.get("dedup_node_dis", False)
|
)
|
||||||
prune_ratio = kwargs.get("prune_ratio", 0.0)
|
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||||
batch_recompute = kwargs.get("batch_recompute", False)
|
|
||||||
global_pruning = kwargs.get("global_pruning", False)
|
|
||||||
|
|
||||||
if recompute_beighbor_embeddings:
|
|
||||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
|
||||||
zmq_port = kwargs.get("zmq_port", 5555)
|
|
||||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
|
||||||
|
|
||||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
|
||||||
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
|
||||||
kwargs['recompute_beighbor_embeddings'] = False
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
if query.ndim == 1:
|
|
||||||
query = np.expand_dims(query, axis=0)
|
|
||||||
|
|
||||||
try:
|
# Map pruning_strategy to DiskANN's global_pruning parameter
|
||||||
labels, distances = self._index.batch_search(
|
if pruning_strategy == "local":
|
||||||
query,
|
use_global_pruning = False
|
||||||
query.shape[0],
|
else: # "global"
|
||||||
top_k,
|
use_global_pruning = True
|
||||||
complexity,
|
|
||||||
beam_width,
|
|
||||||
self.num_threads,
|
|
||||||
USE_DEFERRED_FETCH,
|
|
||||||
skip_search_reorder,
|
|
||||||
recompute_beighbor_embeddings,
|
|
||||||
dedup_node_dis,
|
|
||||||
prune_ratio,
|
|
||||||
batch_recompute,
|
|
||||||
global_pruning
|
|
||||||
)
|
|
||||||
return {"labels": labels, "distances": distances}
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
|
|
||||||
batch_size = query.shape[0]
|
|
||||||
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
|
|
||||||
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
|
||||||
|
|
||||||
def __del__(self):
|
labels, distances = self._index.batch_search(
|
||||||
if hasattr(self, 'embedding_server_manager'):
|
query,
|
||||||
self.embedding_server_manager.stop_server()
|
query.shape[0],
|
||||||
|
top_k,
|
||||||
|
complexity,
|
||||||
|
beam_width,
|
||||||
|
self.num_threads,
|
||||||
|
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||||
|
kwargs.get("skip_search_reorder", False),
|
||||||
|
use_recompute,
|
||||||
|
dedup_node_dis,
|
||||||
|
prune_ratio,
|
||||||
|
batch_recompute,
|
||||||
|
use_global_pruning,
|
||||||
|
)
|
||||||
|
|
||||||
|
string_labels = [
|
||||||
|
[str(int_label) for int_label in batch_labels]
|
||||||
|
for batch_labels in labels
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -5,70 +5,147 @@ Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
import argparse
|
import argparse
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Optional, Union
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import zmq
|
import zmq
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import msgpack
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
# 简化的文档存储 - 替代 LazyPassages
|
# --- New Passage Loader from HNSW backend ---
|
||||||
class SimpleDocumentStore:
|
class SimplePassageLoader:
|
||||||
"""简化的文档存储,支持任意ID"""
|
"""
|
||||||
def __init__(self, documents: dict = None):
|
Simple passage loader that replaces config.py dependencies
|
||||||
self.documents = documents or {}
|
"""
|
||||||
# 默认演示文档
|
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||||
self.default_docs = {
|
self.passages_data = passages_data or {}
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
self._meta_path = ''
|
||||||
1: "Machine learning builds systems that learn from data.",
|
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __getitem__(self, doc_id):
|
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||||
doc_id = int(doc_id)
|
"""Get passage by ID"""
|
||||||
|
str_id = str(passage_id)
|
||||||
|
if str_id in self.passages_data:
|
||||||
|
return {"text": self.passages_data[str_id]}
|
||||||
|
else:
|
||||||
|
# Return empty text for missing passages
|
||||||
|
return {"text": ""}
|
||||||
|
|
||||||
# 优先使用指定的文档
|
def __len__(self) -> int:
|
||||||
if doc_id in self.documents:
|
return len(self.passages_data)
|
||||||
return {"text": self.documents[doc_id]}
|
|
||||||
|
|
||||||
# 其次使用默认演示文档
|
def keys(self):
|
||||||
if doc_id in self.default_docs:
|
return self.passages_data.keys()
|
||||||
return {"text": self.default_docs[doc_id]}
|
|
||||||
|
|
||||||
# 对于任意其他ID,返回通用文档
|
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||||
fallback_docs = [
|
"""
|
||||||
"This is a general document about technology and programming concepts.",
|
Load passages using metadata file with PassageManager for lazy loading
|
||||||
"This document discusses machine learning and artificial intelligence topics.",
|
"""
|
||||||
"This content covers data structures, algorithms, and computer science fundamentals.",
|
# Load metadata to get passage sources
|
||||||
"This is a document about software engineering and development practices.",
|
with open(meta_file, 'r') as f:
|
||||||
"This content focuses on databases, data management, and information systems."
|
meta = json.load(f)
|
||||||
]
|
|
||||||
|
|
||||||
# 根据ID选择一个fallback文档
|
# Import PassageManager dynamically to avoid circular imports
|
||||||
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
|
import sys
|
||||||
return {"text": f"[ID:{doc_id}] {fallback_text}"}
|
from pathlib import Path
|
||||||
|
|
||||||
def __len__(self):
|
# Find the leann package directory relative to this file
|
||||||
return len(self.documents) + len(self.default_docs)
|
current_dir = Path(__file__).parent
|
||||||
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||||
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.api import PassageManager
|
||||||
|
passage_manager = PassageManager(meta['passage_sources'])
|
||||||
|
finally:
|
||||||
|
sys.path.pop(0)
|
||||||
|
|
||||||
|
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
|
||||||
|
|
||||||
|
class LazyPassageLoader(SimplePassageLoader):
|
||||||
|
def __init__(self, passage_manager):
|
||||||
|
self.passage_manager = passage_manager
|
||||||
|
# Initialize parent with empty data
|
||||||
|
super().__init__({})
|
||||||
|
|
||||||
|
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||||
|
"""Get passage by ID with lazy loading"""
|
||||||
|
try:
|
||||||
|
int_id = int(passage_id)
|
||||||
|
string_id = str(int_id)
|
||||||
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
|
if passage_data and passage_data.get("text"):
|
||||||
|
return {"text": passage_data["text"]}
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.passage_manager.global_offset_map)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.passage_manager.global_offset_map.keys()
|
||||||
|
|
||||||
|
loader = LazyPassageLoader(passage_manager)
|
||||||
|
loader._meta_path = meta_file
|
||||||
|
return loader
|
||||||
|
|
||||||
|
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||||
|
"""
|
||||||
|
Load passages from a JSONL file with label map support
|
||||||
|
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not os.path.exists(passages_file):
|
||||||
|
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||||
|
|
||||||
|
if not passages_file.endswith('.jsonl'):
|
||||||
|
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||||
|
|
||||||
|
# Load passages directly by their sequential IDs
|
||||||
|
passages_data = {}
|
||||||
|
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
passages_data[passage['id']] = passage['text']
|
||||||
|
|
||||||
|
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
|
||||||
|
return SimplePassageLoader(passages_data)
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
def create_embedding_server_thread(
|
||||||
zmq_port=5555,
|
zmq_port=5555,
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
|
passages_file: Optional[str] = None,
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
enable_warmup: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
在当前线程中创建并运行 embedding server
|
Create and run embedding server in the current thread
|
||||||
这个函数设计为在单独的线程中调用
|
This function is designed to be called in a separate thread
|
||||||
"""
|
"""
|
||||||
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
|
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查端口是否已被占用
|
# Check if port is already occupied
|
||||||
import socket
|
import socket
|
||||||
def check_port(port):
|
def check_port(port):
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
@@ -78,56 +155,147 @@ def create_embedding_server_thread(
|
|||||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 初始化模型
|
# Auto-detect mode based on model name if not explicitly set
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||||
import torch
|
embedding_mode = "openai"
|
||||||
|
|
||||||
# 选择设备
|
if embedding_mode == "mlx":
|
||||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
from leann.api import compute_embeddings_mlx
|
||||||
cuda_available = torch.cuda.is_available()
|
import torch
|
||||||
|
logger.info("Using MLX for embeddings")
|
||||||
if cuda_available:
|
# Set device to CPU for compatibility with DeviceTimer class
|
||||||
device = torch.device("cuda")
|
|
||||||
print("INFO: Using CUDA device")
|
|
||||||
elif mps_available:
|
|
||||||
device = torch.device("mps")
|
|
||||||
print("INFO: Using MPS device (Apple Silicon)")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
print("INFO: Using CPU device")
|
cuda_available = False
|
||||||
|
mps_available = False
|
||||||
|
elif embedding_mode == "openai":
|
||||||
|
from leann.api import compute_embeddings_openai
|
||||||
|
import torch
|
||||||
|
logger.info("Using OpenAI API for embeddings")
|
||||||
|
# Set device to CPU for compatibility with DeviceTimer class
|
||||||
|
device = torch.device("cpu")
|
||||||
|
cuda_available = False
|
||||||
|
mps_available = False
|
||||||
|
elif embedding_mode == "sentence-transformers":
|
||||||
|
# Initialize model
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
import torch
|
||||||
|
|
||||||
# 加载模型
|
# Select device
|
||||||
print(f"INFO: Loading model {model_name}")
|
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
cuda_available = torch.cuda.is_available()
|
||||||
|
|
||||||
|
if cuda_available:
|
||||||
|
device = torch.device("cuda")
|
||||||
|
logger.info("Using CUDA device")
|
||||||
|
elif mps_available:
|
||||||
|
device = torch.device("mps")
|
||||||
|
logger.info("Using MPS device (Apple Silicon)")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
logger.info("Using CPU device")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
logger.info(f"Loading model {model_name}")
|
||||||
|
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||||
|
|
||||||
|
# Optimize model
|
||||||
|
if cuda_available or mps_available:
|
||||||
|
try:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
logger.info(f"Using FP16 precision with model: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WARNING: Model optimization failed: {e}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||||
|
|
||||||
|
# Load passages from file if provided
|
||||||
|
if passages_file and os.path.exists(passages_file):
|
||||||
|
# Check if it's a metadata file or a single passages file
|
||||||
|
if passages_file.endswith('.meta.json'):
|
||||||
|
passages = load_passages_from_metadata(passages_file)
|
||||||
|
else:
|
||||||
|
# Try to find metadata file in same directory
|
||||||
|
passages_dir = Path(passages_file).parent
|
||||||
|
meta_files = list(passages_dir.glob("*.meta.json"))
|
||||||
|
if meta_files:
|
||||||
|
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
|
||||||
|
passages = load_passages_from_metadata(str(meta_files[0]))
|
||||||
|
else:
|
||||||
|
# Fallback to original single file loading (will cause warnings)
|
||||||
|
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
|
||||||
|
passages = load_passages_from_file(passages_file)
|
||||||
|
else:
|
||||||
|
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||||
|
passages = SimplePassageLoader()
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(passages)} passages.")
|
||||||
|
|
||||||
|
def client_warmup(zmq_port):
|
||||||
|
"""Perform client-side warmup for DiskANN server"""
|
||||||
|
time.sleep(2)
|
||||||
|
print(f"Performing client-side warmup with model {model_name}...")
|
||||||
|
|
||||||
|
# Get actual passage IDs from the loaded passages
|
||||||
|
sample_ids = []
|
||||||
|
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||||
|
available_ids = list(passages.keys())
|
||||||
|
# Take up to 5 actual IDs, but at least 1
|
||||||
|
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||||
|
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||||
|
else:
|
||||||
|
print("No passages available for warmup, skipping warmup...")
|
||||||
|
return
|
||||||
|
|
||||||
# 优化模型
|
|
||||||
if cuda_available or mps_available:
|
|
||||||
try:
|
try:
|
||||||
model = model.half()
|
context = zmq.Context()
|
||||||
model = torch.compile(model)
|
socket = context.socket(zmq.REQ)
|
||||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||||
|
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ids_to_send = [int(x) for x in sample_ids]
|
||||||
|
except ValueError:
|
||||||
|
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not ids_to_send:
|
||||||
|
print("Skipping warmup send.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use protobuf format for warmup
|
||||||
|
from . import embedding_pb2
|
||||||
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
req_proto.node_ids.extend(ids_to_send)
|
||||||
|
request_bytes = req_proto.SerializeToString()
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
||||||
|
socket.send(request_bytes)
|
||||||
|
response_bytes = socket.recv()
|
||||||
|
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
resp_proto.ParseFromString(response_bytes)
|
||||||
|
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
||||||
|
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
print("Client-side Protobuf ZMQ warmup complete")
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
print(f"Error during Protobuf ZMQ warmup: {e}")
|
||||||
|
|
||||||
# 默认演示文档
|
|
||||||
demo_documents = {
|
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
|
||||||
1: "Machine learning builds systems that learn from data.",
|
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
|
||||||
}
|
|
||||||
|
|
||||||
passages = SimpleDocumentStore(demo_documents)
|
|
||||||
print(f"INFO: Loaded {len(passages)} demo documents")
|
|
||||||
|
|
||||||
class DeviceTimer:
|
class DeviceTimer:
|
||||||
"""设备计时器"""
|
"""Device timer"""
|
||||||
def __init__(self, name="", device=device):
|
def __init__(self, name="", device=device):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.device = device
|
self.device = device
|
||||||
self.start_time = 0
|
self.start_time = 0
|
||||||
self.end_time = 0
|
self.end_time = 0
|
||||||
|
|
||||||
if cuda_available:
|
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||||
else:
|
else:
|
||||||
@@ -141,136 +309,249 @@ def create_embedding_server_thread(
|
|||||||
self.end()
|
self.end()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
if cuda_available:
|
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.start_event.record()
|
self.start_event.record()
|
||||||
else:
|
else:
|
||||||
if self.device.type == "mps":
|
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
if cuda_available:
|
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||||
self.end_event.record()
|
self.end_event.record()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
else:
|
else:
|
||||||
if self.device.type == "mps":
|
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
self.end_time = time.time()
|
self.end_time = time.time()
|
||||||
|
|
||||||
def elapsed_time(self):
|
def elapsed_time(self):
|
||||||
if cuda_available:
|
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||||
else:
|
else:
|
||||||
return self.end_time - self.start_time
|
return self.end_time - self.start_time
|
||||||
|
|
||||||
def print_elapsed(self):
|
def print_elapsed(self):
|
||||||
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
|
elapsed = self.elapsed_time()
|
||||||
|
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
|
||||||
|
|
||||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||||
"""处理文本批次"""
|
"""Process text batch"""
|
||||||
batch_size = len(texts_batch)
|
if not texts_batch:
|
||||||
print(f"INFO: Processing batch of size {batch_size}")
|
return np.array([])
|
||||||
|
|
||||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
# Filter out empty texts and their corresponding IDs
|
||||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
valid_texts = []
|
||||||
embed_timer = DeviceTimer("embedding (batch)", device)
|
valid_ids = []
|
||||||
pool_timer = DeviceTimer("mean pooling (batch)", device)
|
for i, text in enumerate(texts_batch):
|
||||||
|
if text.strip(): # Only include non-empty texts
|
||||||
|
valid_texts.append(text)
|
||||||
|
valid_ids.append(ids_batch[i])
|
||||||
|
|
||||||
with tokenize_timer.timing():
|
if not valid_texts:
|
||||||
encoded_batch = tokenizer.batch_encode_plus(
|
print("WARNING: No valid texts in batch")
|
||||||
texts_batch,
|
return np.array([])
|
||||||
padding="max_length",
|
|
||||||
|
# Tokenize
|
||||||
|
token_timer = DeviceTimer("tokenization")
|
||||||
|
with token_timer.timing():
|
||||||
|
inputs = tokenizer(
|
||||||
|
valid_texts,
|
||||||
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=256,
|
max_length=512,
|
||||||
return_tensors="pt",
|
return_tensors="pt"
|
||||||
return_token_type_ids=False,
|
).to(device)
|
||||||
)
|
|
||||||
tokenize_timer.print_elapsed()
|
|
||||||
|
|
||||||
seq_length = encoded_batch["input_ids"].size(1)
|
# Compute embeddings
|
||||||
print(f"Batch size: {batch_size}, Sequence length: {seq_length}")
|
embed_timer = DeviceTimer("embedding computation")
|
||||||
|
with embed_timer.timing():
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
|
||||||
with to_device_timer.timing():
|
# Mean pooling
|
||||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
attention_mask = inputs['attention_mask']
|
||||||
to_device_timer.print_elapsed()
|
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
with embed_timer.timing():
|
|
||||||
out = model(enc["input_ids"], enc["attention_mask"])
|
|
||||||
embed_timer.print_elapsed()
|
|
||||||
|
|
||||||
with pool_timer.timing():
|
|
||||||
hidden_states = out.last_hidden_state if hasattr(out, "last_hidden_state") else out
|
|
||||||
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
|
|
||||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||||
batch_embeddings = sum_embeddings / sum_mask
|
batch_embeddings = sum_embeddings / sum_mask
|
||||||
pool_timer.print_elapsed()
|
embed_timer.print_elapsed()
|
||||||
|
|
||||||
return batch_embeddings.cpu().numpy()
|
return batch_embeddings.cpu().numpy()
|
||||||
|
|
||||||
# ZMQ server 主循环 - 修改为REP套接字
|
# ZMQ server main loop - modified to use REP socket
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.ROUTER) # 改为REP套接字
|
socket = context.socket(zmq.ROUTER) # Changed to REP socket
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
||||||
|
|
||||||
# 设置超时
|
# Set timeouts
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5秒接收超时
|
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300秒发送超时
|
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
|
||||||
|
|
||||||
from . import embedding_pb2
|
from . import embedding_pb2
|
||||||
|
|
||||||
print(f"INFO: Embedding server ready to serve requests")
|
print(f"INFO: Embedding server ready to serve requests")
|
||||||
|
|
||||||
|
# Start warmup thread if enabled
|
||||||
|
if enable_warmup and len(passages) > 0:
|
||||||
|
import threading
|
||||||
|
print(f"Warmup enabled: starting warmup thread")
|
||||||
|
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||||
|
warmup_thread.daemon = True
|
||||||
|
warmup_thread.start()
|
||||||
|
else:
|
||||||
|
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
parts = socket.recv_multipart()
|
parts = socket.recv_multipart()
|
||||||
|
|
||||||
# --- 恢复稳健的消息格式判断 ---
|
# --- Restore robust message format detection ---
|
||||||
# 必须检查 parts 的长度,避免 IndexError
|
# Must check parts length to avoid IndexError
|
||||||
if len(parts) >= 3:
|
if len(parts) >= 3:
|
||||||
identity = parts[0]
|
identity = parts[0]
|
||||||
# empty = parts[1] # 中间的空帧我们通常不关心
|
# empty = parts[1] # We usually don't care about the middle empty frame
|
||||||
message = parts[2]
|
message = parts[2]
|
||||||
elif len(parts) == 2:
|
elif len(parts) == 2:
|
||||||
# 也能处理没有空帧的情况
|
# Can also handle cases without empty frame
|
||||||
identity = parts[0]
|
identity = parts[0]
|
||||||
message = parts[1]
|
message = parts[1]
|
||||||
else:
|
else:
|
||||||
# 如果收到格式错误的消息,打印警告并忽略它,而不是崩溃
|
# If received message format is wrong, print warning and ignore it instead of crashing
|
||||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
||||||
continue
|
continue
|
||||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
||||||
|
|
||||||
e2e_start = time.time()
|
# Handle control messages (MessagePack format)
|
||||||
lookup_timer = DeviceTimer("text lookup", device)
|
try:
|
||||||
|
request_payload = msgpack.unpackb(message)
|
||||||
|
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||||
|
if request_payload[0] == "__QUERY_META_PATH__":
|
||||||
|
# Return the current meta path being used by the server
|
||||||
|
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
|
||||||
|
response = [current_meta_path]
|
||||||
|
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||||
|
continue
|
||||||
|
|
||||||
# 解析请求
|
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
|
||||||
|
# Update the server's meta path and reload passages
|
||||||
|
new_meta_path = request_payload[1]
|
||||||
|
try:
|
||||||
|
print(f"INFO: Updating server meta path to: {new_meta_path}")
|
||||||
|
# Reload passages from the new meta file
|
||||||
|
passages = load_passages_from_metadata(new_meta_path)
|
||||||
|
# Store the meta path for future queries
|
||||||
|
passages._meta_path = new_meta_path
|
||||||
|
response = ["SUCCESS"]
|
||||||
|
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed to update meta path: {e}")
|
||||||
|
response = ["FAILED", str(e)]
|
||||||
|
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif request_payload[0] == "__QUERY_MODEL__":
|
||||||
|
# Return the current model being used by the server
|
||||||
|
response = [model_name]
|
||||||
|
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
|
||||||
|
# Update the server's embedding model
|
||||||
|
new_model_name = request_payload[1]
|
||||||
|
try:
|
||||||
|
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
|
||||||
|
|
||||||
|
# Clean up old model to free memory
|
||||||
|
if not use_mlx:
|
||||||
|
print("INFO: Releasing old model from memory...")
|
||||||
|
old_model = model
|
||||||
|
old_tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Load new tokenizer first
|
||||||
|
print(f"Loading new tokenizer for {new_model_name}...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
|
||||||
|
|
||||||
|
# Load new model
|
||||||
|
print(f"Loading new model {new_model_name}...")
|
||||||
|
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
|
||||||
|
|
||||||
|
# Optimize new model
|
||||||
|
if cuda_available or mps_available:
|
||||||
|
try:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
print(f"INFO: Using FP16 precision with model: {new_model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WARNING: Model optimization failed: {e}")
|
||||||
|
|
||||||
|
# Now safely delete old model after new one is loaded
|
||||||
|
del old_model
|
||||||
|
del old_tokenizer
|
||||||
|
|
||||||
|
# Clear GPU cache if available
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("INFO: Cleared CUDA cache")
|
||||||
|
elif device.type == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
print("INFO: Cleared MPS cache")
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
print("INFO: Memory cleanup completed")
|
||||||
|
|
||||||
|
# Update model name
|
||||||
|
model_name = new_model_name
|
||||||
|
|
||||||
|
response = ["SUCCESS"]
|
||||||
|
print(f"INFO: Successfully updated model to: {new_model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed to update model: {e}")
|
||||||
|
response = ["FAILED", str(e)]
|
||||||
|
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
# Not a control message, continue with normal protobuf processing
|
||||||
|
pass
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
lookup_timer = DeviceTimer("text lookup")
|
||||||
|
|
||||||
|
# Parse request
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
req_proto.ParseFromString(message)
|
req_proto.ParseFromString(message)
|
||||||
node_ids = req_proto.node_ids
|
node_ids = req_proto.node_ids
|
||||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
||||||
|
|
||||||
# 添加调试信息
|
# Add debug information
|
||||||
if len(node_ids) > 0:
|
if len(node_ids) > 0:
|
||||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
||||||
|
|
||||||
# 查找文本
|
# Look up texts
|
||||||
texts = []
|
texts = []
|
||||||
missing_ids = []
|
missing_ids = []
|
||||||
with lookup_timer.timing():
|
with lookup_timer.timing():
|
||||||
for nid in node_ids:
|
for nid in node_ids:
|
||||||
txtinfo = passages[nid]
|
txtinfo = passages[nid]
|
||||||
txt = txtinfo["text"]
|
txt = txtinfo["text"]
|
||||||
texts.append(txt)
|
if txt:
|
||||||
|
texts.append(txt)
|
||||||
|
else:
|
||||||
|
# If text is empty, we still need a placeholder for batch processing,
|
||||||
|
# but record its ID as missing
|
||||||
|
texts.append("")
|
||||||
|
missing_ids.append(nid)
|
||||||
lookup_timer.print_elapsed()
|
lookup_timer.print_elapsed()
|
||||||
|
|
||||||
if missing_ids:
|
if missing_ids:
|
||||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
||||||
|
|
||||||
# 处理批次
|
# Process batch
|
||||||
total_size = len(texts)
|
total_size = len(texts)
|
||||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
||||||
|
|
||||||
@@ -285,20 +566,31 @@ def create_embedding_server_thread(
|
|||||||
chunk_texts = texts[i:end_idx]
|
chunk_texts = texts[i:end_idx]
|
||||||
chunk_ids = node_ids[i:end_idx]
|
chunk_ids = node_ids[i:end_idx]
|
||||||
|
|
||||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
if embedding_mode == "mlx":
|
||||||
|
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
|
||||||
|
elif embedding_mode == "openai":
|
||||||
|
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
|
||||||
|
else: # sentence-transformers
|
||||||
|
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
|
||||||
all_embeddings.append(embeddings_chunk)
|
all_embeddings.append(embeddings_chunk)
|
||||||
|
|
||||||
if cuda_available:
|
if embedding_mode == "sentence-transformers":
|
||||||
torch.cuda.empty_cache()
|
if cuda_available:
|
||||||
elif device.type == "mps":
|
torch.cuda.empty_cache()
|
||||||
torch.mps.empty_cache()
|
elif device.type == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
hidden = np.vstack(all_embeddings)
|
hidden = np.vstack(all_embeddings)
|
||||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
||||||
else:
|
else:
|
||||||
hidden = process_batch(texts, node_ids, missing_ids)
|
if embedding_mode == "mlx":
|
||||||
|
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
|
||||||
|
elif embedding_mode == "openai":
|
||||||
|
hidden = compute_embeddings_openai(texts, model_name)
|
||||||
|
else: # sentence-transformers
|
||||||
|
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
|
||||||
|
|
||||||
# 序列化响应
|
# Serialize response
|
||||||
ser_start = time.time()
|
ser_start = time.time()
|
||||||
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
@@ -310,32 +602,32 @@ def create_embedding_server_thread(
|
|||||||
|
|
||||||
response_data = resp_proto.SerializeToString()
|
response_data = resp_proto.SerializeToString()
|
||||||
|
|
||||||
# REP 套接字发送单个响应
|
# REP socket sends a single response
|
||||||
socket.send_multipart([identity, b'', response_data])
|
socket.send_multipart([identity, b'', response_data])
|
||||||
|
|
||||||
ser_end = time.time()
|
ser_end = time.time()
|
||||||
|
|
||||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
||||||
|
|
||||||
if device.type == "cuda":
|
if embedding_mode == "sentence-transformers":
|
||||||
torch.cuda.synchronize()
|
if device.type == "cuda":
|
||||||
elif device.type == "mps":
|
torch.cuda.synchronize()
|
||||||
torch.mps.synchronize()
|
elif device.type == "mps":
|
||||||
|
torch.mps.synchronize()
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||||
|
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
print("INFO: ZMQ socket timeout, continuing to listen")
|
||||||
# REP套接字不需要重新创建,只需要继续监听
|
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR: Error in ZMQ server: {e}")
|
print(f"ERROR: Error in ZMQ server: {e}")
|
||||||
try:
|
try:
|
||||||
# 发送空响应以维持REQ-REP状态
|
# Send empty response to maintain REQ-REP state
|
||||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
||||||
socket.send(empty_resp.SerializeToString())
|
socket.send(empty_resp.SerializeToString())
|
||||||
except:
|
except:
|
||||||
# 如果发送失败,重新创建socket
|
# If sending fails, recreate socket
|
||||||
socket.close()
|
socket.close()
|
||||||
socket = context.socket(zmq.REP)
|
socket = context.socket(zmq.REP)
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||||
@@ -348,7 +640,6 @@ def create_embedding_server_thread(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
|
|
||||||
def create_embedding_server(
|
def create_embedding_server(
|
||||||
domain="demo",
|
domain="demo",
|
||||||
load_passages=True,
|
load_passages=True,
|
||||||
@@ -360,18 +651,22 @@ def create_embedding_server(
|
|||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
lazy_load_passages=False,
|
lazy_load_passages=False,
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
passages_file: Optional[str] = None,
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
enable_warmup: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
原有的 create_embedding_server 函数保持不变
|
原有的 create_embedding_server 函数保持不变
|
||||||
这个是阻塞版本,用于直接运行
|
这个是阻塞版本,用于直接运行
|
||||||
"""
|
"""
|
||||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
|
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Embedding service")
|
parser = argparse.ArgumentParser(description="Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||||
|
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||||
@@ -381,8 +676,18 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
||||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||||
help="Embedding model name")
|
help="Embedding model name")
|
||||||
|
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "mlx", "openai"],
|
||||||
|
help="Embedding backend mode")
|
||||||
|
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
|
||||||
|
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Handle backward compatibility with use_mlx
|
||||||
|
embedding_mode = args.embedding_mode
|
||||||
|
if args.use_mlx:
|
||||||
|
embedding_mode = "mlx"
|
||||||
|
|
||||||
create_embedding_server(
|
create_embedding_server(
|
||||||
domain=args.domain,
|
domain=args.domain,
|
||||||
load_passages=args.load_passages,
|
load_passages=args.load_passages,
|
||||||
@@ -394,4 +699,7 @@ if __name__ == "__main__":
|
|||||||
max_batch_size=args.max_batch_size,
|
max_batch_size=args.max_batch_size,
|
||||||
lazy_load_passages=args.lazy_load_passages,
|
lazy_load_passages=args.lazy_load_passages,
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
|
passages_file=args.passages_file,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
enable_warmup=not args.disable_warmup,
|
||||||
)
|
)
|
||||||
@@ -8,9 +8,12 @@ version = "0.1.0"
|
|||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# 关键:简化的 CMake 路径
|
# Key: simplified CMake path
|
||||||
cmake.source-dir = "third_party/DiskANN"
|
cmake.source-dir = "third_party/DiskANN"
|
||||||
# 关键:Python 包在根目录,路径完全匹配
|
# Key: Python package in root directory, paths match exactly
|
||||||
wheel.packages = ["leann_backend_diskann"]
|
wheel.packages = ["leann_backend_diskann"]
|
||||||
# 使用默认的 redirect 模式
|
# Use default redirect mode
|
||||||
editable.mode = "redirect"
|
editable.mode = "redirect"
|
||||||
|
cmake.build-type = "Release"
|
||||||
|
build.verbose = true
|
||||||
|
build.tool-args = ["-j8"]
|
||||||
1
packages/leann-backend-diskann/third_party/DiskANN
vendored
Submodule
1
packages/leann-backend-diskann/third_party/DiskANN
vendored
Submodule
Submodule packages/leann-backend-diskann/third_party/DiskANN added at af2a26481e
@@ -1,6 +0,0 @@
|
|||||||
---
|
|
||||||
BasedOnStyle: Microsoft
|
|
||||||
---
|
|
||||||
Language: Cpp
|
|
||||||
SortIncludes: false
|
|
||||||
...
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
# Set the default behavior, in case people don't have core.autocrlf set.
|
|
||||||
* text=auto
|
|
||||||
|
|
||||||
# Explicitly declare text files you want to always be normalized and converted
|
|
||||||
# to native line endings on checkout.
|
|
||||||
*.c text
|
|
||||||
*.h text
|
|
||||||
|
|
||||||
# Declare files that will always have CRLF line endings on checkout.
|
|
||||||
*.sln text eol=crlf
|
|
||||||
|
|
||||||
# Denote all files that are truly binary and should not be modified.
|
|
||||||
*.png binary
|
|
||||||
*.jpg binary
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
---
|
|
||||||
name: Bug report
|
|
||||||
about: Bug reports help us improve! Thanks for submitting yours!
|
|
||||||
title: "[BUG] "
|
|
||||||
labels: bug
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Expected Behavior
|
|
||||||
Tell us what should happen
|
|
||||||
|
|
||||||
## Actual Behavior
|
|
||||||
Tell us what happens instead
|
|
||||||
|
|
||||||
## Example Code
|
|
||||||
Please see [How to create a Minimal, Reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) for some guidance on creating the best possible example of the problem
|
|
||||||
```bash
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dataset Description
|
|
||||||
Please tell us about the shape and datatype of your data, (e.g. 128 dimensions, 12.3 billion points, floats)
|
|
||||||
- Dimensions:
|
|
||||||
- Number of Points:
|
|
||||||
- Data type:
|
|
||||||
|
|
||||||
## Error
|
|
||||||
```
|
|
||||||
Paste the full error, with any sensitive information minimally redacted and marked $$REDACTED$$
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
## Your Environment
|
|
||||||
* Operating system (e.g. Windows 11 Pro, Ubuntu 22.04.1 LTS)
|
|
||||||
* DiskANN version (or commit built from)
|
|
||||||
|
|
||||||
## Additional Details
|
|
||||||
Any other contextual information you might feel is important.
|
|
||||||
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
blank_issues_enabled: false
|
|
||||||
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
---
|
|
||||||
name: Feature request
|
|
||||||
about: Suggest an idea for this project
|
|
||||||
title: ''
|
|
||||||
labels: enhancement
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Is your feature request related to a problem? Please describe.
|
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
|
||||||
|
|
||||||
## Describe the solution you'd like
|
|
||||||
A clear and concise description of what you want to happen.
|
|
||||||
|
|
||||||
## Describe alternatives you've considered
|
|
||||||
A clear and concise description of any alternative solutions or features you've considered.
|
|
||||||
|
|
||||||
## Provide references (if applicable)
|
|
||||||
If your feature request is related to a published algorithm/idea, please provide links to
|
|
||||||
any relevant articles or webpages.
|
|
||||||
|
|
||||||
## Additional context
|
|
||||||
Add any other context or screenshots about the feature request here.
|
|
||||||
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
---
|
|
||||||
name: Usage Question
|
|
||||||
about: Ask us a question about DiskANN!
|
|
||||||
title: "[Question]"
|
|
||||||
labels: question
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
This is our forum for asking whatever DiskANN question you'd like! No need to feel shy - we're happy to talk about use cases and optimal tuning strategies!
|
|
||||||
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
<!--
|
|
||||||
Thanks for contributing a pull request! Please ensure you have taken a look at
|
|
||||||
the contribution guidelines: https://github.com/microsoft/DiskANN/blob/main/CONTRIBUTING.md
|
|
||||||
-->
|
|
||||||
- [ ] Does this PR have a descriptive title that could go in our release notes?
|
|
||||||
- [ ] Does this PR add any new dependencies?
|
|
||||||
- [ ] Does this PR modify any existing APIs?
|
|
||||||
- [ ] Is the change to the API backwards compatible?
|
|
||||||
- [ ] Should this result in any changes to our documentation, either updating existing docs or adding new ones?
|
|
||||||
|
|
||||||
#### Reference Issues/PRs
|
|
||||||
<!--
|
|
||||||
Example: Fixes #1234. See also #3456.
|
|
||||||
Please use keywords (e.g., Fixes) to create link to the issues or pull requests
|
|
||||||
you resolved, so that they will automatically be closed when your pull request
|
|
||||||
is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests
|
|
||||||
-->
|
|
||||||
|
|
||||||
#### What does this implement/fix? Briefly explain your changes.
|
|
||||||
|
|
||||||
#### Any other comments?
|
|
||||||
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
name: 'DiskANN Build Bootstrap'
|
|
||||||
description: 'Prepares DiskANN build environment and executes build'
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
# ------------ Linux Build ---------------
|
|
||||||
- name: Prepare and Execute Build
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
run: |
|
|
||||||
sudo scripts/dev/install-dev-deps-ubuntu.bash
|
|
||||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
|
|
||||||
cmake --build build -- -j
|
|
||||||
cmake --install build --prefix="dist"
|
|
||||||
shell: bash
|
|
||||||
# ------------ End Linux Build ---------------
|
|
||||||
# ------------ Windows Build ---------------
|
|
||||||
- name: Add VisualStudio command line tools into path
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
uses: ilammy/msvc-dev-cmd@v1
|
|
||||||
- name: Run configure and build for Windows
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
run: |
|
|
||||||
mkdir build && cd build && cmake .. -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
|
|
||||||
cd ..
|
|
||||||
mkdir dist
|
|
||||||
mklink /j .\dist\bin .\x64\Release\
|
|
||||||
shell: cmd
|
|
||||||
# ------------ End Windows Build ---------------
|
|
||||||
# ------------ Windows Build With EXEC_ENV_OLS and USE_BING_INFRA ---------------
|
|
||||||
- name: Add VisualStudio command line tools into path
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
uses: ilammy/msvc-dev-cmd@v1
|
|
||||||
- name: Run configure and build for Windows with Bing feature flags
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
run: |
|
|
||||||
mkdir build_bing && cd build_bing && cmake .. -DEXEC_ENV_OLS=1 -DUSE_BING_INFRA=1 -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
|
|
||||||
cd ..
|
|
||||||
shell: cmd
|
|
||||||
# ------------ End Windows Build ---------------
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
name: 'Checking code formatting...'
|
|
||||||
description: 'Ensures code complies with code formatting rules'
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Checking code formatting...
|
|
||||||
run: |
|
|
||||||
sudo apt install clang-format
|
|
||||||
find include -name '*.h' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
|
||||||
find src -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
|
||||||
find apps -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
|
||||||
find python -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
|
||||||
shell: bash
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
name: 'Generating Random Data (Basic)'
|
|
||||||
description: 'Generates the random data files used in acceptance tests'
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Generate Random Data (Basic)
|
|
||||||
run: |
|
|
||||||
mkdir data
|
|
||||||
|
|
||||||
echo "Generating random 1020,1024,1536D float and 4096 int8 vectors for index"
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_5K_norm1.0.bin -D 1020 -N 5000 --norm 1.0
|
|
||||||
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_5K_norm1.0.bin -D 1024 -N 5000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_5K_norm1.0.bin -D 1536 -N 5000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_5K_norm1.0.bin -D 4096 -N 5000 --norm 1.0
|
|
||||||
|
|
||||||
echo "Generating random 1020,1024,1536D float and 4096D int8 avectors for query"
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_1K_norm1.0.bin -D 1020 -N 1000 --norm 1.0
|
|
||||||
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_1K_norm1.0.bin -D 1024 -N 1000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_1K_norm1.0.bin -D 1536 -N 1000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_1K_norm1.0.bin -D 4096 -N 1000 --norm 1.0
|
|
||||||
|
|
||||||
echo "Computing ground truth for 1020,1024,1536D float and 4096D int8 avectors for query"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1020D_5K_norm1.0.bin --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --K 100
|
|
||||||
#dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1024D_5K_norm1.0.bin --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1536D_5K_norm1.0.bin --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_4096D_5K_norm1.0.bin --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --K 100
|
|
||||||
|
|
||||||
shell: bash
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
name: 'Generating Random Data (Basic)'
|
|
||||||
description: 'Generates the random data files used in acceptance tests'
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Generate Random Data (Basic)
|
|
||||||
run: |
|
|
||||||
mkdir data
|
|
||||||
|
|
||||||
echo "Generating random vectors for index"
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_norm1.0.bin -D 10 -N 10000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_unnorm.bin -D 10 -N 10000 --rand_scaling 2.0
|
|
||||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
|
||||||
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
|
||||||
|
|
||||||
echo "Generating random vectors for query"
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_norm1.0.bin -D 10 -N 1000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_unnorm.bin -D 10 -N 1000 --rand_scaling 2.0
|
|
||||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
|
||||||
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
|
||||||
|
|
||||||
echo "Computing ground truth for floats across l2, mips, and cosine distance functions"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn mips --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_unnorm.bin --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --K 100
|
|
||||||
|
|
||||||
echo "Computing ground truth for int8s across l2, mips, and cosine distance functions"
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn mips --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/mips_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn cosine --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
|
|
||||||
echo "Computing ground truth for uint8s across l2, mips, and cosine distance functions"
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn mips --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
|
|
||||||
shell: bash
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
name: Build Python Wheel
|
|
||||||
description: Builds a python wheel with cibuildwheel
|
|
||||||
inputs:
|
|
||||||
cibw-identifier:
|
|
||||||
description: "CI build wheel identifier to build"
|
|
||||||
required: true
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- uses: actions/setup-python@v3
|
|
||||||
- name: Install cibuildwheel
|
|
||||||
run: python -m pip install cibuildwheel==2.11.3
|
|
||||||
shell: bash
|
|
||||||
- name: Building Python ${{inputs.cibw-identifier}} Wheel
|
|
||||||
run: python -m cibuildwheel --output-dir dist
|
|
||||||
env:
|
|
||||||
CIBW_BUILD: ${{inputs.cibw-identifier}}
|
|
||||||
shell: bash
|
|
||||||
- uses: actions/upload-artifact@v3
|
|
||||||
with:
|
|
||||||
name: wheels
|
|
||||||
path: ./dist/*.whl
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
name: DiskANN Build PDoc Documentation
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
build-reference-documentation:
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Set up Python 3.9
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: 3.9
|
|
||||||
- name: Install python build
|
|
||||||
run: python -m pip install build
|
|
||||||
shell: bash
|
|
||||||
# Install required dependencies
|
|
||||||
- name: Prepare Linux environment
|
|
||||||
run: |
|
|
||||||
sudo scripts/dev/install-dev-deps-ubuntu.bash
|
|
||||||
shell: bash
|
|
||||||
# We need to build the wheel in order to run pdoc. pdoc does not seem to work if you just point it at
|
|
||||||
# our source directory.
|
|
||||||
- name: Building Python Wheel for documentation generation
|
|
||||||
run: python -m build --wheel --outdir documentation_dist
|
|
||||||
shell: bash
|
|
||||||
- name: "Run Reference Documentation Generation"
|
|
||||||
run: |
|
|
||||||
pip install pdoc pipdeptree
|
|
||||||
pip install documentation_dist/*.whl
|
|
||||||
echo "documentation" > dependencies_documentation.txt
|
|
||||||
pipdeptree >> dependencies_documentation.txt
|
|
||||||
pdoc -o docs/python/html diskannpy
|
|
||||||
- name: Create version environment variable
|
|
||||||
run: |
|
|
||||||
echo "DISKANN_VERSION=$(python <<EOF
|
|
||||||
from importlib.metadata import version
|
|
||||||
v = version('diskannpy')
|
|
||||||
print(v)
|
|
||||||
EOF
|
|
||||||
)" >> $GITHUB_ENV
|
|
||||||
- name: Archive documentation version artifact
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dependencies
|
|
||||||
path: |
|
|
||||||
${{ github.run_id }}-dependencies_documentation.txt
|
|
||||||
overwrite: true
|
|
||||||
- name: Archive documentation artifacts
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: documentation-site
|
|
||||||
path: |
|
|
||||||
docs/python/html
|
|
||||||
# Publish to /dev if we are on the "main" branch
|
|
||||||
- name: Publish reference docs for latest development version (main branch)
|
|
||||||
uses: peaceiris/actions-gh-pages@v3
|
|
||||||
if: github.ref == 'refs/heads/main'
|
|
||||||
with:
|
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
publish_dir: docs/python/html
|
|
||||||
destination_dir: docs/python/dev
|
|
||||||
# Publish to /<version> if we are releasing
|
|
||||||
- name: Publish reference docs by version (main branch)
|
|
||||||
uses: peaceiris/actions-gh-pages@v3
|
|
||||||
if: github.event_name == 'release'
|
|
||||||
with:
|
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
publish_dir: docs/python/html
|
|
||||||
destination_dir: docs/python/${{ env.DISKANN_VERSION }}
|
|
||||||
# Publish to /latest if we are releasing
|
|
||||||
- name: Publish latest reference docs (main branch)
|
|
||||||
uses: peaceiris/actions-gh-pages@v3
|
|
||||||
if: github.event_name == 'release'
|
|
||||||
with:
|
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
publish_dir: docs/python/html
|
|
||||||
destination_dir: docs/python/latest
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
name: DiskANN Build Python Wheel
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
linux-build:
|
|
||||||
name: Python - Ubuntu - ${{matrix.cibw-identifier}}
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
cibw-identifier: ["cp39-manylinux_x86_64", "cp310-manylinux_x86_64", "cp311-manylinux_x86_64"]
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Building python wheel ${{matrix.cibw-identifier}}
|
|
||||||
uses: ./.github/actions/python-wheel
|
|
||||||
with:
|
|
||||||
cibw-identifier: ${{matrix.cibw-identifier}}
|
|
||||||
windows-build:
|
|
||||||
name: Python - Windows - ${{matrix.cibw-identifier}}
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
cibw-identifier: ["cp39-win_amd64", "cp310-win_amd64", "cp311-win_amd64"]
|
|
||||||
runs-on: windows-latest
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Building python wheel ${{matrix.cibw-identifier}}
|
|
||||||
uses: ./.github/actions/python-wheel
|
|
||||||
with:
|
|
||||||
cibw-identifier: ${{matrix.cibw-identifier}}
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
name: DiskANN Common Checks
|
|
||||||
# common means common to both pr-test and push-test
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
formatting-check:
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: Code Formatting Test
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checking code formatting...
|
|
||||||
uses: ./.github/actions/format-check
|
|
||||||
docker-container-build:
|
|
||||||
name: Docker Container Build
|
|
||||||
needs: [formatting-check]
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Docker build
|
|
||||||
run: |
|
|
||||||
docker build .
|
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
name: Disk With PQ
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-disk-pq:
|
|
||||||
name: Disk, PQ
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, cosine, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16\
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (sharded graph build, cosine, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (int8)
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (sharded graph build, MIPS, diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 --PQ_disk_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: disk-pq-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
name: Dynamic-Labels
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-dynamic:
|
|
||||||
name: Dynamic-Labels
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: Generate Labels
|
|
||||||
run: |
|
|
||||||
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
|
|
||||||
|
|
||||||
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
|
|
||||||
|
|
||||||
- name: Test a streaming index (float) with labels (Zipf distributed)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_zipf_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
|
|
||||||
|
|
||||||
echo "Computing groundtruth with filter"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 --label_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
echo "Searching with filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
echo "Computing groundtruth w/o filter"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000
|
|
||||||
echo "Searching without filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
|
|
||||||
|
|
||||||
- name: Test a streaming index (float) with labels (random distributed)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_rand_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
|
|
||||||
|
|
||||||
echo "Computing groundtruth with filter"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 --label_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
echo "Searching with filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
echo "Computing groundtruth w/o filter"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000
|
|
||||||
echo "Searching without filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
|
|
||||||
|
|
||||||
- name: Test Insert Delete Consolidate (float) with labels (zipf distributed)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/zipf_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_zipf_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
|
|
||||||
|
|
||||||
echo "Computing groundtruth with filter"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K_wlabel_5 --label_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
echo "Searching with filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 10 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_zipf_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
echo "Computing groundtruth w/o filter"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K
|
|
||||||
echo "Searching without filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
|
|
||||||
|
|
||||||
- name: Test Insert Delete Consolidate (float) with labels (random distributed)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/rand_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_rand_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
|
|
||||||
|
|
||||||
echo "Computing groundtruth with filter"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K_wlabel_5 --label_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
echo "Searching with filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 40 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_rand_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
echo "Computing groundtruth w/o filter"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K
|
|
||||||
echo "Searching without filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dynamic-labels-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
name: Dynamic
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-dynamic:
|
|
||||||
name: Dynamic
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: test a streaming index (float)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
|
||||||
- name: test a streaming index (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
|
||||||
- name: test a streaming index
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
- name: build and search an incremental index (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2;
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
|
||||||
- name: build and search an incremental index (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
|
||||||
- name: build and search an incremental index (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_10K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dynamic-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
name: In-Memory Without PQ
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-mem-no-pq:
|
|
||||||
name: In-Mem, Without PQ
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: build and search in-memory index with L2 metrics (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
- name: build and search in-memory index with L2 metrics (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
- name: build and search in-memory index with L2 metrics (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: Searching with fast_l2 distance function (float)
|
|
||||||
if: runner.os != 'Windows' && (success() || failure())
|
|
||||||
run: |
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: build and search in-memory index with MIPS metric (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_mips_rand_float_10D_10K_norm1.0
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn mips --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: build and search in-memory index with cosine metric (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_cosine_rand_float_10D_10K_norm1.0
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
- name: build and search in-memory index with cosine metric (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type int8 --dist_fn cosine --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_int8_10D_10K_norm50.0
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
- name: build and search in-memory index with cosine metric
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50.0
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: in-memory-no-pq-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
name: In-Memory With PQ
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-mem-pq:
|
|
||||||
name: In-Mem, PQ
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: build and search in-memory index with L2 metric with PQ based distance comparisons (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: in-memory-pq-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
name: Labels
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-labels:
|
|
||||||
name: Labels
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: Generate Labels
|
|
||||||
run: |
|
|
||||||
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
|
|
||||||
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/mips_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
|
|
||||||
echo "Generating synthetic labels and computing ground truth for filtered search without a universal label"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --K 100
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 10 --num_points 1000 --output_file data/query_labels_1K.txt --distribution_type one_per_point
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label_file data/query_labels_1K.txt --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
|
|
||||||
- name: build and search in-memory index with labels using L2 and Cosine metrics (random distributed labels)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
|
|
||||||
echo "Searching without filters"
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
|
||||||
|
|
||||||
- name: build and search disk index with labels using L2 and Cosine metrics (random distributed labels)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search in-memory index with labels using L2 and Cosine metrics (zipf distributed labels)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
|
|
||||||
echo "Searching without filters"
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
|
||||||
|
|
||||||
- name: build and search disk index with labels using L2 and Cosine metrics (zipf distributed labels)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name : build and search in-memory and disk index (without universal label, zipf distributed)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal -R 32 -L 5 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal -L 16 32
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: Generate combined GT for each query with a separate label and search
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --query_filters_file data/query_labels_1K.txt --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
- name: build and search in-memory index with pq_dist of 5 with 10 dimensions
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
- name: Build and search stitched vamana with random and zipf distributed labels
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_rand_32_100_64_new --universal_label 0
|
|
||||||
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_zipf_32_100_64_new --universal_label 0
|
|
||||||
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix data/stit_rand_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/rand_stit_96_10_90_new --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
|
|
||||||
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/stit_zipf_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/zipf_stit_96_10_90_new --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
if: success() || failure()
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: labels-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
name: Disk With PQ
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-disk-pq:
|
|
||||||
name: Disk, PQ
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-high-dim-random
|
|
||||||
|
|
||||||
- name: build and search disk index (1020D, one shot graph build, L2, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1020D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
|
||||||
#- name: build and search disk index (1024D, one shot graph build, L2, no diskPQ) (float)
|
|
||||||
# if: success() || failure()
|
|
||||||
# run: |
|
|
||||||
# dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1024D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
|
||||||
# dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
|
||||||
- name: build and search disk index (1536D, one shot graph build, L2, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1536D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (4096D, one shot graph build, L2, no diskPQ) (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_4096D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: multi-sector-disk-pq-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
name: DiskANN Nightly Performance Metrics
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: "41 14 * * *" # 14:41 UTC, 7:41 PDT, 8:41 PST, 08:11 IST
|
|
||||||
jobs:
|
|
||||||
perf-test:
|
|
||||||
name: Run Perf Test from main
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Build Perf Container
|
|
||||||
run: |
|
|
||||||
docker build --build-arg GIT_COMMIT_ISH="$GITHUB_SHA" -t perf -f scripts/perf/Dockerfile scripts
|
|
||||||
- name: Performance Tests
|
|
||||||
run: |
|
|
||||||
mkdir metrics
|
|
||||||
docker run -v ./metrics:/app/logs perf &> ./metrics/combined_stdouterr.log
|
|
||||||
- name: Upload Metrics Logs
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: metrics-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./metrics/**
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
name: DiskANN Pull Request Build and Test
|
|
||||||
on: [pull_request]
|
|
||||||
jobs:
|
|
||||||
common:
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: DiskANN Common Build Checks
|
|
||||||
uses: ./.github/workflows/common.yml
|
|
||||||
unit-tests:
|
|
||||||
name: Unit tests
|
|
||||||
uses: ./.github/workflows/unit-tests.yml
|
|
||||||
in-mem-pq:
|
|
||||||
name: In-Memory with PQ
|
|
||||||
uses: ./.github/workflows/in-mem-pq.yml
|
|
||||||
in-mem-no-pq:
|
|
||||||
name: In-Memory without PQ
|
|
||||||
uses: ./.github/workflows/in-mem-no-pq.yml
|
|
||||||
disk-pq:
|
|
||||||
name: Disk with PQ
|
|
||||||
uses: ./.github/workflows/disk-pq.yml
|
|
||||||
multi-sector-disk-pq:
|
|
||||||
name: Multi-sector Disk with PQ
|
|
||||||
uses: ./.github/workflows/multi-sector-disk-pq.yml
|
|
||||||
labels:
|
|
||||||
name: Labels
|
|
||||||
uses: ./.github/workflows/labels.yml
|
|
||||||
dynamic:
|
|
||||||
name: Dynamic
|
|
||||||
uses: ./.github/workflows/dynamic.yml
|
|
||||||
dynamic-labels:
|
|
||||||
name: Dynamic Labels
|
|
||||||
uses: ./.github/workflows/dynamic-labels.yml
|
|
||||||
python:
|
|
||||||
name: Python
|
|
||||||
uses: ./.github/workflows/build-python.yml
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
name: DiskANN Push Build
|
|
||||||
on: [push]
|
|
||||||
jobs:
|
|
||||||
common:
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: DiskANN Common Build Checks
|
|
||||||
uses: ./.github/workflows/common.yml
|
|
||||||
build-documentation:
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: DiskANN Build Documentation
|
|
||||||
uses: ./.github/workflows/build-python-pdoc.yml
|
|
||||||
build:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ ubuntu-latest, windows-2019, windows-latest ]
|
|
||||||
name: Build for ${{matrix.os}}
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: Build diskannpy dependency tree
|
|
||||||
run: |
|
|
||||||
pip install diskannpy pipdeptree
|
|
||||||
echo "dependencies" > dependencies_${{ matrix.os }}.txt
|
|
||||||
pipdeptree >> dependencies_${{ matrix.os }}.txt
|
|
||||||
- name: Archive diskannpy dependencies artifact
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dependencies_${{ matrix.os }}
|
|
||||||
path: |
|
|
||||||
dependencies_${{ matrix.os }}.txt
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
name: Build and Release Python Wheels
|
|
||||||
on:
|
|
||||||
release:
|
|
||||||
types: [published]
|
|
||||||
jobs:
|
|
||||||
python-release-wheels:
|
|
||||||
name: Python
|
|
||||||
uses: ./.github/workflows/build-python.yml
|
|
||||||
build-documentation:
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: DiskANN Build Documentation
|
|
||||||
uses: ./.github/workflows/build-python-pdoc.yml
|
|
||||||
release:
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: python-release-wheels
|
|
||||||
steps:
|
|
||||||
- uses: actions/download-artifact@v3
|
|
||||||
with:
|
|
||||||
name: wheels
|
|
||||||
path: dist/
|
|
||||||
- name: Generate SHA256 files for each wheel
|
|
||||||
run: |
|
|
||||||
sha256sum dist/*.whl > checksums.txt
|
|
||||||
cat checksums.txt
|
|
||||||
- uses: actions/setup-python@v3
|
|
||||||
- name: Install twine
|
|
||||||
run: python -m pip install twine
|
|
||||||
- name: Publish with twine
|
|
||||||
env:
|
|
||||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
|
||||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
|
||||||
run: |
|
|
||||||
twine upload dist/*.whl
|
|
||||||
- name: Update release with SHA256 and Artifacts
|
|
||||||
uses: softprops/action-gh-release@v1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
files: |
|
|
||||||
dist/*.whl
|
|
||||||
checksums.txt
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
name: Unit Tests
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-labels:
|
|
||||||
name: Unit Tests
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Run Unit Tests
|
|
||||||
run: |
|
|
||||||
cd build
|
|
||||||
ctest -C Release
|
|
||||||
@@ -1,384 +0,0 @@
|
|||||||
## Ignore Visual Studio temporary files, build results, and
|
|
||||||
## files generated by popular Visual Studio add-ons.
|
|
||||||
##
|
|
||||||
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
|
||||||
|
|
||||||
# User-specific files
|
|
||||||
*.rsuser
|
|
||||||
*.suo
|
|
||||||
*.user
|
|
||||||
*.userosscache
|
|
||||||
*.sln.docstates
|
|
||||||
|
|
||||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
|
||||||
*.userprefs
|
|
||||||
|
|
||||||
# Mono auto generated files
|
|
||||||
mono_crash.*
|
|
||||||
|
|
||||||
# Build results
|
|
||||||
[Dd]ebug/
|
|
||||||
[Dd]ebugPublic/
|
|
||||||
[Rr]elease/
|
|
||||||
[Rr]eleases/
|
|
||||||
x64/
|
|
||||||
x86/
|
|
||||||
[Aa][Rr][Mm]/
|
|
||||||
[Aa][Rr][Mm]64/
|
|
||||||
bld/
|
|
||||||
[Bb]in/
|
|
||||||
[Oo]bj/
|
|
||||||
[Ll]og/
|
|
||||||
[Ll]ogs/
|
|
||||||
|
|
||||||
# Visual Studio 2015/2017 cache/options directory
|
|
||||||
.vs/
|
|
||||||
# Uncomment if you have tasks that create the project's static files in wwwroot
|
|
||||||
#wwwroot/
|
|
||||||
|
|
||||||
# Visual Studio 2017 auto generated files
|
|
||||||
Generated\ Files/
|
|
||||||
|
|
||||||
# MSTest test Results
|
|
||||||
[Tt]est[Rr]esult*/
|
|
||||||
[Bb]uild[Ll]og.*
|
|
||||||
|
|
||||||
# NUnit
|
|
||||||
*.VisualState.xml
|
|
||||||
TestResult.xml
|
|
||||||
nunit-*.xml
|
|
||||||
|
|
||||||
# Build Results of an ATL Project
|
|
||||||
[Dd]ebugPS/
|
|
||||||
[Rr]eleasePS/
|
|
||||||
dlldata.c
|
|
||||||
|
|
||||||
# Benchmark Results
|
|
||||||
BenchmarkDotNet.Artifacts/
|
|
||||||
|
|
||||||
# .NET Core
|
|
||||||
project.lock.json
|
|
||||||
project.fragment.lock.json
|
|
||||||
artifacts/
|
|
||||||
|
|
||||||
# StyleCop
|
|
||||||
StyleCopReport.xml
|
|
||||||
|
|
||||||
# Files built by Visual Studio
|
|
||||||
*_i.c
|
|
||||||
*_p.c
|
|
||||||
*_h.h
|
|
||||||
*.ilk
|
|
||||||
*.meta
|
|
||||||
*.obj
|
|
||||||
*.iobj
|
|
||||||
*.pch
|
|
||||||
*.pdb
|
|
||||||
*.ipdb
|
|
||||||
*.pgc
|
|
||||||
*.pgd
|
|
||||||
*.rsp
|
|
||||||
*.sbr
|
|
||||||
*.tlb
|
|
||||||
*.tli
|
|
||||||
*.tlh
|
|
||||||
*.tmp
|
|
||||||
*.tmp_proj
|
|
||||||
*_wpftmp.csproj
|
|
||||||
*.log
|
|
||||||
*.vspscc
|
|
||||||
*.vssscc
|
|
||||||
.builds
|
|
||||||
*.pidb
|
|
||||||
*.svclog
|
|
||||||
*.scc
|
|
||||||
|
|
||||||
# Chutzpah Test files
|
|
||||||
_Chutzpah*
|
|
||||||
|
|
||||||
# Visual C++ cache files
|
|
||||||
ipch/
|
|
||||||
*.aps
|
|
||||||
*.ncb
|
|
||||||
*.opendb
|
|
||||||
*.opensdf
|
|
||||||
*.sdf
|
|
||||||
*.cachefile
|
|
||||||
*.VC.db
|
|
||||||
*.VC.VC.opendb
|
|
||||||
|
|
||||||
# Visual Studio profiler
|
|
||||||
*.psess
|
|
||||||
*.vsp
|
|
||||||
*.vspx
|
|
||||||
*.sap
|
|
||||||
|
|
||||||
# Visual Studio Trace Files
|
|
||||||
*.e2e
|
|
||||||
|
|
||||||
# TFS 2012 Local Workspace
|
|
||||||
$tf/
|
|
||||||
|
|
||||||
# Guidance Automation Toolkit
|
|
||||||
*.gpState
|
|
||||||
|
|
||||||
# ReSharper is a .NET coding add-in
|
|
||||||
_ReSharper*/
|
|
||||||
*.[Rr]e[Ss]harper
|
|
||||||
*.DotSettings.user
|
|
||||||
|
|
||||||
# TeamCity is a build add-in
|
|
||||||
_TeamCity*
|
|
||||||
|
|
||||||
# DotCover is a Code Coverage Tool
|
|
||||||
*.dotCover
|
|
||||||
|
|
||||||
# AxoCover is a Code Coverage Tool
|
|
||||||
.axoCover/*
|
|
||||||
!.axoCover/settings.json
|
|
||||||
|
|
||||||
# Visual Studio code coverage results
|
|
||||||
*.coverage
|
|
||||||
*.coveragexml
|
|
||||||
|
|
||||||
# NCrunch
|
|
||||||
_NCrunch_*
|
|
||||||
.*crunch*.local.xml
|
|
||||||
nCrunchTemp_*
|
|
||||||
|
|
||||||
# MightyMoose
|
|
||||||
*.mm.*
|
|
||||||
AutoTest.Net/
|
|
||||||
|
|
||||||
# Web workbench (sass)
|
|
||||||
.sass-cache/
|
|
||||||
|
|
||||||
# Installshield output folder
|
|
||||||
[Ee]xpress/
|
|
||||||
|
|
||||||
# DocProject is a documentation generator add-in
|
|
||||||
DocProject/buildhelp/
|
|
||||||
DocProject/Help/*.HxT
|
|
||||||
DocProject/Help/*.HxC
|
|
||||||
DocProject/Help/*.hhc
|
|
||||||
DocProject/Help/*.hhk
|
|
||||||
DocProject/Help/*.hhp
|
|
||||||
DocProject/Help/Html2
|
|
||||||
DocProject/Help/html
|
|
||||||
|
|
||||||
# Click-Once directory
|
|
||||||
publish/
|
|
||||||
|
|
||||||
# Publish Web Output
|
|
||||||
*.[Pp]ublish.xml
|
|
||||||
*.azurePubxml
|
|
||||||
# Note: Comment the next line if you want to checkin your web deploy settings,
|
|
||||||
# but database connection strings (with potential passwords) will be unencrypted
|
|
||||||
*.pubxml
|
|
||||||
*.publishproj
|
|
||||||
|
|
||||||
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
|
||||||
# checkin your Azure Web App publish settings, but sensitive information contained
|
|
||||||
# in these scripts will be unencrypted
|
|
||||||
PublishScripts/
|
|
||||||
|
|
||||||
# NuGet Packages
|
|
||||||
*.nupkg
|
|
||||||
# NuGet Symbol Packages
|
|
||||||
*.snupkg
|
|
||||||
# The packages folder can be ignored because of Package Restore
|
|
||||||
**/[Pp]ackages/*
|
|
||||||
# except build/, which is used as an MSBuild target.
|
|
||||||
!**/[Pp]ackages/build/
|
|
||||||
# Uncomment if necessary however generally it will be regenerated when needed
|
|
||||||
#!**/[Pp]ackages/repositories.config
|
|
||||||
# NuGet v3's project.json files produces more ignorable files
|
|
||||||
*.nuget.props
|
|
||||||
*.nuget.targets
|
|
||||||
|
|
||||||
# Microsoft Azure Build Output
|
|
||||||
csx/
|
|
||||||
*.build.csdef
|
|
||||||
|
|
||||||
# Microsoft Azure Emulator
|
|
||||||
ecf/
|
|
||||||
rcf/
|
|
||||||
|
|
||||||
# Windows Store app package directories and files
|
|
||||||
AppPackages/
|
|
||||||
BundleArtifacts/
|
|
||||||
Package.StoreAssociation.xml
|
|
||||||
_pkginfo.txt
|
|
||||||
*.appx
|
|
||||||
*.appxbundle
|
|
||||||
*.appxupload
|
|
||||||
|
|
||||||
# Visual Studio cache files
|
|
||||||
# files ending in .cache can be ignored
|
|
||||||
*.[Cc]ache
|
|
||||||
# but keep track of directories ending in .cache
|
|
||||||
!?*.[Cc]ache/
|
|
||||||
|
|
||||||
# Others
|
|
||||||
ClientBin/
|
|
||||||
~$*
|
|
||||||
*~
|
|
||||||
*.dbmdl
|
|
||||||
*.dbproj.schemaview
|
|
||||||
*.jfm
|
|
||||||
*.pfx
|
|
||||||
*.publishsettings
|
|
||||||
orleans.codegen.cs
|
|
||||||
|
|
||||||
# Including strong name files can present a security risk
|
|
||||||
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
|
||||||
#*.snk
|
|
||||||
|
|
||||||
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
|
||||||
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
|
||||||
#bower_components/
|
|
||||||
|
|
||||||
# RIA/Silverlight projects
|
|
||||||
Generated_Code/
|
|
||||||
|
|
||||||
# Backup & report files from converting an old project file
|
|
||||||
# to a newer Visual Studio version. Backup files are not needed,
|
|
||||||
# because we have git ;-)
|
|
||||||
_UpgradeReport_Files/
|
|
||||||
Backup*/
|
|
||||||
UpgradeLog*.XML
|
|
||||||
UpgradeLog*.htm
|
|
||||||
ServiceFabricBackup/
|
|
||||||
*.rptproj.bak
|
|
||||||
|
|
||||||
# SQL Server files
|
|
||||||
*.mdf
|
|
||||||
*.ldf
|
|
||||||
*.ndf
|
|
||||||
|
|
||||||
# Business Intelligence projects
|
|
||||||
*.rdl.data
|
|
||||||
*.bim.layout
|
|
||||||
*.bim_*.settings
|
|
||||||
*.rptproj.rsuser
|
|
||||||
*- [Bb]ackup.rdl
|
|
||||||
*- [Bb]ackup ([0-9]).rdl
|
|
||||||
*- [Bb]ackup ([0-9][0-9]).rdl
|
|
||||||
|
|
||||||
# Microsoft Fakes
|
|
||||||
FakesAssemblies/
|
|
||||||
|
|
||||||
# GhostDoc plugin setting file
|
|
||||||
*.GhostDoc.xml
|
|
||||||
|
|
||||||
# Node.js Tools for Visual Studio
|
|
||||||
.ntvs_analysis.dat
|
|
||||||
node_modules/
|
|
||||||
|
|
||||||
# Visual Studio 6 build log
|
|
||||||
*.plg
|
|
||||||
|
|
||||||
# Visual Studio 6 workspace options file
|
|
||||||
*.opt
|
|
||||||
|
|
||||||
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
|
||||||
*.vbw
|
|
||||||
|
|
||||||
# Visual Studio LightSwitch build output
|
|
||||||
**/*.HTMLClient/GeneratedArtifacts
|
|
||||||
**/*.DesktopClient/GeneratedArtifacts
|
|
||||||
**/*.DesktopClient/ModelManifest.xml
|
|
||||||
**/*.Server/GeneratedArtifacts
|
|
||||||
**/*.Server/ModelManifest.xml
|
|
||||||
_Pvt_Extensions
|
|
||||||
|
|
||||||
# Paket dependency manager
|
|
||||||
.paket/paket.exe
|
|
||||||
paket-files/
|
|
||||||
|
|
||||||
# FAKE - F# Make
|
|
||||||
.fake/
|
|
||||||
|
|
||||||
# CodeRush personal settings
|
|
||||||
.cr/personal
|
|
||||||
|
|
||||||
# Python Tools for Visual Studio (PTVS)
|
|
||||||
__pycache__/
|
|
||||||
*.pyc
|
|
||||||
|
|
||||||
# Cake - Uncomment if you are using it
|
|
||||||
# tools/**
|
|
||||||
# !tools/packages.config
|
|
||||||
|
|
||||||
# Tabs Studio
|
|
||||||
*.tss
|
|
||||||
|
|
||||||
# Telerik's JustMock configuration file
|
|
||||||
*.jmconfig
|
|
||||||
|
|
||||||
# BizTalk build output
|
|
||||||
*.btp.cs
|
|
||||||
*.btm.cs
|
|
||||||
*.odx.cs
|
|
||||||
*.xsd.cs
|
|
||||||
|
|
||||||
# OpenCover UI analysis results
|
|
||||||
OpenCover/
|
|
||||||
|
|
||||||
# Azure Stream Analytics local run output
|
|
||||||
ASALocalRun/
|
|
||||||
|
|
||||||
# MSBuild Binary and Structured Log
|
|
||||||
*.binlog
|
|
||||||
|
|
||||||
# NVidia Nsight GPU debugger configuration file
|
|
||||||
*.nvuser
|
|
||||||
|
|
||||||
# MFractors (Xamarin productivity tool) working folder
|
|
||||||
.mfractor/
|
|
||||||
|
|
||||||
# Local History for Visual Studio
|
|
||||||
.localhistory/
|
|
||||||
|
|
||||||
# BeatPulse healthcheck temp database
|
|
||||||
healthchecksdb
|
|
||||||
|
|
||||||
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
|
||||||
MigrationBackup/
|
|
||||||
|
|
||||||
# Ionide (cross platform F# VS Code tools) working folder
|
|
||||||
.ionide/
|
|
||||||
|
|
||||||
/vcproj/nsg/x64/Debug/nsg.Build.CppClean.log
|
|
||||||
/vcproj/test_recall/x64/Debug/test_recall.Build.CppClean.log
|
|
||||||
/vcproj/test_recall/test_recall.vcxproj.user
|
|
||||||
/.vs
|
|
||||||
/out/build/x64-Debug
|
|
||||||
cscope*
|
|
||||||
|
|
||||||
build/
|
|
||||||
build_linux/
|
|
||||||
!.github/actions/build
|
|
||||||
|
|
||||||
# jetbrains specific stuff
|
|
||||||
.idea/
|
|
||||||
cmake-build-debug/
|
|
||||||
|
|
||||||
#python extension module ignores
|
|
||||||
python/diskannpy.egg-info/
|
|
||||||
python/dist/
|
|
||||||
|
|
||||||
**/*.egg-info
|
|
||||||
wheelhouse/*
|
|
||||||
dist/*
|
|
||||||
venv*/**
|
|
||||||
*.swp
|
|
||||||
|
|
||||||
gperftools
|
|
||||||
|
|
||||||
# Rust
|
|
||||||
rust/target
|
|
||||||
|
|
||||||
python/src/*.so
|
|
||||||
|
|
||||||
compile_commands.json
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
[submodule "gperftools"]
|
|
||||||
path = gperftools
|
|
||||||
url = https://github.com/gperftools/gperftools.git
|
|
||||||
@@ -1,563 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
# Parameters:
|
|
||||||
#
|
|
||||||
# BOOST_ROOT:
|
|
||||||
# Specify root of the Boost library if Boost cannot be auto-detected. On Windows, a fallback to a
|
|
||||||
# downloaded nuget version will be used if Boost cannot be found.
|
|
||||||
#
|
|
||||||
# DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS:
|
|
||||||
# This is a work-in-progress feature, not completed yet. The core DiskANN library will be split into
|
|
||||||
# build-related and search-related functionality. In build-related functionality, when using tcmalloc,
|
|
||||||
# it's possible to release memory that's free but reserved by tcmalloc. Setting this to true enables
|
|
||||||
# such behavior.
|
|
||||||
# Contact for this feature: gopalrs.
|
|
||||||
|
|
||||||
|
|
||||||
# Some variables like MSVC are defined only after project(), so put that first.
|
|
||||||
cmake_minimum_required(VERSION 3.20)
|
|
||||||
project(diskann)
|
|
||||||
|
|
||||||
#Set option to use tcmalloc
|
|
||||||
option(USE_TCMALLOC "Use tcmalloc from gperftools" ON)
|
|
||||||
|
|
||||||
# set tcmalloc to false when on macos
|
|
||||||
if(APPLE)
|
|
||||||
set(USE_TCMALLOC OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
option(PYBIND "Build with Python bindings" ON)
|
|
||||||
|
|
||||||
if(PYBIND)
|
|
||||||
# Find Python
|
|
||||||
find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
|
|
||||||
execute_process(
|
|
||||||
COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())"
|
|
||||||
OUTPUT_VARIABLE pybind11_DIR
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
find_package(pybind11 CONFIG REQUIRED)
|
|
||||||
|
|
||||||
message(STATUS "Python include dirs: ${Python_INCLUDE_DIRS}")
|
|
||||||
message(STATUS "Pybind11 include dirs: ${pybind11_INCLUDE_DIRS}")
|
|
||||||
|
|
||||||
# Add pybind11 include directories
|
|
||||||
include_directories(SYSTEM ${pybind11_INCLUDE_DIRS} ${Python_INCLUDE_DIRS})
|
|
||||||
|
|
||||||
# Add compilation definitions
|
|
||||||
add_definitions(-DPYBIND11_EMBEDDED)
|
|
||||||
|
|
||||||
# Set visibility flags
|
|
||||||
if(NOT MSVC)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(CMAKE_STANDARD 17)
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
||||||
|
|
||||||
# if(NOT MSVC)
|
|
||||||
# set(CMAKE_CXX_COMPILER g++)
|
|
||||||
# endif()
|
|
||||||
|
|
||||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
|
|
||||||
|
|
||||||
# Install nuget packages for dependencies.
|
|
||||||
if (MSVC)
|
|
||||||
find_program(NUGET_EXE NAMES nuget)
|
|
||||||
|
|
||||||
if (NOT NUGET_EXE)
|
|
||||||
message(FATAL_ERROR "Cannot find nuget command line tool.\nPlease install it from e.g. https://www.nuget.org/downloads")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(DISKANN_MSVC_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/packages.config)
|
|
||||||
set(DISKANN_MSVC_PACKAGES ${CMAKE_BINARY_DIR}/packages)
|
|
||||||
|
|
||||||
message(STATUS "Invoking nuget to download Boost, OpenMP and MKL dependencies...")
|
|
||||||
configure_file(${PROJECT_SOURCE_DIR}/windows/packages.config.in ${DISKANN_MSVC_PACKAGES_CONFIG})
|
|
||||||
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
|
|
||||||
if (RESTAPI)
|
|
||||||
set(DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/restapi/packages.config)
|
|
||||||
configure_file(${PROJECT_SOURCE_DIR}/windows/packages_restapi.config.in ${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG})
|
|
||||||
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
|
|
||||||
endif()
|
|
||||||
message(STATUS "Finished setting up nuget dependencies")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
|
||||||
|
|
||||||
include(FetchContent)
|
|
||||||
|
|
||||||
if(USE_TCMALLOC)
|
|
||||||
FetchContent_Declare(
|
|
||||||
tcmalloc
|
|
||||||
GIT_REPOSITORY https://github.com/google/tcmalloc.git
|
|
||||||
GIT_TAG origin/master # or specify a particular version or commit
|
|
||||||
)
|
|
||||||
|
|
||||||
FetchContent_MakeAvailable(tcmalloc)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(NOT PYBIND)
|
|
||||||
set(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS ON)
|
|
||||||
endif()
|
|
||||||
# It's necessary to include tcmalloc headers only if calling into MallocExtension interface.
|
|
||||||
# For using tcmalloc in DiskANN tools, it's enough to just link with tcmalloc.
|
|
||||||
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
|
||||||
include_directories(${tcmalloc_SOURCE_DIR}/src)
|
|
||||||
if (MSVC)
|
|
||||||
include_directories(${tcmalloc_SOURCE_DIR}/src/windows)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#OpenMP
|
|
||||||
if (MSVC)
|
|
||||||
# Do not use find_package here since it would use VisualStudio's built-in OpenMP, but MKL libraries
|
|
||||||
# refer to Intel's OpenMP.
|
|
||||||
#
|
|
||||||
# No extra settings are needed for compilation: it only needs /openmp flag which is set further below,
|
|
||||||
# in the common MSVC compiler options block.
|
|
||||||
include_directories(BEFORE "${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/include")
|
|
||||||
link_libraries("${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/win-x64/libiomp5md.lib")
|
|
||||||
|
|
||||||
set(OPENMP_WINDOWS_RUNTIME_FILES
|
|
||||||
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.dll"
|
|
||||||
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.pdb")
|
|
||||||
elseif(APPLE)
|
|
||||||
# Check if we're building Python bindings
|
|
||||||
if(PYBIND)
|
|
||||||
# First look for PyTorch's OpenMP to avoid conflicts
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${Python_EXECUTABLE} -c "import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib', 'libomp.dylib'))"
|
|
||||||
RESULT_VARIABLE TORCH_PATH_RESULT
|
|
||||||
OUTPUT_VARIABLE TORCH_LIBOMP_PATH
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
ERROR_QUIET
|
|
||||||
)
|
|
||||||
|
|
||||||
execute_process(
|
|
||||||
COMMAND brew --prefix libomp
|
|
||||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
|
|
||||||
if(EXISTS "${TORCH_LIBOMP_PATH}")
|
|
||||||
message(STATUS "Found PyTorch's libomp: ${TORCH_LIBOMP_PATH}")
|
|
||||||
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp")
|
|
||||||
set(OpenMP_C_FLAGS "-Xclang -fopenmp")
|
|
||||||
set(OpenMP_CXX_LIBRARIES "${TORCH_LIBOMP_PATH}")
|
|
||||||
set(OpenMP_C_LIBRARIES "${TORCH_LIBOMP_PATH}")
|
|
||||||
set(OpenMP_FOUND TRUE)
|
|
||||||
|
|
||||||
include_directories(${LIBOMP_ROOT}/include)
|
|
||||||
|
|
||||||
# Set compiler flags and link libraries
|
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
|
||||||
link_libraries("${TORCH_LIBOMP_PATH}")
|
|
||||||
else()
|
|
||||||
message(STATUS "No PyTorch's libomp found, falling back to normal OpenMP detection")
|
|
||||||
# Fallback to normal OpenMP detection
|
|
||||||
execute_process(
|
|
||||||
COMMAND brew --prefix libomp
|
|
||||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
|
|
||||||
set(OpenMP_ROOT "${LIBOMP_ROOT}")
|
|
||||||
find_package(OpenMP)
|
|
||||||
|
|
||||||
if (OPENMP_FOUND)
|
|
||||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
|
||||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
|
||||||
link_libraries(OpenMP::OpenMP_CXX)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "No OpenMP support")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
# Regular OpenMP setup for non-Python builds
|
|
||||||
execute_process(
|
|
||||||
COMMAND brew --prefix libomp
|
|
||||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
set(OpenMP_ROOT "${LIBOMP_ROOT}")
|
|
||||||
find_package(OpenMP)
|
|
||||||
|
|
||||||
if (OPENMP_FOUND)
|
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
|
||||||
link_libraries(OpenMP::OpenMP_CXX)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "No OpenMP support")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
find_package(OpenMP)
|
|
||||||
|
|
||||||
if (OPENMP_FOUND)
|
|
||||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
|
||||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "No OpenMP support")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# DiskANN core uses header-only libraries. Only DiskANN tools need program_options which has a linker library,
|
|
||||||
# but its size is small. Reduce number of dependent DLLs by linking statically.
|
|
||||||
if (MSVC)
|
|
||||||
set(Boost_USE_STATIC_LIBS ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(NOT MSVC)
|
|
||||||
find_package(Boost COMPONENTS program_options)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# For Windows, fall back to nuget version if find_package didn't find it.
|
|
||||||
if (MSVC AND NOT Boost_FOUND)
|
|
||||||
set(DISKANN_BOOST_INCLUDE "${DISKANN_MSVC_PACKAGES}/boost/lib/native/include")
|
|
||||||
# Multi-threaded static library.
|
|
||||||
set(PROGRAM_OPTIONS_LIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-x64-*.lib")
|
|
||||||
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_LIB ${PROGRAM_OPTIONS_LIB_PATTERN})
|
|
||||||
|
|
||||||
set(PROGRAM_OPTIONS_DLIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-gd-x64-*.lib")
|
|
||||||
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_DLIB ${PROGRAM_OPTIONS_DLIB_PATTERN})
|
|
||||||
|
|
||||||
if (EXISTS ${DISKANN_BOOST_INCLUDE} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_LIB} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB})
|
|
||||||
set(Boost_FOUND ON)
|
|
||||||
set(Boost_INCLUDE_DIR ${DISKANN_BOOST_INCLUDE})
|
|
||||||
add_library(Boost::program_options STATIC IMPORTED)
|
|
||||||
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_RELEASE "${DISKANN_BOOST_PROGRAM_OPTIONS_LIB}")
|
|
||||||
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_DEBUG "${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB}")
|
|
||||||
message(STATUS "Falling back to using Boost from the nuget package")
|
|
||||||
else()
|
|
||||||
message(WARNING "Couldn't find Boost. Was looking for ${DISKANN_BOOST_INCLUDE} and ${PROGRAM_OPTIONS_LIB_PATTERN}")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NOT Boost_FOUND)
|
|
||||||
message(FATAL_ERROR "Couldn't find Boost dependency")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include_directories(${Boost_INCLUDE_DIR})
|
|
||||||
|
|
||||||
#MKL Config
|
|
||||||
if (MSVC)
|
|
||||||
# Only the DiskANN DLL and one of the tools need MKL libraries. Additionally, only a small part of MKL is used.
|
|
||||||
# Given that and given that MKL DLLs are huge, use static linking to end up with no MKL DLL dependencies and with
|
|
||||||
# significantly smaller disk footprint.
|
|
||||||
#
|
|
||||||
# The compile options are not modified as there's already an unconditional -DMKL_ILP64 define below
|
|
||||||
# for all architectures, which is all that's needed.
|
|
||||||
set(DISKANN_MKL_INCLUDE_DIRECTORIES "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/include")
|
|
||||||
set(DISKANN_MKL_LIB_PATH "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/win-x64")
|
|
||||||
|
|
||||||
set(DISKANN_MKL_LINK_LIBRARIES
|
|
||||||
"${DISKANN_MKL_LIB_PATH}/mkl_intel_ilp64.lib"
|
|
||||||
"${DISKANN_MKL_LIB_PATH}/mkl_core.lib"
|
|
||||||
"${DISKANN_MKL_LIB_PATH}/mkl_intel_thread.lib")
|
|
||||||
elseif(APPLE)
|
|
||||||
# no mkl on non-intel devices
|
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
|
||||||
message(STATUS "Found Accelerate (${ACCELERATE_LIBRARY})")
|
|
||||||
set(DISKANN_ACCEL_LINK_OPTIONS ${ACCELERATE_LIBRARY})
|
|
||||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
|
||||||
else()
|
|
||||||
# expected path for manual intel mkl installs
|
|
||||||
set(POSSIBLE_OMP_PATHS "/opt/intel/oneapi/compiler/2025.0/lib/libiomp5.so;/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin/libiomp5.so;/usr/lib/x86_64-linux-gnu/libiomp5.so;/opt/intel/lib/intel64_lin/libiomp5.so")
|
|
||||||
foreach(POSSIBLE_OMP_PATH ${POSSIBLE_OMP_PATHS})
|
|
||||||
if (EXISTS ${POSSIBLE_OMP_PATH})
|
|
||||||
get_filename_component(OMP_PATH ${POSSIBLE_OMP_PATH} DIRECTORY)
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
if(NOT OMP_PATH)
|
|
||||||
message(FATAL_ERROR "Could not find Intel OMP in standard locations; use -DOMP_PATH to specify the install location for your environment")
|
|
||||||
endif()
|
|
||||||
link_directories(${OMP_PATH})
|
|
||||||
|
|
||||||
set(POSSIBLE_MKL_LIB_PATHS "/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so;/usr/lib/x86_64-linux-gnu/libmkl_core.so;/opt/intel/mkl/lib/intel64/libmkl_core.so")
|
|
||||||
foreach(POSSIBLE_MKL_LIB_PATH ${POSSIBLE_MKL_LIB_PATHS})
|
|
||||||
if (EXISTS ${POSSIBLE_MKL_LIB_PATH})
|
|
||||||
get_filename_component(MKL_PATH ${POSSIBLE_MKL_LIB_PATH} DIRECTORY)
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
set(POSSIBLE_MKL_INCLUDE_PATHS "/opt/intel/oneapi/mkl/latest/include;/usr/include/mkl;/opt/intel/mkl/include/;")
|
|
||||||
foreach(POSSIBLE_MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATHS})
|
|
||||||
if (EXISTS ${POSSIBLE_MKL_INCLUDE_PATH})
|
|
||||||
set(MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATH})
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
if(NOT MKL_PATH)
|
|
||||||
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_PATH to specify the install location for your environment")
|
|
||||||
elseif(NOT MKL_INCLUDE_PATH)
|
|
||||||
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_INCLUDE_PATH to specify the install location for headers for your environment")
|
|
||||||
endif()
|
|
||||||
if (EXISTS ${MKL_PATH}/libmkl_def.so.2)
|
|
||||||
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so.2)
|
|
||||||
elseif(EXISTS ${MKL_PATH}/libmkl_def.so)
|
|
||||||
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Despite finding MKL, libmkl_def.so was not found in expected locations.")
|
|
||||||
endif()
|
|
||||||
link_directories(${MKL_PATH})
|
|
||||||
include_directories(${MKL_INCLUDE_PATH})
|
|
||||||
|
|
||||||
# compile flags and link libraries
|
|
||||||
# if gcc/g++
|
|
||||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
|
||||||
add_compile_options(-m64 -Wl,--no-as-needed)
|
|
||||||
endif()
|
|
||||||
if (NOT PYBIND)
|
|
||||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl)
|
|
||||||
else()
|
|
||||||
# static linking for python so as to minimize customer dependency issues
|
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
|
||||||
# In debug mode, use dynamic linking to ensure all symbols are available
|
|
||||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core ${MKL_DEF_SO} iomp5 pthread m dl)
|
|
||||||
else()
|
|
||||||
# In release mode, use static linking to minimize dependencies
|
|
||||||
link_libraries(
|
|
||||||
${MKL_PATH}/libmkl_intel_ilp64.a
|
|
||||||
${MKL_PATH}/libmkl_intel_thread.a
|
|
||||||
${MKL_PATH}/libmkl_core.a
|
|
||||||
${MKL_DEF_SO}
|
|
||||||
iomp5
|
|
||||||
pthread
|
|
||||||
m
|
|
||||||
dl
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_definitions(-DMKL_ILP64)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
|
|
||||||
# Section for tcmalloc. The DiskANN tools are always linked to tcmalloc. For Windows, they also need to
|
|
||||||
# force-include the _tcmalloc symbol for enabling tcmalloc.
|
|
||||||
#
|
|
||||||
# The DLL itself needs to be linked to tcmalloc only if DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS
|
|
||||||
# is enabled.
|
|
||||||
if(USE_TCMALLOC)
|
|
||||||
if (MSVC)
|
|
||||||
if (NOT EXISTS "${PROJECT_SOURCE_DIR}/gperftools/gperftools.sln")
|
|
||||||
message(FATAL_ERROR "The gperftools submodule was not found. "
|
|
||||||
"Please check-out git submodules by doing 'git submodule init' followed by 'git submodule update'")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(TCMALLOC_LINK_LIBRARY "${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.lib")
|
|
||||||
set(TCMALLOC_WINDOWS_RUNTIME_FILES
|
|
||||||
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.dll"
|
|
||||||
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.pdb")
|
|
||||||
|
|
||||||
# Tell CMake how to build the tcmalloc linker library from the submodule.
|
|
||||||
add_custom_target(build_libtcmalloc_minimal DEPENDS ${TCMALLOC_LINK_LIBRARY})
|
|
||||||
add_custom_command(OUTPUT ${TCMALLOC_LINK_LIBRARY}
|
|
||||||
COMMAND ${CMAKE_VS_MSBUILD_COMMAND} gperftools.sln /m /nologo
|
|
||||||
/t:libtcmalloc_minimal /p:Configuration="Release-Patch"
|
|
||||||
/property:Platform="x64"
|
|
||||||
/p:PlatformToolset=v${MSVC_TOOLSET_VERSION}
|
|
||||||
/p:WindowsTargetPlatformVersion=${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION}
|
|
||||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/gperftools)
|
|
||||||
|
|
||||||
add_library(libtcmalloc_minimal_for_exe STATIC IMPORTED)
|
|
||||||
add_library(libtcmalloc_minimal_for_dll STATIC IMPORTED)
|
|
||||||
|
|
||||||
set_target_properties(libtcmalloc_minimal_for_dll PROPERTIES
|
|
||||||
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}")
|
|
||||||
|
|
||||||
set_target_properties(libtcmalloc_minimal_for_exe PROPERTIES
|
|
||||||
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}"
|
|
||||||
INTERFACE_LINK_OPTIONS /INCLUDE:_tcmalloc)
|
|
||||||
|
|
||||||
# Ensure libtcmalloc_minimal is built before it's being used.
|
|
||||||
add_dependencies(libtcmalloc_minimal_for_dll build_libtcmalloc_minimal)
|
|
||||||
add_dependencies(libtcmalloc_minimal_for_exe build_libtcmalloc_minimal)
|
|
||||||
|
|
||||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_exe)
|
|
||||||
elseif(APPLE) # ! Inherited from #474, not been adjusted for TCMalloc Removal
|
|
||||||
execute_process(
|
|
||||||
COMMAND brew --prefix gperftools
|
|
||||||
OUTPUT_VARIABLE GPERFTOOLS_PREFIX
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-L${GPERFTOOLS_PREFIX}/lib -ltcmalloc")
|
|
||||||
elseif(NOT PYBIND)
|
|
||||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-ltcmalloc")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
|
||||||
add_definitions(-DRELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
|
||||||
|
|
||||||
if (MSVC)
|
|
||||||
set(DISKANN_DLL_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_dll)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NOT MSVC AND NOT APPLE)
|
|
||||||
set(DISKANN_ASYNC_LIB aio)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#Main compiler/linker settings
|
|
||||||
if(MSVC)
|
|
||||||
#language options
|
|
||||||
add_compile_options(/permissive- /openmp:experimental /Zc:twoPhase- /Zc:inline /WX- /std:c++17 /Gd /W3 /MP /Zi /FC /nologo)
|
|
||||||
#code generation options
|
|
||||||
add_compile_options(/arch:AVX2 /fp:fast /fp:except- /EHsc /GS- /Gy)
|
|
||||||
#optimization options
|
|
||||||
add_compile_options(/Ot /Oy /Oi)
|
|
||||||
#path options
|
|
||||||
add_definitions(-DUSE_AVX2 -DUSE_ACCELERATED_PQ -D_WINDOWS -DNOMINMAX -DUNICODE)
|
|
||||||
# Linker options. Exclude VCOMP/VCOMPD.LIB which contain VisualStudio's version of OpenMP.
|
|
||||||
# MKL was linked against Intel's OpenMP and depends on the corresponding DLL.
|
|
||||||
add_link_options(/NODEFAULTLIB:VCOMP.LIB /NODEFAULTLIB:VCOMPD.LIB /DEBUG:FULL /OPT:REF /OPT:ICF)
|
|
||||||
|
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
|
||||||
|
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
|
||||||
elseif(APPLE)
|
|
||||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -Xclang -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -Wno-inconsistent-missing-override -Wno-return-type")
|
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DNDEBUG -ftree-vectorize")
|
|
||||||
if (NOT PYBIND)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
|
|
||||||
if (NOT PORTABLE)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -mtune=native")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
# -Ofast is not supported in a python extension module
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -fPIC")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -msse2 -ftree-vectorize -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2 -fPIC")
|
|
||||||
if(USE_TCMALLOC)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free")
|
|
||||||
endif()
|
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
|
||||||
if (NOT PYBIND)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
|
|
||||||
if (NOT PORTABLE)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
# -Ofast is not supported in a python extension module
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_subdirectory(src)
|
|
||||||
if (NOT PYBIND)
|
|
||||||
add_subdirectory(apps)
|
|
||||||
add_subdirectory(apps/utils)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (UNIT_TEST)
|
|
||||||
enable_testing()
|
|
||||||
add_subdirectory(tests)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (MSVC)
|
|
||||||
message(STATUS "The ${PROJECT_NAME}.sln has been created, opened it from VisualStudio to build Release or Debug configurations.\n"
|
|
||||||
"Alternatively, use MSBuild to build:\n\n"
|
|
||||||
"msbuild.exe ${PROJECT_NAME}.sln /m /nologo /t:Build /p:Configuration=\"Release\" /property:Platform=\"x64\"\n")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (RESTAPI)
|
|
||||||
if (MSVC)
|
|
||||||
set(DISKANN_CPPRESTSDK "${DISKANN_MSVC_PACKAGES}/cpprestsdk.v142/build/native")
|
|
||||||
# expected path for apt packaged intel mkl installs
|
|
||||||
link_libraries("${DISKANN_CPPRESTSDK}/x64/lib/cpprest142_2_10.lib")
|
|
||||||
include_directories("${DISKANN_CPPRESTSDK}/include")
|
|
||||||
endif()
|
|
||||||
add_subdirectory(apps/restapi)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include(clang-format.cmake)
|
|
||||||
|
|
||||||
if(PYBIND)
|
|
||||||
add_subdirectory(python)
|
|
||||||
|
|
||||||
install(TARGETS _diskannpy
|
|
||||||
DESTINATION leann_backend_diskann
|
|
||||||
COMPONENT python_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
endif()
|
|
||||||
###############################################################################
|
|
||||||
# PROTOBUF SECTION - Corrected to use CONFIG mode explicitly
|
|
||||||
###############################################################################
|
|
||||||
set(Protobuf_USE_STATIC_LIBS OFF)
|
|
||||||
|
|
||||||
find_package(ZLIB REQUIRED)
|
|
||||||
|
|
||||||
find_package(Protobuf REQUIRED)
|
|
||||||
|
|
||||||
message(STATUS "Protobuf found: ${Protobuf_VERSION}")
|
|
||||||
message(STATUS "Protobuf include dirs: ${Protobuf_INCLUDE_DIRS}")
|
|
||||||
message(STATUS "Protobuf libraries: ${Protobuf_LIBRARIES}")
|
|
||||||
message(STATUS "Protobuf protoc executable: ${Protobuf_PROTOC_EXECUTABLE}")
|
|
||||||
|
|
||||||
include_directories(${Protobuf_INCLUDE_DIRS})
|
|
||||||
|
|
||||||
set(PROTO_FILE "${CMAKE_CURRENT_SOURCE_DIR}/../embedding.proto")
|
|
||||||
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
|
|
||||||
set(generated_proto_sources ${PROTO_SRCS})
|
|
||||||
|
|
||||||
|
|
||||||
add_library(proto_embeddings STATIC ${generated_proto_sources})
|
|
||||||
target_link_libraries(proto_embeddings PUBLIC protobuf::libprotobuf)
|
|
||||||
target_include_directories(proto_embeddings PUBLIC
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
${Protobuf_INCLUDE_DIRS}
|
|
||||||
)
|
|
||||||
|
|
||||||
target_link_libraries(diskann PRIVATE proto_embeddings protobuf::libprotobuf)
|
|
||||||
target_include_directories(diskann PRIVATE
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
${Protobuf_INCLUDE_DIRS}
|
|
||||||
)
|
|
||||||
|
|
||||||
target_link_libraries(diskann_s PRIVATE proto_embeddings protobuf::libprotobuf)
|
|
||||||
target_include_directories(diskann_s PRIVATE
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
${Protobuf_INCLUDE_DIRS}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
# ZEROMQ SECTION - REQUIRED
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
find_package(ZeroMQ QUIET)
|
|
||||||
if(NOT ZeroMQ_FOUND)
|
|
||||||
find_path(ZeroMQ_INCLUDE_DIR zmq.h)
|
|
||||||
find_library(ZeroMQ_LIBRARY zmq)
|
|
||||||
if(ZeroMQ_INCLUDE_DIR AND ZeroMQ_LIBRARY)
|
|
||||||
set(ZeroMQ_FOUND TRUE)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(ZeroMQ_FOUND)
|
|
||||||
message(STATUS "Found ZeroMQ: ${ZeroMQ_LIBRARY}")
|
|
||||||
include_directories(${ZeroMQ_INCLUDE_DIR})
|
|
||||||
target_link_libraries(diskann PRIVATE ${ZeroMQ_LIBRARY})
|
|
||||||
target_link_libraries(diskann_s PRIVATE ${ZeroMQ_LIBRARY})
|
|
||||||
add_definitions(-DUSE_ZEROMQ)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "ZeroMQ is required but not found. Please install ZeroMQ and try again.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
target_link_libraries(diskann ${PYBIND11_LIBRARIES})
|
|
||||||
target_link_libraries(diskann_s ${PYBIND11_LIBRARIES})
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
{
|
|
||||||
"configurations": [
|
|
||||||
{
|
|
||||||
"name": "x64-Release",
|
|
||||||
"generator": "Ninja",
|
|
||||||
"configurationType": "Release",
|
|
||||||
"inheritEnvironments": [ "msvc_x64" ],
|
|
||||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
|
||||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
|
||||||
"cmakeCommandArgs": "",
|
|
||||||
"buildCommandArgs": "",
|
|
||||||
"ctestCommandArgs": ""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "WSL-GCC-Release",
|
|
||||||
"generator": "Ninja",
|
|
||||||
"configurationType": "RelWithDebInfo",
|
|
||||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
|
||||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
|
||||||
"cmakeExecutable": "cmake",
|
|
||||||
"cmakeCommandArgs": "",
|
|
||||||
"buildCommandArgs": "",
|
|
||||||
"ctestCommandArgs": "",
|
|
||||||
"inheritEnvironments": [ "linux_x64" ],
|
|
||||||
"wslPath": "${defaultWSLPath}"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
# Microsoft Open Source Code of Conduct
|
|
||||||
|
|
||||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
|
||||||
|
|
||||||
Resources:
|
|
||||||
|
|
||||||
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
|
||||||
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
|
||||||
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
# Contributing
|
|
||||||
|
|
||||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
|
||||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
|
||||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
|
||||||
|
|
||||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
|
||||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
|
||||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
#Copyright(c) Microsoft Corporation.All rights reserved.
|
|
||||||
#Licensed under the MIT license.
|
|
||||||
|
|
||||||
FROM ubuntu:jammy
|
|
||||||
|
|
||||||
RUN apt update
|
|
||||||
RUN apt install -y software-properties-common
|
|
||||||
RUN add-apt-repository -y ppa:git-core/ppa
|
|
||||||
RUN apt update
|
|
||||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev python3.10
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
RUN git clone https://github.com/microsoft/DiskANN.git
|
|
||||||
WORKDIR /app/DiskANN
|
|
||||||
RUN mkdir build
|
|
||||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
|
|
||||||
RUN cmake --build build -- -j
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
#Copyright(c) Microsoft Corporation.All rights reserved.
|
|
||||||
#Licensed under the MIT license.
|
|
||||||
|
|
||||||
FROM ubuntu:jammy
|
|
||||||
|
|
||||||
RUN apt update
|
|
||||||
RUN apt install -y software-properties-common
|
|
||||||
RUN add-apt-repository -y ppa:git-core/ppa
|
|
||||||
RUN apt update
|
|
||||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libboost-test-dev libmkl-full-dev libcpprest-dev python3.10
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
RUN git clone https://github.com/microsoft/DiskANN.git
|
|
||||||
WORKDIR /app/DiskANN
|
|
||||||
RUN mkdir build
|
|
||||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
|
|
||||||
RUN cmake --build build -- -j
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
DiskANN
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) Microsoft Corporation.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
include MANIFEST.in
|
|
||||||
include *.txt
|
|
||||||
include *.md
|
|
||||||
include setup.py
|
|
||||||
include pyproject.toml
|
|
||||||
include *.cmake
|
|
||||||
recursive-include gperftools *
|
|
||||||
recursive-include include *
|
|
||||||
recursive-include python *
|
|
||||||
recursive-include windows *
|
|
||||||
prune python/tests
|
|
||||||
recursive-include src *
|
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
# DiskANN
|
|
||||||
|
|
||||||
[](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml)
|
|
||||||
[](https://pypi.org/project/diskannpy/)
|
|
||||||
[](https://pepy.tech/project/diskannpy)
|
|
||||||
[](https://opensource.org/licenses/MIT)
|
|
||||||
|
|
||||||
[](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf)
|
|
||||||
[](https://arxiv.org/abs/2105.09613)
|
|
||||||
[](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf)
|
|
||||||
|
|
||||||
|
|
||||||
DiskANN is a suite of scalable, accurate and cost-effective approximate nearest neighbor search algorithms for large-scale vector search that support real-time changes and simple filters.
|
|
||||||
This code is based on ideas from the [DiskANN](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf), [Fresh-DiskANN](https://arxiv.org/abs/2105.09613) and the [Filtered-DiskANN](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf) papers with further improvements.
|
|
||||||
This code forked off from [code for NSG](https://github.com/ZJULearning/nsg) algorithm.
|
|
||||||
|
|
||||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
|
||||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
|
||||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
||||||
|
|
||||||
See [guidelines](CONTRIBUTING.md) for contributing to this project.
|
|
||||||
|
|
||||||
## Linux build:
|
|
||||||
|
|
||||||
Install the following packages through apt-get
|
|
||||||
|
|
||||||
```bash
|
|
||||||
sudo apt install make cmake g++ libaio-dev libgoogle-perftools-dev clang-format libboost-all-dev
|
|
||||||
```
|
|
||||||
|
|
||||||
### Install Intel MKL
|
|
||||||
#### Ubuntu 20.04 or newer
|
|
||||||
```bash
|
|
||||||
sudo apt install libmkl-full-dev
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Earlier versions of Ubuntu
|
|
||||||
Install Intel MKL either by downloading the [oneAPI MKL installer](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) or using [apt](https://software.intel.com/en-us/articles/installing-intel-free-libs-and-python-apt-repo) (we tested with build 2019.4-070 and 2022.1.2.146).
|
|
||||||
|
|
||||||
```
|
|
||||||
# OneAPI MKL Installer
|
|
||||||
wget https://registrationcenter-download.intel.com/akdlm/irc_nas/18487/l_BaseKit_p_2022.1.2.146.sh
|
|
||||||
sudo sh l_BaseKit_p_2022.1.2.146.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
|
||||||
```
|
|
||||||
|
|
||||||
### Build
|
|
||||||
```bash
|
|
||||||
mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
|
|
||||||
```
|
|
||||||
|
|
||||||
## Windows build:
|
|
||||||
|
|
||||||
The Windows version has been tested with Enterprise editions of Visual Studio 2022, 2019 and 2017. It should work with the Community and Professional editions as well without any changes.
|
|
||||||
|
|
||||||
**Prerequisites:**
|
|
||||||
|
|
||||||
* CMake 3.15+ (available in VisualStudio 2019+ or from https://cmake.org)
|
|
||||||
* NuGet.exe (install from https://www.nuget.org/downloads)
|
|
||||||
* The build script will use NuGet to get MKL, OpenMP and Boost packages.
|
|
||||||
* DiskANN git repository checked out together with submodules. To check out submodules after git clone:
|
|
||||||
```
|
|
||||||
git submodule init
|
|
||||||
git submodule update
|
|
||||||
```
|
|
||||||
|
|
||||||
* Environment variables:
|
|
||||||
* [optional] If you would like to override the Boost library listed in windows/packages.config.in, set BOOST_ROOT to your Boost folder.
|
|
||||||
|
|
||||||
**Build steps:**
|
|
||||||
* Open the "x64 Native Tools Command Prompt for VS 2019" (or corresponding version) and change to DiskANN folder
|
|
||||||
* Create a "build" directory inside it
|
|
||||||
* Change to the "build" directory and run
|
|
||||||
```
|
|
||||||
cmake ..
|
|
||||||
```
|
|
||||||
OR for Visual Studio 2017 and earlier:
|
|
||||||
```
|
|
||||||
<full-path-to-installed-cmake>\cmake ..
|
|
||||||
```
|
|
||||||
**This will create a diskann.sln solution**. Now you can:
|
|
||||||
|
|
||||||
- Open it from VisualStudio and build either Release or Debug configuration.
|
|
||||||
- `<full-path-to-installed-cmake>\cmake --build build`
|
|
||||||
- Use MSBuild:
|
|
||||||
```
|
|
||||||
msbuild.exe diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64"
|
|
||||||
```
|
|
||||||
|
|
||||||
* This will also build gperftools submodule for libtcmalloc_minimal dependency.
|
|
||||||
* Generated binaries are stored in the x64/Release or x64/Debug directories.
|
|
||||||
|
|
||||||
## macOS Build
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
* Apple Silicon. The code should still work on Intel-based Macs, but there are no guarantees.
|
|
||||||
* macOS >= 12.0
|
|
||||||
* XCode Command Line Tools (install with `xcode-select --install`)
|
|
||||||
* [homebrew](https://brew.sh/)
|
|
||||||
|
|
||||||
### Install Required Packages
|
|
||||||
```zsh
|
|
||||||
brew install cmake
|
|
||||||
brew install boost
|
|
||||||
brew install gperftools
|
|
||||||
brew install libomp
|
|
||||||
```
|
|
||||||
|
|
||||||
### Build DiskANN
|
|
||||||
```zsh
|
|
||||||
# same as ubuntu instructions
|
|
||||||
mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage:
|
|
||||||
|
|
||||||
Please see the following pages on using the compiled code:
|
|
||||||
|
|
||||||
- [Commandline interface for building and search SSD based indices](workflows/SSD_index.md)
|
|
||||||
- [Commandline interface for building and search in memory indices](workflows/in_memory_index.md)
|
|
||||||
- [Commandline examples for using in-memory streaming indices](workflows/dynamic_index.md)
|
|
||||||
- [Commandline interface for building and search in memory indices with label data and filters](workflows/filtered_in_memory.md)
|
|
||||||
- [Commandline interface for building and search SSD based indices with label data and filters](workflows/filtered_ssd_index.md)
|
|
||||||
- [diskannpy - DiskANN as a python extension module](python/README.md)
|
|
||||||
|
|
||||||
Please cite this software in your work as:
|
|
||||||
|
|
||||||
```
|
|
||||||
@misc{diskann-github,
|
|
||||||
author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan and Patel, Yash}},
|
|
||||||
title = {{DiskANN: Graph-structured Indices for Scalable, Fast, Fresh and Filtered Approximate Nearest Neighbor Search}},
|
|
||||||
url = {https://github.com/Microsoft/DiskANN},
|
|
||||||
version = {0.6.1},
|
|
||||||
year = {2023}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
|
|
||||||
|
|
||||||
## Security
|
|
||||||
|
|
||||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
|
||||||
|
|
||||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
|
|
||||||
|
|
||||||
## Reporting Security Issues
|
|
||||||
|
|
||||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
|
||||||
|
|
||||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
|
|
||||||
|
|
||||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
|
|
||||||
|
|
||||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
|
||||||
|
|
||||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
|
||||||
|
|
||||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
|
||||||
* Full paths of source file(s) related to the manifestation of the issue
|
|
||||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
|
||||||
* Any special configuration required to reproduce the issue
|
|
||||||
* Step-by-step instructions to reproduce the issue
|
|
||||||
* Proof-of-concept or exploit code (if possible)
|
|
||||||
* Impact of the issue, including how an attacker might exploit the issue
|
|
||||||
|
|
||||||
This information will help us triage your report more quickly.
|
|
||||||
|
|
||||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
|
|
||||||
|
|
||||||
## Preferred Languages
|
|
||||||
|
|
||||||
We prefer all communications to be in English.
|
|
||||||
|
|
||||||
## Policy
|
|
||||||
|
|
||||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
|
||||||
|
|
||||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
|
||||||
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
|
|
||||||
|
|
||||||
add_executable(build_memory_index build_memory_index.cpp)
|
|
||||||
target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(build_stitched_index build_stitched_index.cpp)
|
|
||||||
target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(search_memory_index search_memory_index.cpp)
|
|
||||||
target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(build_disk_index build_disk_index.cpp)
|
|
||||||
target_link_libraries(build_disk_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(search_disk_index search_disk_index.cpp)
|
|
||||||
target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(range_search_disk_index range_search_disk_index.cpp)
|
|
||||||
target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(test_streaming_scenario test_streaming_scenario.cpp)
|
|
||||||
target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(test_insert_deletes_consolidate test_insert_deletes_consolidate.cpp)
|
|
||||||
target_link_libraries(test_insert_deletes_consolidate ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
|
||||||
|
|
||||||
if (NOT MSVC)
|
|
||||||
install(TARGETS build_memory_index
|
|
||||||
build_stitched_index
|
|
||||||
search_memory_index
|
|
||||||
build_disk_index
|
|
||||||
search_disk_index
|
|
||||||
range_search_disk_index
|
|
||||||
test_streaming_scenario
|
|
||||||
test_insert_deletes_consolidate
|
|
||||||
RUNTIME
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <omp.h>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
|
|
||||||
#include "utils.h"
|
|
||||||
#include "disk_utils.h"
|
|
||||||
#include "math_utils.h"
|
|
||||||
#include "index.h"
|
|
||||||
#include "partition.h"
|
|
||||||
#include "program_options_utils.hpp"
|
|
||||||
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label,
|
|
||||||
label_type;
|
|
||||||
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold;
|
|
||||||
float B, M;
|
|
||||||
bool append_reorder_data = false;
|
|
||||||
bool use_opq = false;
|
|
||||||
|
|
||||||
po::options_description desc{
|
|
||||||
program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
|
|
||||||
// Required parameters
|
|
||||||
po::options_description required_configs("Required");
|
|
||||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
|
||||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
|
||||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
|
||||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
|
||||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
|
||||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
|
||||||
program_options_utils::INPUT_DATA_PATH);
|
|
||||||
required_configs.add_options()("search_DRAM_budget,B", po::value<float>(&B)->required(),
|
|
||||||
"DRAM budget in GB for searching the index to set the "
|
|
||||||
"compressed level for data while search happens");
|
|
||||||
required_configs.add_options()("build_DRAM_budget,M", po::value<float>(&M)->required(),
|
|
||||||
"DRAM budget in GB for building the index");
|
|
||||||
|
|
||||||
// Optional parameters
|
|
||||||
po::options_description optional_configs("Optional");
|
|
||||||
optional_configs.add_options()("num_threads,T",
|
|
||||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
|
||||||
program_options_utils::MAX_BUILD_DEGREE);
|
|
||||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
|
||||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
|
||||||
optional_configs.add_options()("QD", po::value<uint32_t>(&QD)->default_value(0),
|
|
||||||
" Quantized Dimension for compression");
|
|
||||||
optional_configs.add_options()("codebook_prefix", po::value<std::string>(&codebook_prefix)->default_value(""),
|
|
||||||
"Path prefix for pre-trained codebook");
|
|
||||||
optional_configs.add_options()("PQ_disk_bytes", po::value<uint32_t>(&disk_PQ)->default_value(0),
|
|
||||||
"Number of bytes to which vectors should be compressed "
|
|
||||||
"on SSD; 0 for no compression");
|
|
||||||
optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false),
|
|
||||||
"Include full precision data in the index. Use only in "
|
|
||||||
"conjuction with compressed data on SSD.");
|
|
||||||
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ)->default_value(0),
|
|
||||||
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
|
|
||||||
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
|
|
||||||
program_options_utils::USE_OPQ);
|
|
||||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
|
||||||
program_options_utils::LABEL_FILE);
|
|
||||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
|
||||||
program_options_utils::UNIVERSAL_LABEL);
|
|
||||||
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
|
|
||||||
program_options_utils::FILTERED_LBUILD);
|
|
||||||
optional_configs.add_options()("filter_threshold,F", po::value<uint32_t>(&filter_threshold)->default_value(0),
|
|
||||||
"Threshold to break up the existing nodes to generate new graph "
|
|
||||||
"internally where each node has a maximum F labels.");
|
|
||||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
|
||||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
|
||||||
|
|
||||||
// Merge required and optional parameters
|
|
||||||
desc.add(required_configs).add(optional_configs);
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
if (vm["append_reorder_data"].as<bool>())
|
|
||||||
append_reorder_data = true;
|
|
||||||
if (vm["use_opq"].as<bool>())
|
|
||||||
use_opq = true;
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << '\n';
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool use_filters = (label_file != "") ? true : false;
|
|
||||||
diskann::Metric metric;
|
|
||||||
if (dist_fn == std::string("l2"))
|
|
||||||
metric = diskann::Metric::L2;
|
|
||||||
else if (dist_fn == std::string("mips"))
|
|
||||||
metric = diskann::Metric::INNER_PRODUCT;
|
|
||||||
else if (dist_fn == std::string("cosine"))
|
|
||||||
metric = diskann::Metric::COSINE;
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (append_reorder_data)
|
|
||||||
{
|
|
||||||
if (disk_PQ == 0)
|
|
||||||
{
|
|
||||||
std::cout << "Error: It is not necessary to append data for reordering "
|
|
||||||
"when vectors are not compressed on disk."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (data_type != std::string("float"))
|
|
||||||
{
|
|
||||||
std::cout << "Error: Appending data for reordering currently only "
|
|
||||||
"supported for float data type."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " +
|
|
||||||
std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " +
|
|
||||||
std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " +
|
|
||||||
std::string(std::to_string(append_reorder_data)) + " " +
|
|
||||||
std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD));
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
if (label_file != "" && label_type == "ushort")
|
|
||||||
{
|
|
||||||
if (data_type == std::string("int8"))
|
|
||||||
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
|
||||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
|
||||||
universal_label, filter_threshold, Lf);
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
return diskann::build_disk_index<uint8_t, uint16_t>(
|
|
||||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
|
|
||||||
use_filters, label_file, universal_label, filter_threshold, Lf);
|
|
||||||
else if (data_type == std::string("float"))
|
|
||||||
return diskann::build_disk_index<float, uint16_t>(
|
|
||||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
|
|
||||||
use_filters, label_file, universal_label, filter_threshold, Lf);
|
|
||||||
else
|
|
||||||
{
|
|
||||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (data_type == std::string("int8"))
|
|
||||||
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
|
||||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
|
||||||
universal_label, filter_threshold, Lf);
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
return diskann::build_disk_index<uint8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
|
||||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
|
||||||
universal_label, filter_threshold, Lf);
|
|
||||||
else if (data_type == std::string("float"))
|
|
||||||
return diskann::build_disk_index<float>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
|
||||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
|
||||||
universal_label, filter_threshold, Lf);
|
|
||||||
else
|
|
||||||
{
|
|
||||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &e)
|
|
||||||
{
|
|
||||||
std::cout << std::string(e.what()) << std::endl;
|
|
||||||
diskann::cerr << "Index build failed." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <omp.h>
|
|
||||||
#include <cstring>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
|
|
||||||
#include "index.h"
|
|
||||||
#include "utils.h"
|
|
||||||
#include "program_options_utils.hpp"
|
|
||||||
|
|
||||||
#ifndef _WINDOWS
|
|
||||||
#include <sys/mman.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#else
|
|
||||||
#include <Windows.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "memory_mapper.h"
|
|
||||||
#include "ann_exception.h"
|
|
||||||
#include "index_factory.h"
|
|
||||||
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
|
|
||||||
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
|
|
||||||
float alpha;
|
|
||||||
bool use_pq_build, use_opq;
|
|
||||||
|
|
||||||
po::options_description desc{
|
|
||||||
program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
|
|
||||||
// Required parameters
|
|
||||||
po::options_description required_configs("Required");
|
|
||||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
|
||||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
|
||||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
|
||||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
|
||||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
|
||||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
|
||||||
program_options_utils::INPUT_DATA_PATH);
|
|
||||||
|
|
||||||
// Optional parameters
|
|
||||||
po::options_description optional_configs("Optional");
|
|
||||||
optional_configs.add_options()("num_threads,T",
|
|
||||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
|
||||||
program_options_utils::MAX_BUILD_DEGREE);
|
|
||||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
|
||||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
|
||||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
|
||||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
|
||||||
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
|
|
||||||
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
|
|
||||||
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
|
|
||||||
program_options_utils::USE_OPQ);
|
|
||||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
|
||||||
program_options_utils::LABEL_FILE);
|
|
||||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
|
||||||
program_options_utils::UNIVERSAL_LABEL);
|
|
||||||
|
|
||||||
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
|
|
||||||
program_options_utils::FILTERED_LBUILD);
|
|
||||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
|
||||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
|
||||||
|
|
||||||
// Merge required and optional parameters
|
|
||||||
desc.add(required_configs).add(optional_configs);
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
use_pq_build = (build_PQ_bytes > 0);
|
|
||||||
use_opq = vm["use_opq"].as<bool>();
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << '\n';
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::Metric metric;
|
|
||||||
if (dist_fn == std::string("mips"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::INNER_PRODUCT;
|
|
||||||
}
|
|
||||||
else if (dist_fn == std::string("l2"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::L2;
|
|
||||||
}
|
|
||||||
else if (dist_fn == std::string("cosine"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::COSINE;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
|
||||||
"Product/Cosine are supported."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha
|
|
||||||
<< " #threads: " << num_threads << std::endl;
|
|
||||||
|
|
||||||
size_t data_num, data_dim;
|
|
||||||
diskann::get_bin_metadata(data_path, data_num, data_dim);
|
|
||||||
|
|
||||||
auto index_build_params = diskann::IndexWriteParametersBuilder(L, R)
|
|
||||||
.with_filter_list_size(Lf)
|
|
||||||
.with_alpha(alpha)
|
|
||||||
.with_saturate_graph(false)
|
|
||||||
.with_num_threads(num_threads)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto filter_params = diskann::IndexFilterParamsBuilder()
|
|
||||||
.with_universal_label(universal_label)
|
|
||||||
.with_label_file(label_file)
|
|
||||||
.with_save_path_prefix(index_path_prefix)
|
|
||||||
.build();
|
|
||||||
auto config = diskann::IndexConfigBuilder()
|
|
||||||
.with_metric(metric)
|
|
||||||
.with_dimension(data_dim)
|
|
||||||
.with_max_points(data_num)
|
|
||||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
|
||||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
|
||||||
.with_data_type(data_type)
|
|
||||||
.with_label_type(label_type)
|
|
||||||
.is_dynamic_index(false)
|
|
||||||
.with_index_write_params(index_build_params)
|
|
||||||
.is_enable_tags(false)
|
|
||||||
.is_use_opq(use_opq)
|
|
||||||
.is_pq_dist_build(use_pq_build)
|
|
||||||
.with_num_pq_chunks(build_PQ_bytes)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto index_factory = diskann::IndexFactory(config);
|
|
||||||
auto index = index_factory.create_instance();
|
|
||||||
index->build(data_path, data_num, filter_params);
|
|
||||||
index->save(index_path_prefix.c_str());
|
|
||||||
index.reset();
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
catch (const std::exception &e)
|
|
||||||
{
|
|
||||||
std::cout << std::string(e.what()) << std::endl;
|
|
||||||
diskann::cerr << "Index build failed." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,441 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
#include <chrono>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstring>
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
|
||||||
#include <tuple>
|
|
||||||
#include "filter_utils.h"
|
|
||||||
#include <omp.h>
|
|
||||||
#ifndef _WINDOWS
|
|
||||||
#include <sys/uio.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "index.h"
|
|
||||||
#include "memory_mapper.h"
|
|
||||||
#include "parameters.h"
|
|
||||||
#include "utils.h"
|
|
||||||
#include "program_options_utils.hpp"
|
|
||||||
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
typedef std::tuple<std::vector<std::vector<uint32_t>>, uint64_t> stitch_indices_return_values;
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Inline function to display progress bar.
|
|
||||||
*/
|
|
||||||
inline void print_progress(double percentage)
|
|
||||||
{
|
|
||||||
int val = (int)(percentage * 100);
|
|
||||||
int lpad = (int)(percentage * PBWIDTH);
|
|
||||||
int rpad = PBWIDTH - lpad;
|
|
||||||
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Inline function to generate a random integer in a range.
|
|
||||||
*/
|
|
||||||
inline size_t random(size_t range_from, size_t range_to)
|
|
||||||
{
|
|
||||||
std::random_device rand_dev;
|
|
||||||
std::mt19937 generator(rand_dev());
|
|
||||||
std::uniform_int_distribution<size_t> distr(range_from, range_to);
|
|
||||||
return distr(generator);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* function to handle command line parsing.
|
|
||||||
*
|
|
||||||
* Arguments are merely the inputs from the command line.
|
|
||||||
*/
|
|
||||||
void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix,
|
|
||||||
path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L,
|
|
||||||
uint32_t &stitched_R, float &alpha)
|
|
||||||
{
|
|
||||||
po::options_description desc{
|
|
||||||
program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
|
|
||||||
// Required parameters
|
|
||||||
po::options_description required_configs("Required");
|
|
||||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
|
||||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("index_path_prefix",
|
|
||||||
po::value<std::string>(&final_index_path_prefix)->required(),
|
|
||||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
|
||||||
required_configs.add_options()("data_path", po::value<std::string>(&input_data_path)->required(),
|
|
||||||
program_options_utils::INPUT_DATA_PATH);
|
|
||||||
|
|
||||||
// Optional parameters
|
|
||||||
po::options_description optional_configs("Optional");
|
|
||||||
optional_configs.add_options()("num_threads,T",
|
|
||||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
|
||||||
program_options_utils::MAX_BUILD_DEGREE);
|
|
||||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
|
||||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
|
||||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
|
||||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
|
||||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_data_path)->default_value(""),
|
|
||||||
program_options_utils::LABEL_FILE);
|
|
||||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
|
||||||
program_options_utils::UNIVERSAL_LABEL);
|
|
||||||
optional_configs.add_options()("stitched_R", po::value<uint32_t>(&stitched_R)->default_value(100),
|
|
||||||
"Degree to prune final graph down to");
|
|
||||||
|
|
||||||
// Merge required and optional parameters
|
|
||||||
desc.add(required_configs).add(optional_configs);
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << '\n';
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Custom index save to write the in-memory index to disk.
|
|
||||||
* Also writes required files for diskANN API -
|
|
||||||
* 1. labels_to_medoids
|
|
||||||
* 2. universal_label
|
|
||||||
* 3. data (redundant for static indices)
|
|
||||||
* 4. labels (redundant for static indices)
|
|
||||||
*/
|
|
||||||
void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size,
|
|
||||||
std::vector<std::vector<uint32_t>> stitched_graph,
|
|
||||||
tsl::robin_map<std::string, uint32_t> entry_points, std::string universal_label,
|
|
||||||
path label_data_path)
|
|
||||||
{
|
|
||||||
// aux. file 1
|
|
||||||
auto saving_index_timer = std::chrono::high_resolution_clock::now();
|
|
||||||
std::ifstream original_label_data_stream;
|
|
||||||
original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
|
||||||
original_label_data_stream.open(label_data_path, std::ios::binary);
|
|
||||||
std::ofstream new_label_data_stream;
|
|
||||||
new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
|
||||||
new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary);
|
|
||||||
new_label_data_stream << original_label_data_stream.rdbuf();
|
|
||||||
original_label_data_stream.close();
|
|
||||||
new_label_data_stream.close();
|
|
||||||
|
|
||||||
// aux. file 2
|
|
||||||
std::ifstream original_input_data_stream;
|
|
||||||
original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
|
||||||
original_input_data_stream.open(input_data_path, std::ios::binary);
|
|
||||||
std::ofstream new_input_data_stream;
|
|
||||||
new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
|
||||||
new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary);
|
|
||||||
new_input_data_stream << original_input_data_stream.rdbuf();
|
|
||||||
original_input_data_stream.close();
|
|
||||||
new_input_data_stream.close();
|
|
||||||
|
|
||||||
// aux. file 3
|
|
||||||
std::ofstream labels_to_medoids_writer;
|
|
||||||
labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
|
||||||
labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt");
|
|
||||||
for (auto iter : entry_points)
|
|
||||||
labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl;
|
|
||||||
labels_to_medoids_writer.close();
|
|
||||||
|
|
||||||
// aux. file 4 (only if we're using a universal label)
|
|
||||||
if (universal_label != "")
|
|
||||||
{
|
|
||||||
std::ofstream universal_label_writer;
|
|
||||||
universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
|
||||||
universal_label_writer.open(final_index_path_prefix + "_universal_label.txt");
|
|
||||||
universal_label_writer << universal_label << std::endl;
|
|
||||||
universal_label_writer.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
// main index
|
|
||||||
uint64_t index_num_frozen_points = 0, index_num_edges = 0;
|
|
||||||
uint32_t index_max_observed_degree = 0, index_entry_point = 0;
|
|
||||||
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
|
|
||||||
for (auto &point_neighbors : stitched_graph)
|
|
||||||
{
|
|
||||||
index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::ofstream stitched_graph_writer;
|
|
||||||
stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
|
||||||
stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary);
|
|
||||||
|
|
||||||
stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t));
|
|
||||||
stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t));
|
|
||||||
stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t));
|
|
||||||
stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t));
|
|
||||||
|
|
||||||
size_t bytes_written = METADATA;
|
|
||||||
for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++)
|
|
||||||
{
|
|
||||||
uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size();
|
|
||||||
std::vector<uint32_t> current_node_neighbors = stitched_graph[node_point];
|
|
||||||
stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(uint32_t));
|
|
||||||
bytes_written += sizeof(uint32_t);
|
|
||||||
for (const auto ¤t_node_neighbor : current_node_neighbors)
|
|
||||||
{
|
|
||||||
stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(uint32_t));
|
|
||||||
bytes_written += sizeof(uint32_t);
|
|
||||||
}
|
|
||||||
index_num_edges += current_node_num_neighbors;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bytes_written != final_index_size)
|
|
||||||
{
|
|
||||||
std::cerr << "Error: written bytes does not match allocated space" << std::endl;
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
|
|
||||||
stitched_graph_writer.close();
|
|
||||||
|
|
||||||
std::chrono::duration<double> saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer;
|
|
||||||
std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl;
|
|
||||||
std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size()))
|
|
||||||
<< std::endl;
|
|
||||||
std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Unions the per-label graph indices together via the following policy:
|
|
||||||
* - any two nodes can only have at most one edge between them -
|
|
||||||
*
|
|
||||||
* Returns the "stitched" graph and its expected file size.
|
|
||||||
*/
|
|
||||||
template <typename T>
|
|
||||||
stitch_indices_return_values stitch_label_indices(
|
|
||||||
path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels,
|
|
||||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
|
|
||||||
tsl::robin_map<std::string, uint32_t> &label_entry_points,
|
|
||||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map)
|
|
||||||
{
|
|
||||||
size_t final_index_size = 0;
|
|
||||||
std::vector<std::vector<uint32_t>> stitched_graph(total_number_of_points);
|
|
||||||
|
|
||||||
auto stitching_index_timer = std::chrono::high_resolution_clock::now();
|
|
||||||
for (const auto &lbl : all_labels)
|
|
||||||
{
|
|
||||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
|
||||||
std::vector<std::vector<uint32_t>> curr_label_index;
|
|
||||||
uint64_t curr_label_index_size;
|
|
||||||
uint32_t curr_label_entry_point;
|
|
||||||
|
|
||||||
std::tie(curr_label_index, curr_label_index_size) =
|
|
||||||
diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]);
|
|
||||||
curr_label_entry_point = (uint32_t)random(0, curr_label_index.size());
|
|
||||||
label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point];
|
|
||||||
|
|
||||||
for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++)
|
|
||||||
{
|
|
||||||
uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point];
|
|
||||||
for (auto &node_neighbor : curr_label_index[node_point])
|
|
||||||
{
|
|
||||||
uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor];
|
|
||||||
std::vector<uint32_t> curr_point_neighbors = stitched_graph[original_point_id];
|
|
||||||
if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) ==
|
|
||||||
curr_point_neighbors.end())
|
|
||||||
{
|
|
||||||
stitched_graph[original_point_id].push_back(original_neighbor_id);
|
|
||||||
final_index_size += sizeof(uint32_t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
|
|
||||||
final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA);
|
|
||||||
|
|
||||||
std::chrono::duration<double> stitching_index_time =
|
|
||||||
std::chrono::high_resolution_clock::now() - stitching_index_timer;
|
|
||||||
std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl;
|
|
||||||
|
|
||||||
return std::make_tuple(stitched_graph, final_index_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Applies the prune_neighbors function from src/index.cpp to
|
|
||||||
* every node in the stitched graph.
|
|
||||||
*
|
|
||||||
* This is an optional step, hence the saving of both the full
|
|
||||||
* and pruned graph.
|
|
||||||
*/
|
|
||||||
template <typename T>
|
|
||||||
void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path,
|
|
||||||
std::vector<std::vector<uint32_t>> stitched_graph, uint32_t stitched_R,
|
|
||||||
tsl::robin_map<std::string, uint32_t> label_entry_points, std::string universal_label,
|
|
||||||
path label_data_path, uint32_t num_threads)
|
|
||||||
{
|
|
||||||
size_t dimension, number_of_label_points;
|
|
||||||
auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr);
|
|
||||||
auto std_cout_buffer = std::cout.rdbuf(nullptr);
|
|
||||||
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
|
|
||||||
|
|
||||||
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false,
|
|
||||||
false, false, 0, false);
|
|
||||||
|
|
||||||
// not searching this index, set search_l to 0
|
|
||||||
index.load(full_index_path_prefix.c_str(), num_threads, 1);
|
|
||||||
|
|
||||||
std::cout << "parsing labels" << std::endl;
|
|
||||||
|
|
||||||
index.prune_all_neighbors(stitched_R, 750, 1.2);
|
|
||||||
index.save((final_index_path_prefix).c_str());
|
|
||||||
|
|
||||||
diskann::cout.rdbuf(diskann_cout_buffer);
|
|
||||||
std::cout.rdbuf(std_cout_buffer);
|
|
||||||
std::chrono::duration<double> pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer;
|
|
||||||
std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Delete all temporary artifacts.
|
|
||||||
* In the process of creating the stitched index, some temporary artifacts are
|
|
||||||
* created:
|
|
||||||
* 1. the separate bin files for each labels' points
|
|
||||||
* 2. the separate diskANN indices built for each label
|
|
||||||
* 3. the '.data' file created while generating the indices
|
|
||||||
*/
|
|
||||||
void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels)
|
|
||||||
{
|
|
||||||
for (const auto &lbl : all_labels)
|
|
||||||
{
|
|
||||||
path curr_label_input_data_path(input_data_path + "_" + lbl);
|
|
||||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
|
||||||
path curr_label_index_path_data(curr_label_index_path + ".data");
|
|
||||||
|
|
||||||
if (std::remove(curr_label_index_path.c_str()) != 0)
|
|
||||||
throw;
|
|
||||||
if (std::remove(curr_label_input_data_path.c_str()) != 0)
|
|
||||||
throw;
|
|
||||||
if (std::remove(curr_label_index_path_data.c_str()) != 0)
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
// 1. handle cmdline inputs
|
|
||||||
std::string data_type;
|
|
||||||
path input_data_path, final_index_path_prefix, label_data_path;
|
|
||||||
std::string universal_label;
|
|
||||||
uint32_t num_threads, R, L, stitched_R;
|
|
||||||
float alpha;
|
|
||||||
|
|
||||||
auto index_timer = std::chrono::high_resolution_clock::now();
|
|
||||||
handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label,
|
|
||||||
num_threads, R, L, stitched_R, alpha);
|
|
||||||
|
|
||||||
path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt";
|
|
||||||
path labels_map_file = final_index_path_prefix + "_labels_map.txt";
|
|
||||||
|
|
||||||
convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label);
|
|
||||||
|
|
||||||
// 2. parse label file and create necessary data structures
|
|
||||||
std::vector<label_set> point_ids_to_labels;
|
|
||||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
|
|
||||||
label_set all_labels;
|
|
||||||
|
|
||||||
std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) =
|
|
||||||
diskann::parse_label_file(labels_file_to_use, universal_label);
|
|
||||||
|
|
||||||
// 3. for each label, make a separate data file
|
|
||||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map;
|
|
||||||
uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size();
|
|
||||||
|
|
||||||
#ifndef _WINDOWS
|
|
||||||
if (data_type == "uint8")
|
|
||||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<uint8_t>(
|
|
||||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
|
||||||
else if (data_type == "int8")
|
|
||||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<int8_t>(
|
|
||||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
|
||||||
else if (data_type == "float")
|
|
||||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<float>(
|
|
||||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
|
||||||
else
|
|
||||||
throw;
|
|
||||||
#else
|
|
||||||
if (data_type == "uint8")
|
|
||||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
|
||||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
|
||||||
else if (data_type == "int8")
|
|
||||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
|
||||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
|
||||||
else if (data_type == "float")
|
|
||||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<float>(
|
|
||||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
|
||||||
else
|
|
||||||
throw;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// 4. for each created data file, create a vanilla diskANN index
|
|
||||||
if (data_type == "uint8")
|
|
||||||
diskann::generate_label_indices<uint8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
|
||||||
num_threads);
|
|
||||||
else if (data_type == "int8")
|
|
||||||
diskann::generate_label_indices<int8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
|
||||||
num_threads);
|
|
||||||
else if (data_type == "float")
|
|
||||||
diskann::generate_label_indices<float>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
|
||||||
num_threads);
|
|
||||||
else
|
|
||||||
throw;
|
|
||||||
|
|
||||||
// 5. "stitch" the indices together
|
|
||||||
std::vector<std::vector<uint32_t>> stitched_graph;
|
|
||||||
tsl::robin_map<std::string, uint32_t> label_entry_points;
|
|
||||||
uint64_t stitched_graph_size;
|
|
||||||
|
|
||||||
if (data_type == "uint8")
|
|
||||||
std::tie(stitched_graph, stitched_graph_size) =
|
|
||||||
stitch_label_indices<uint8_t>(final_index_path_prefix, total_number_of_points, all_labels,
|
|
||||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
|
||||||
else if (data_type == "int8")
|
|
||||||
std::tie(stitched_graph, stitched_graph_size) =
|
|
||||||
stitch_label_indices<int8_t>(final_index_path_prefix, total_number_of_points, all_labels,
|
|
||||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
|
||||||
else if (data_type == "float")
|
|
||||||
std::tie(stitched_graph, stitched_graph_size) =
|
|
||||||
stitch_label_indices<float>(final_index_path_prefix, total_number_of_points, all_labels,
|
|
||||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
|
||||||
else
|
|
||||||
throw;
|
|
||||||
path full_index_path_prefix = final_index_path_prefix + "_full";
|
|
||||||
// 5a. save the stitched graph to disk
|
|
||||||
save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points,
|
|
||||||
universal_label, labels_file_to_use);
|
|
||||||
|
|
||||||
// 6. run a prune on the stitched index, and save to disk
|
|
||||||
if (data_type == "uint8")
|
|
||||||
prune_and_save<uint8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
|
||||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
|
||||||
else if (data_type == "int8")
|
|
||||||
prune_and_save<int8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
|
||||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
|
||||||
else if (data_type == "float")
|
|
||||||
prune_and_save<float>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
|
||||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
|
||||||
else
|
|
||||||
throw;
|
|
||||||
|
|
||||||
std::chrono::duration<double> index_time = std::chrono::high_resolution_clock::now() - index_timer;
|
|
||||||
std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl;
|
|
||||||
|
|
||||||
clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels);
|
|
||||||
}
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
<!-- Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
Licensed under the MIT license. -->
|
|
||||||
|
|
||||||
# Integration Tests
|
|
||||||
The following tests use Python to prepare, run, verify, and tear down the rest api services.
|
|
||||||
|
|
||||||
We do make use of the built-in `unittest` library, but that's only to take advantage of test reporting purposes.
|
|
||||||
|
|
||||||
These are decidedly **not** _unit_ tests. These are end to end integration tests.
|
|
||||||
|
|
||||||
## Caveats
|
|
||||||
This has only been tested or built for Linux, though we have written platform agnostic Python for the smoke test
|
|
||||||
(i.e. using `os.path.join`, etc)
|
|
||||||
|
|
||||||
It has been tested on Python 3.9 and 3.10, but should work on Python 3.6+.
|
|
||||||
|
|
||||||
## How to Run
|
|
||||||
|
|
||||||
First, build the DiskANN RestAPI code; see $REPOSITORY_ROOT/workflows/rest_api.md for detailed instructions.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd tests/python
|
|
||||||
python3 -m venv venv
|
|
||||||
source venv/bin/activate
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
export DISKANN_BUILD_DIR=/path/to/your/diskann/build
|
|
||||||
python -m unittest
|
|
||||||
```
|
|
||||||
|
|
||||||
## Smoke Test Failed, Now What?
|
|
||||||
The smoke test written takes advantage of temporary directories that are only valid during the
|
|
||||||
lifetime of the test. The contents of these directories include:
|
|
||||||
- Randomized vectors (first in tsv, then bin form) used to build the PQFlashIndex
|
|
||||||
- The PQFlashIndex files
|
|
||||||
|
|
||||||
It is useful to keep these around. By setting some environment variables, you can control whether an ephemeral,
|
|
||||||
temporary directory is used (and deleted on test completion), or left as an exercise for the developer to
|
|
||||||
clean up.
|
|
||||||
|
|
||||||
The valid environment variables are:
|
|
||||||
- `DISKANN_REST_TEST_WORKING_DIR` (example: `$USER/DiskANNRestTest`)
|
|
||||||
- If this is specified, it **must exist** and **must be writeable**. Any existing files will be clobbered.
|
|
||||||
- `DISKANN_REST_SERVER` (example: `http://127.0.0.1:10067`)
|
|
||||||
- Note that if this is set, no data will be generated, nor will a server be started; it is presumed you have done
|
|
||||||
all the work in creating and starting the rest server prior to running the test and just submits requests against it.
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
|
|
||||||
def output_vectors(
|
|
||||||
diskann_build_path: str,
|
|
||||||
temporary_file_path: str,
|
|
||||||
vectors: np.ndarray,
|
|
||||||
timeout: int = 60
|
|
||||||
) -> str:
|
|
||||||
vectors_as_tsv_path = os.path.join(temporary_file_path, "vectors.tsv")
|
|
||||||
with open(vectors_as_tsv_path, "w") as vectors_tsv_out:
|
|
||||||
for vector in vectors:
|
|
||||||
as_str = "\t".join((str(component) for component in vector))
|
|
||||||
print(as_str, file=vectors_tsv_out)
|
|
||||||
# there is probably a clever way to have numpy write out C++ friendly floats, so feel free to remove this in
|
|
||||||
# favor of something more sane later
|
|
||||||
vectors_as_bin_path = os.path.join(temporary_file_path, "vectors.bin")
|
|
||||||
tsv_to_bin_path = os.path.join(diskann_build_path, "apps", "utils", "tsv_to_bin")
|
|
||||||
|
|
||||||
number_of_points, dimensions = vectors.shape
|
|
||||||
args = [
|
|
||||||
tsv_to_bin_path,
|
|
||||||
"float",
|
|
||||||
vectors_as_tsv_path,
|
|
||||||
vectors_as_bin_path,
|
|
||||||
str(dimensions),
|
|
||||||
str(number_of_points)
|
|
||||||
]
|
|
||||||
completed = subprocess.run(args, timeout=timeout)
|
|
||||||
if completed.returncode != 0:
|
|
||||||
raise Exception(f"Unable to convert tsv to binary using tsv_to_bin, completed_process: {completed}")
|
|
||||||
return vectors_as_bin_path
|
|
||||||
|
|
||||||
|
|
||||||
def build_ssd_index(
|
|
||||||
diskann_build_path: str,
|
|
||||||
temporary_file_path: str,
|
|
||||||
vectors: np.ndarray,
|
|
||||||
per_process_timeout: int = 60 # this may not be long enough if you're doing something larger
|
|
||||||
):
|
|
||||||
vectors_as_bin_path = output_vectors(diskann_build_path, temporary_file_path, vectors, timeout=per_process_timeout)
|
|
||||||
|
|
||||||
ssd_builder_path = os.path.join(diskann_build_path, "apps", "build_disk_index")
|
|
||||||
args = [
|
|
||||||
ssd_builder_path,
|
|
||||||
"--data_type", "float",
|
|
||||||
"--dist_fn", "l2",
|
|
||||||
"--data_path", vectors_as_bin_path,
|
|
||||||
"--index_path_prefix", os.path.join(temporary_file_path, "smoke_test"),
|
|
||||||
"-R", "64",
|
|
||||||
"-L", "100",
|
|
||||||
"--search_DRAM_budget", "1",
|
|
||||||
"--build_DRAM_budget", "1",
|
|
||||||
"--num_threads", "1",
|
|
||||||
"--PQ_disk_bytes", "0"
|
|
||||||
]
|
|
||||||
completed = subprocess.run(args, timeout=per_process_timeout)
|
|
||||||
|
|
||||||
if completed.returncode != 0:
|
|
||||||
command_run = " ".join(args)
|
|
||||||
raise Exception(f"Unable to build a disk index with the command: '{command_run}'\ncompleted_process: {completed}\nstdout: {completed.stdout}\nstderr: {completed.stderr}")
|
|
||||||
# index is now built inside of temporary_file_path
|
|
||||||
@@ -1,379 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <atomic>
|
|
||||||
#include <cstring>
|
|
||||||
#include <iomanip>
|
|
||||||
#include <omp.h>
|
|
||||||
#include <set>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
|
|
||||||
#include "index.h"
|
|
||||||
#include "disk_utils.h"
|
|
||||||
#include "math_utils.h"
|
|
||||||
#include "memory_mapper.h"
|
|
||||||
#include "pq_flash_index.h"
|
|
||||||
#include "partition.h"
|
|
||||||
#include "timer.h"
|
|
||||||
#include "program_options_utils.hpp"
|
|
||||||
|
|
||||||
#ifndef _WINDOWS
|
|
||||||
#include <sys/mman.h>
|
|
||||||
#include <sys/stat.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include "linux_aligned_file_reader.h"
|
|
||||||
#else
|
|
||||||
#ifdef USE_BING_INFRA
|
|
||||||
#include "bing_aligned_file_reader.h"
|
|
||||||
#else
|
|
||||||
#include "windows_aligned_file_reader.h"
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
#define WARMUP false
|
|
||||||
|
|
||||||
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(20) << category << ": " << std::flush;
|
|
||||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(8) << percentiles[s] << "%";
|
|
||||||
}
|
|
||||||
diskann::cout << std::endl;
|
|
||||||
diskann::cout << std::setw(22) << " " << std::flush;
|
|
||||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(9) << results[s];
|
|
||||||
}
|
|
||||||
diskann::cout << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename LabelT = uint32_t>
|
|
||||||
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file,
|
|
||||||
std::string >_file, const uint32_t num_threads, const float search_range,
|
|
||||||
const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector<uint32_t> &Lvec)
|
|
||||||
{
|
|
||||||
std::string pq_prefix = index_path_prefix + "_pq";
|
|
||||||
std::string disk_index_file = index_path_prefix + "_disk.index";
|
|
||||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
|
||||||
|
|
||||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
|
||||||
if (beamwidth <= 0)
|
|
||||||
diskann::cout << "beamwidth to be optimized for each L value" << std::endl;
|
|
||||||
else
|
|
||||||
diskann::cout << " beamwidth: " << beamwidth << std::endl;
|
|
||||||
|
|
||||||
// load query bin
|
|
||||||
T *query = nullptr;
|
|
||||||
std::vector<std::vector<uint32_t>> groundtruth_ids;
|
|
||||||
size_t query_num, query_dim, query_aligned_dim, gt_num;
|
|
||||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
|
||||||
|
|
||||||
bool calc_recall_flag = false;
|
|
||||||
if (gt_file != std::string("null") && file_exists(gt_file))
|
|
||||||
{
|
|
||||||
diskann::load_range_truthset(gt_file, groundtruth_ids,
|
|
||||||
gt_num); // use for range search type of truthset
|
|
||||||
// diskann::prune_truthset_for_range(gt_file, search_range,
|
|
||||||
// groundtruth_ids, gt_num); // use for traditional truthset
|
|
||||||
if (gt_num != query_num)
|
|
||||||
{
|
|
||||||
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
calc_recall_flag = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<AlignedFileReader> reader = nullptr;
|
|
||||||
#ifdef _WINDOWS
|
|
||||||
#ifndef USE_BING_INFRA
|
|
||||||
reader.reset(new WindowsAlignedFileReader());
|
|
||||||
#else
|
|
||||||
reader.reset(new diskann::BingAlignedFileReader());
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
reader.reset(new LinuxAlignedFileReader());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
|
||||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
|
||||||
|
|
||||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
|
||||||
|
|
||||||
if (res != 0)
|
|
||||||
{
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
// cache bfs levels
|
|
||||||
std::vector<uint32_t> node_list;
|
|
||||||
diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl;
|
|
||||||
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
|
|
||||||
// _pFlashIndex->generate_cache_list_from_sample_queries(
|
|
||||||
// warmup_query_file, 15, 6, num_nodes_to_cache, num_threads,
|
|
||||||
// node_list);
|
|
||||||
_pFlashIndex->load_cache_list(node_list);
|
|
||||||
node_list.clear();
|
|
||||||
node_list.shrink_to_fit();
|
|
||||||
|
|
||||||
omp_set_num_threads(num_threads);
|
|
||||||
|
|
||||||
uint64_t warmup_L = 20;
|
|
||||||
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
|
|
||||||
T *warmup = nullptr;
|
|
||||||
|
|
||||||
if (WARMUP)
|
|
||||||
{
|
|
||||||
if (file_exists(warmup_query_file))
|
|
||||||
{
|
|
||||||
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
|
|
||||||
warmup_dim = query_dim;
|
|
||||||
warmup_aligned_dim = query_aligned_dim;
|
|
||||||
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
|
|
||||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
|
||||||
std::random_device rd;
|
|
||||||
std::mt19937 gen(rd());
|
|
||||||
std::uniform_int_distribution<> dis(-128, 127);
|
|
||||||
for (uint32_t i = 0; i < warmup_num; i++)
|
|
||||||
{
|
|
||||||
for (uint32_t d = 0; d < warmup_dim; d++)
|
|
||||||
{
|
|
||||||
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
diskann::cout << "Warming up index... " << std::flush;
|
|
||||||
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
|
|
||||||
std::vector<float> warmup_result_dists(warmup_num, 0);
|
|
||||||
|
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
|
||||||
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
|
|
||||||
{
|
|
||||||
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
|
|
||||||
warmup_result_ids_64.data() + (i * 1),
|
|
||||||
warmup_result_dists.data() + (i * 1), 4);
|
|
||||||
}
|
|
||||||
diskann::cout << "..done" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
|
||||||
diskann::cout.precision(2);
|
|
||||||
|
|
||||||
std::string recall_string = "Recall@rng=" + std::to_string(search_range);
|
|
||||||
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
|
|
||||||
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
|
|
||||||
<< "CPU (s)";
|
|
||||||
if (calc_recall_flag)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(16) << recall_string << std::endl;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
diskann::cout << std::endl;
|
|
||||||
diskann::cout << "==============================================================="
|
|
||||||
"==========================================="
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
std::vector<std::vector<std::vector<uint32_t>>> query_result_ids(Lvec.size());
|
|
||||||
|
|
||||||
uint32_t optimized_beamwidth = 2;
|
|
||||||
uint32_t max_list_size = 10000;
|
|
||||||
|
|
||||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
|
||||||
{
|
|
||||||
uint32_t L = Lvec[test_id];
|
|
||||||
|
|
||||||
if (beamwidth <= 0)
|
|
||||||
{
|
|
||||||
optimized_beamwidth =
|
|
||||||
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
optimized_beamwidth = beamwidth;
|
|
||||||
|
|
||||||
query_result_ids[test_id].clear();
|
|
||||||
query_result_ids[test_id].resize(query_num);
|
|
||||||
|
|
||||||
diskann::QueryStats *stats = new diskann::QueryStats[query_num];
|
|
||||||
|
|
||||||
auto s = std::chrono::high_resolution_clock::now();
|
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
|
||||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
|
||||||
{
|
|
||||||
std::vector<uint64_t> indices;
|
|
||||||
std::vector<float> distances;
|
|
||||||
uint32_t res_count =
|
|
||||||
_pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices,
|
|
||||||
distances, optimized_beamwidth, stats + i);
|
|
||||||
query_result_ids[test_id][i].reserve(res_count);
|
|
||||||
query_result_ids[test_id][i].resize(res_count);
|
|
||||||
for (uint32_t idx = 0; idx < res_count; idx++)
|
|
||||||
query_result_ids[test_id][i][idx] = (uint32_t)indices[idx];
|
|
||||||
}
|
|
||||||
auto e = std::chrono::high_resolution_clock::now();
|
|
||||||
std::chrono::duration<double> diff = e - s;
|
|
||||||
auto qps = (1.0 * query_num) / (1.0 * diff.count());
|
|
||||||
|
|
||||||
auto mean_latency = diskann::get_mean_stats<float>(
|
|
||||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
|
||||||
|
|
||||||
auto latency_999 = diskann::get_percentile_stats<float>(
|
|
||||||
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
|
||||||
|
|
||||||
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
|
|
||||||
[](const diskann::QueryStats &stats) { return stats.n_ios; });
|
|
||||||
|
|
||||||
double mean_cpuus = diskann::get_mean_stats<float>(
|
|
||||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; });
|
|
||||||
|
|
||||||
double recall = 0;
|
|
||||||
double ratio_of_sums = 0;
|
|
||||||
if (calc_recall_flag)
|
|
||||||
{
|
|
||||||
recall =
|
|
||||||
diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]);
|
|
||||||
|
|
||||||
uint32_t total_true_positive = 0;
|
|
||||||
uint32_t total_positive = 0;
|
|
||||||
for (uint32_t i = 0; i < query_num; i++)
|
|
||||||
{
|
|
||||||
total_true_positive += (uint32_t)query_result_ids[test_id][i].size();
|
|
||||||
total_positive += (uint32_t)groundtruth_ids[i].size();
|
|
||||||
}
|
|
||||||
|
|
||||||
ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive);
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
|
|
||||||
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
|
|
||||||
<< std::setw(16) << mean_cpuus;
|
|
||||||
if (calc_recall_flag)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(16) << recall << "," << ratio_of_sums << std::endl;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
diskann::cout << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::cout << "Done searching. " << std::endl;
|
|
||||||
|
|
||||||
diskann::aligned_free(query);
|
|
||||||
if (warmup != nullptr)
|
|
||||||
diskann::aligned_free(warmup);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file;
|
|
||||||
uint32_t num_threads, W, num_nodes_to_cache;
|
|
||||||
std::vector<uint32_t> Lvec;
|
|
||||||
float range;
|
|
||||||
|
|
||||||
po::options_description desc{program_options_utils::make_program_description(
|
|
||||||
"range_search_disk_index", "Searches disk DiskANN indexes using ranges")};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
|
|
||||||
// Required parameters
|
|
||||||
po::options_description required_configs("Required");
|
|
||||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
|
||||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
|
||||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
|
||||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
|
||||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
|
||||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
|
||||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("search_list,L",
|
|
||||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
|
||||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
|
||||||
required_configs.add_options()("range_threshold,K", po::value<float>(&range)->required(),
|
|
||||||
"Number of neighbors to be returned");
|
|
||||||
|
|
||||||
// Optional parameters
|
|
||||||
po::options_description optional_configs("Optional");
|
|
||||||
optional_configs.add_options()("num_threads,T",
|
|
||||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
|
||||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
|
||||||
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
|
|
||||||
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
|
|
||||||
program_options_utils::BEAMWIDTH);
|
|
||||||
|
|
||||||
// Merge required and optional parameters
|
|
||||||
desc.add(required_configs).add(optional_configs);
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << '\n';
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::Metric metric;
|
|
||||||
if (dist_fn == std::string("mips"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::INNER_PRODUCT;
|
|
||||||
}
|
|
||||||
else if (dist_fn == std::string("l2"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::L2;
|
|
||||||
}
|
|
||||||
else if (dist_fn == std::string("cosine"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::COSINE;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
|
||||||
"Product/Cosine are supported."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
|
|
||||||
{
|
|
||||||
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
if (data_type == std::string("float"))
|
|
||||||
return search_disk_index<float>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
|
||||||
num_nodes_to_cache, Lvec);
|
|
||||||
else if (data_type == std::string("int8"))
|
|
||||||
return search_disk_index<int8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
|
||||||
num_nodes_to_cache, Lvec);
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
return search_disk_index<uint8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
|
||||||
num_nodes_to_cache, Lvec);
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &e)
|
|
||||||
{
|
|
||||||
std::cout << std::string(e.what()) << std::endl;
|
|
||||||
diskann::cerr << "Index search failed." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
|
||||||
|
|
||||||
add_executable(inmem_server inmem_server.cpp)
|
|
||||||
if(MSVC)
|
|
||||||
target_link_options(inmem_server PRIVATE /MACHINE:x64)
|
|
||||||
target_link_libraries(inmem_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
|
||||||
target_link_libraries(inmem_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
|
||||||
else()
|
|
||||||
target_link_libraries(inmem_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_executable(ssd_server ssd_server.cpp)
|
|
||||||
if(MSVC)
|
|
||||||
target_link_options(ssd_server PRIVATE /MACHINE:x64)
|
|
||||||
target_link_libraries(ssd_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
|
||||||
target_link_libraries(ssd_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
|
||||||
else()
|
|
||||||
target_link_libraries(ssd_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_executable(multiple_ssdindex_server multiple_ssdindex_server.cpp)
|
|
||||||
if(MSVC)
|
|
||||||
target_link_options(multiple_ssdindex_server PRIVATE /MACHINE:x64)
|
|
||||||
target_link_libraries(multiple_ssdindex_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
|
||||||
target_link_libraries(multiple_ssdindex_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
|
||||||
else()
|
|
||||||
target_link_libraries(multiple_ssdindex_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_executable(client client.cpp)
|
|
||||||
if(MSVC)
|
|
||||||
target_link_options(client PRIVATE /MACHINE:x64)
|
|
||||||
target_link_libraries(client debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
|
||||||
target_link_libraries(client optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
|
||||||
else()
|
|
||||||
target_link_libraries(client ${PROJECT_NAME} -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
|
||||||
endif()
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <ctime>
|
|
||||||
#include <functional>
|
|
||||||
#include <iomanip>
|
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <codecvt>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
|
|
||||||
#include <cpprest/http_client.h>
|
|
||||||
#include <restapi/common.h>
|
|
||||||
|
|
||||||
using namespace web;
|
|
||||||
using namespace web::http;
|
|
||||||
using namespace web::http::client;
|
|
||||||
|
|
||||||
using namespace diskann;
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void query_loop(const std::string &ip_addr_port, const std::string &query_file, const unsigned nq, const unsigned Ls,
|
|
||||||
const unsigned k_value)
|
|
||||||
{
|
|
||||||
web::http::client::http_client client(U(ip_addr_port));
|
|
||||||
|
|
||||||
T *data;
|
|
||||||
size_t npts = 1, ndims = 128, rounded_dim = 128;
|
|
||||||
diskann::load_aligned_bin<T>(query_file, data, npts, ndims, rounded_dim);
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nq; ++i)
|
|
||||||
{
|
|
||||||
T *vec = data + i * rounded_dim;
|
|
||||||
web::http::http_request http_query(methods::POST);
|
|
||||||
web::json::value queryJson = web::json::value::object();
|
|
||||||
queryJson[QUERY_ID_KEY] = i;
|
|
||||||
queryJson[K_KEY] = k_value;
|
|
||||||
queryJson[L_KEY] = Ls;
|
|
||||||
for (size_t i = 0; i < ndims; ++i)
|
|
||||||
{
|
|
||||||
queryJson[VECTOR_KEY][i] = web::json::value::number(vec[i]);
|
|
||||||
}
|
|
||||||
http_query.set_body(queryJson);
|
|
||||||
|
|
||||||
client.request(http_query)
|
|
||||||
.then([](web::http::http_response response) -> pplx::task<utility::string_t> {
|
|
||||||
if (response.status_code() == status_codes::OK)
|
|
||||||
{
|
|
||||||
return response.extract_string();
|
|
||||||
}
|
|
||||||
std::cerr << "Query failed" << std::endl;
|
|
||||||
return pplx::task_from_result(utility::string_t());
|
|
||||||
})
|
|
||||||
.then([](pplx::task<utility::string_t> previousTask) {
|
|
||||||
try
|
|
||||||
{
|
|
||||||
std::cout << previousTask.get() << std::endl;
|
|
||||||
}
|
|
||||||
catch (http_exception const &e)
|
|
||||||
{
|
|
||||||
std::wcout << e.what() << std::endl;
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.wait();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
|
||||||
{
|
|
||||||
std::string data_type, query_file, address;
|
|
||||||
uint32_t num_queries;
|
|
||||||
uint32_t l_search, k_value;
|
|
||||||
|
|
||||||
po::options_description desc{"Arguments"};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
|
||||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
|
||||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
|
||||||
"File containing the queries to search");
|
|
||||||
desc.add_options()("num_queries,Q", po::value<uint32_t>(&num_queries)->required(),
|
|
||||||
"Number of queries to search");
|
|
||||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
|
|
||||||
desc.add_options()("k_value,K", po::value<uint32_t>(&k_value)->default_value(10), "Value of K (default 10)");
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data_type == std::string("float"))
|
|
||||||
{
|
|
||||||
query_loop<float>(address, query_file, num_queries, l_search, k_value);
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("int8"))
|
|
||||||
{
|
|
||||||
query_loop<int8_t>(address, query_file, num_queries, l_search, k_value);
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
{
|
|
||||||
query_loop<uint8_t>(address, query_file, num_queries, l_search, k_value);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "Unsupported type " << argv[2] << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <ctime>
|
|
||||||
#include <functional>
|
|
||||||
#include <iomanip>
|
|
||||||
#include <string>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <codecvt>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
|
|
||||||
#include <restapi/server.h>
|
|
||||||
|
|
||||||
using namespace diskann;
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
|
||||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_inMemorySearch;
|
|
||||||
|
|
||||||
void setup(const utility::string_t &address, const std::string &typestring)
|
|
||||||
{
|
|
||||||
web::http::uri_builder uriBldr(address);
|
|
||||||
auto uri = uriBldr.to_uri();
|
|
||||||
|
|
||||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
|
||||||
|
|
||||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch, typestring));
|
|
||||||
std::cout << "Created a server object" << std::endl;
|
|
||||||
|
|
||||||
g_httpServer->open().wait();
|
|
||||||
ucout << U"Listening for requests on: " << address << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void teardown(const utility::string_t &address)
|
|
||||||
{
|
|
||||||
g_httpServer->close().wait();
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
|
||||||
{
|
|
||||||
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
|
|
||||||
uint32_t num_threads;
|
|
||||||
uint32_t l_search;
|
|
||||||
|
|
||||||
po::options_description desc{"Arguments"};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
|
||||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
|
||||||
desc.add_options()("data_file", po::value<std::string>(&data_file)->required(),
|
|
||||||
"File containing the data found in the index");
|
|
||||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_file)->required(),
|
|
||||||
"Path prefix for saving index file components");
|
|
||||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->required(),
|
|
||||||
"Number of threads used for building index");
|
|
||||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
|
|
||||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
|
||||||
"distance function <l2/mips>");
|
|
||||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
|
||||||
"Tags file location");
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
diskann::Metric metric;
|
|
||||||
if (dist_fn == std::string("l2"))
|
|
||||||
metric = diskann::Metric::L2;
|
|
||||||
else if (dist_fn == std::string("mips"))
|
|
||||||
metric = diskann::Metric::INNER_PRODUCT;
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data_type == std::string("float"))
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
|
||||||
new diskann::InMemorySearch<float>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
|
||||||
g_inMemorySearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("int8"))
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
|
||||||
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
|
||||||
g_inMemorySearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
|
||||||
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
|
||||||
g_inMemorySearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (1)
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
setup(address, data_type);
|
|
||||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
|
||||||
std::string line;
|
|
||||||
std::getline(std::cin, line);
|
|
||||||
if (line == "exit")
|
|
||||||
{
|
|
||||||
teardown(address);
|
|
||||||
g_httpServer->close().wait();
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
|
||||||
std::cerr << "Restarting HTTP server";
|
|
||||||
teardown(address);
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
|
||||||
std::cerr << "Restarting HTTP server";
|
|
||||||
teardown(address);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <restapi/server.h>
|
|
||||||
#include <restapi/in_memory_search.h>
|
|
||||||
#include <codecvt>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
|
||||||
std::unique_ptr<diskann::InMemorySearch> g_inMemorySearch(nullptr);
|
|
||||||
|
|
||||||
void setup(const utility::string_t &address)
|
|
||||||
{
|
|
||||||
web::http::uri_builder uriBldr(address);
|
|
||||||
auto uri = uriBldr.to_uri();
|
|
||||||
|
|
||||||
std::wcout << L"Attempting to start server on " << uri.to_string() << std::endl;
|
|
||||||
|
|
||||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch));
|
|
||||||
g_httpServer->open().wait();
|
|
||||||
|
|
||||||
ucout << U"Listening for requests on: " << address << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void teardown(const utility::string_t &address)
|
|
||||||
{
|
|
||||||
g_httpServer->close().wait();
|
|
||||||
}
|
|
||||||
|
|
||||||
void loadIndex(const char *indexFile, const char *baseFile, const char *idsFile)
|
|
||||||
{
|
|
||||||
auto nsgSearch = new diskann::InMemorySearch(baseFile, indexFile, idsFile, diskann::L2);
|
|
||||||
g_inMemorySearch = std::unique_ptr<diskann::InMemorySearch>(nsgSearch);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::wstring getHostingAddress(const char *hostNameAndPort)
|
|
||||||
{
|
|
||||||
wchar_t buffer[4096];
|
|
||||||
mbstowcs_s(nullptr, buffer, sizeof(buffer) / sizeof(buffer[0]), hostNameAndPort,
|
|
||||||
sizeof(buffer) / sizeof(buffer[0]));
|
|
||||||
return std::wstring(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
|
||||||
{
|
|
||||||
if (argc != 5)
|
|
||||||
{
|
|
||||||
std::cout << "Usage: nsg_server <ip_addr_and_port> <index_file> "
|
|
||||||
"<base_file> <ids_file> "
|
|
||||||
<< std::endl;
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto address = getHostingAddress(argv[1]);
|
|
||||||
loadIndex(argv[2], argv[3], argv[4]);
|
|
||||||
while (1)
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
setup(address);
|
|
||||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
|
||||||
std::string line;
|
|
||||||
std::getline(std::cin, line);
|
|
||||||
if (line == "exit")
|
|
||||||
{
|
|
||||||
teardown(address);
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
|
||||||
std::cerr << "Restarting HTTP server";
|
|
||||||
teardown(address);
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
|
||||||
std::cerr << "Restarting HTTP server";
|
|
||||||
teardown(address);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <ctime>
|
|
||||||
#include <functional>
|
|
||||||
#include <iomanip>
|
|
||||||
#include <string>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <codecvt>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
#include <omp.h>
|
|
||||||
|
|
||||||
#include <restapi/server.h>
|
|
||||||
|
|
||||||
using namespace diskann;
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
|
||||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
|
||||||
|
|
||||||
void setup(const utility::string_t &address, const std::string &typestring)
|
|
||||||
{
|
|
||||||
web::http::uri_builder uriBldr(address);
|
|
||||||
auto uri = uriBldr.to_uri();
|
|
||||||
|
|
||||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
|
||||||
|
|
||||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
|
|
||||||
std::cout << "Created a server object" << std::endl;
|
|
||||||
|
|
||||||
g_httpServer->open().wait();
|
|
||||||
ucout << U"Listening for requests on: " << address << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void teardown(const utility::string_t &address)
|
|
||||||
{
|
|
||||||
g_httpServer->close().wait();
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
|
||||||
{
|
|
||||||
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
|
|
||||||
uint32_t num_nodes_to_cache;
|
|
||||||
uint32_t num_threads;
|
|
||||||
|
|
||||||
po::options_description desc{"Arguments"};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
|
||||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
|
||||||
desc.add_options()("index_prefix_paths", po::value<std::string>(&index_prefix_paths)->required(),
|
|
||||||
"Path prefix for loading index file components");
|
|
||||||
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
|
||||||
"Number of nodes to cache during search");
|
|
||||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
"Number of threads used for building index (defaults to "
|
|
||||||
"omp_get_num_procs())");
|
|
||||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
|
||||||
"distance function <l2/mips>");
|
|
||||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
|
||||||
"Tags file location");
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::Metric metric;
|
|
||||||
if (dist_fn == std::string("l2"))
|
|
||||||
metric = diskann::Metric::L2;
|
|
||||||
else if (dist_fn == std::string("mips"))
|
|
||||||
metric = diskann::Metric::INNER_PRODUCT;
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::pair<std::string, std::string>> index_tag_paths;
|
|
||||||
std::ifstream index_in(index_prefix_paths);
|
|
||||||
if (!index_in.is_open())
|
|
||||||
{
|
|
||||||
std::cerr << "Could not open " << index_prefix_paths << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
std::ifstream tags_in(tags_file);
|
|
||||||
if (!tags_in.is_open())
|
|
||||||
{
|
|
||||||
std::cerr << "Could not open " << tags_file << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
std::string prefix, tagfile;
|
|
||||||
while (std::getline(index_in, prefix))
|
|
||||||
{
|
|
||||||
if (std::getline(tags_in, tagfile))
|
|
||||||
{
|
|
||||||
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "The number of tags specified does not match the number of "
|
|
||||||
"indices specified"
|
|
||||||
<< std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
index_in.close();
|
|
||||||
tags_in.close();
|
|
||||||
|
|
||||||
if (data_type == std::string("float"))
|
|
||||||
{
|
|
||||||
for (auto &index_tag : index_tag_paths)
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<float>(
|
|
||||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
|
||||||
g_ssdSearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("int8"))
|
|
||||||
{
|
|
||||||
for (auto &index_tag : index_tag_paths)
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
|
|
||||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
|
||||||
g_ssdSearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
{
|
|
||||||
for (auto &index_tag : index_tag_paths)
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<uint8_t>(
|
|
||||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
|
||||||
g_ssdSearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "Unsupported data type " << data_type << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (1)
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
setup(address, data_type);
|
|
||||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
|
||||||
std::string line;
|
|
||||||
std::getline(std::cin, line);
|
|
||||||
if (line == "exit")
|
|
||||||
{
|
|
||||||
teardown(address);
|
|
||||||
g_httpServer->close().wait();
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
|
||||||
std::cerr << "Restarting HTTP server";
|
|
||||||
teardown(address);
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
|
||||||
std::cerr << "Restarting HTTP server";
|
|
||||||
teardown(address);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <ctime>
|
|
||||||
#include <functional>
|
|
||||||
#include <iomanip>
|
|
||||||
#include <string>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <codecvt>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
#include <omp.h>
|
|
||||||
|
|
||||||
#include <restapi/server.h>
|
|
||||||
|
|
||||||
using namespace diskann;
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
|
||||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
|
||||||
|
|
||||||
void setup(const utility::string_t &address, const std::string &typestring)
|
|
||||||
{
|
|
||||||
web::http::uri_builder uriBldr(address);
|
|
||||||
auto uri = uriBldr.to_uri();
|
|
||||||
|
|
||||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
|
||||||
|
|
||||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
|
|
||||||
std::cout << "Created a server object" << std::endl;
|
|
||||||
|
|
||||||
g_httpServer->open().wait();
|
|
||||||
ucout << U"Listening for requests on: " << address << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void teardown(const utility::string_t &address)
|
|
||||||
{
|
|
||||||
g_httpServer->close().wait();
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
|
||||||
{
|
|
||||||
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
|
|
||||||
uint32_t num_nodes_to_cache;
|
|
||||||
uint32_t num_threads;
|
|
||||||
|
|
||||||
po::options_description desc{"Arguments"};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
|
||||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
|
||||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
|
||||||
"Path prefix for loading index file components");
|
|
||||||
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
|
||||||
"Number of nodes to cache during search");
|
|
||||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
"Number of threads used for building index (defaults to "
|
|
||||||
"omp_get_num_procs())");
|
|
||||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
|
||||||
"distance function <l2/mips>");
|
|
||||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
|
||||||
"Tags file location");
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::Metric metric;
|
|
||||||
if (dist_fn == std::string("l2"))
|
|
||||||
metric = diskann::Metric::L2;
|
|
||||||
else if (dist_fn == std::string("mips"))
|
|
||||||
metric = diskann::Metric::INNER_PRODUCT;
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data_type == std::string("float"))
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
|
||||||
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
|
||||||
g_ssdSearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("int8"))
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
|
||||||
new diskann::PQFlashSearch<int8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
|
||||||
g_ssdSearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
{
|
|
||||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
|
||||||
new diskann::PQFlashSearch<uint8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
|
||||||
g_ssdSearch.push_back(std::move(searcher));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (1)
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
setup(address, data_type);
|
|
||||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
|
||||||
std::string line;
|
|
||||||
std::getline(std::cin, line);
|
|
||||||
if (line == "exit")
|
|
||||||
{
|
|
||||||
teardown(address);
|
|
||||||
g_httpServer->close().wait();
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
|
||||||
std::cerr << "Restarting HTTP server";
|
|
||||||
teardown(address);
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
|
||||||
std::cerr << "Restarting HTTP server";
|
|
||||||
teardown(address);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,499 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include "common_includes.h"
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
|
|
||||||
#include "index.h"
|
|
||||||
#include "disk_utils.h"
|
|
||||||
#include "math_utils.h"
|
|
||||||
#include "memory_mapper.h"
|
|
||||||
#include "partition.h"
|
|
||||||
#include "pq_flash_index.h"
|
|
||||||
#include "timer.h"
|
|
||||||
#include "percentile_stats.h"
|
|
||||||
#include "program_options_utils.hpp"
|
|
||||||
|
|
||||||
#ifndef _WINDOWS
|
|
||||||
#include <sys/mman.h>
|
|
||||||
#include <sys/stat.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include "linux_aligned_file_reader.h"
|
|
||||||
#else
|
|
||||||
#ifdef USE_BING_INFRA
|
|
||||||
#include "bing_aligned_file_reader.h"
|
|
||||||
#else
|
|
||||||
#include "windows_aligned_file_reader.h"
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define WARMUP false
|
|
||||||
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(20) << category << ": " << std::flush;
|
|
||||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(8) << percentiles[s] << "%";
|
|
||||||
}
|
|
||||||
diskann::cout << std::endl;
|
|
||||||
diskann::cout << std::setw(22) << " " << std::flush;
|
|
||||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(9) << results[s];
|
|
||||||
}
|
|
||||||
diskann::cout << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename LabelT = uint32_t>
|
|
||||||
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
|
|
||||||
const std::string &result_output_prefix, const std::string &query_file, std::string >_file,
|
|
||||||
const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth,
|
|
||||||
const uint32_t num_nodes_to_cache, const uint32_t search_io_limit,
|
|
||||||
const std::vector<uint32_t> &Lvec, const float fail_if_recall_below,
|
|
||||||
const std::vector<std::string> &query_filters, const bool use_reorder_data = false)
|
|
||||||
{
|
|
||||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
|
||||||
if (beamwidth <= 0)
|
|
||||||
diskann::cout << "beamwidth to be optimized for each L value" << std::flush;
|
|
||||||
else
|
|
||||||
diskann::cout << " beamwidth: " << beamwidth << std::flush;
|
|
||||||
if (search_io_limit == std::numeric_limits<uint32_t>::max())
|
|
||||||
diskann::cout << "." << std::endl;
|
|
||||||
else
|
|
||||||
diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl;
|
|
||||||
|
|
||||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
|
||||||
|
|
||||||
// load query bin
|
|
||||||
T *query = nullptr;
|
|
||||||
uint32_t *gt_ids = nullptr;
|
|
||||||
float *gt_dists = nullptr;
|
|
||||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
|
||||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
|
||||||
|
|
||||||
bool filtered_search = false;
|
|
||||||
if (!query_filters.empty())
|
|
||||||
{
|
|
||||||
filtered_search = true;
|
|
||||||
if (query_filters.size() != 1 && query_filters.size() != query_num)
|
|
||||||
{
|
|
||||||
std::cout << "Error. Mismatch in number of queries and size of query "
|
|
||||||
"filters file"
|
|
||||||
<< std::endl;
|
|
||||||
return -1; // To return -1 or some other error handling?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool calc_recall_flag = false;
|
|
||||||
if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file))
|
|
||||||
{
|
|
||||||
diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim);
|
|
||||||
if (gt_num != query_num)
|
|
||||||
{
|
|
||||||
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
|
||||||
}
|
|
||||||
calc_recall_flag = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<AlignedFileReader> reader = nullptr;
|
|
||||||
#ifdef _WINDOWS
|
|
||||||
#ifndef USE_BING_INFRA
|
|
||||||
reader.reset(new WindowsAlignedFileReader());
|
|
||||||
#else
|
|
||||||
reader.reset(new diskann::BingAlignedFileReader());
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
reader.reset(new LinuxAlignedFileReader());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
|
||||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
|
||||||
|
|
||||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
|
||||||
|
|
||||||
if (res != 0)
|
|
||||||
{
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint32_t> node_list;
|
|
||||||
diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl;
|
|
||||||
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
|
|
||||||
// if (num_nodes_to_cache > 0)
|
|
||||||
// _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache,
|
|
||||||
// num_threads, node_list);
|
|
||||||
_pFlashIndex->load_cache_list(node_list);
|
|
||||||
node_list.clear();
|
|
||||||
node_list.shrink_to_fit();
|
|
||||||
|
|
||||||
omp_set_num_threads(num_threads);
|
|
||||||
|
|
||||||
uint64_t warmup_L = 20;
|
|
||||||
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
|
|
||||||
T *warmup = nullptr;
|
|
||||||
|
|
||||||
if (WARMUP)
|
|
||||||
{
|
|
||||||
if (file_exists(warmup_query_file))
|
|
||||||
{
|
|
||||||
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
|
|
||||||
warmup_dim = query_dim;
|
|
||||||
warmup_aligned_dim = query_aligned_dim;
|
|
||||||
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
|
|
||||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
|
||||||
std::random_device rd;
|
|
||||||
std::mt19937 gen(rd());
|
|
||||||
std::uniform_int_distribution<> dis(-128, 127);
|
|
||||||
for (uint32_t i = 0; i < warmup_num; i++)
|
|
||||||
{
|
|
||||||
for (uint32_t d = 0; d < warmup_dim; d++)
|
|
||||||
{
|
|
||||||
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
diskann::cout << "Warming up index... " << std::flush;
|
|
||||||
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
|
|
||||||
std::vector<float> warmup_result_dists(warmup_num, 0);
|
|
||||||
|
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
|
||||||
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
|
|
||||||
{
|
|
||||||
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
|
|
||||||
warmup_result_ids_64.data() + (i * 1),
|
|
||||||
warmup_result_dists.data() + (i * 1), 4);
|
|
||||||
}
|
|
||||||
diskann::cout << "..done" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
|
||||||
diskann::cout.precision(2);
|
|
||||||
|
|
||||||
std::string recall_string = "Recall@" + std::to_string(recall_at);
|
|
||||||
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
|
|
||||||
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
|
|
||||||
<< "Mean IO (us)" << std::setw(16) << "CPU (s)";
|
|
||||||
if (calc_recall_flag)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(16) << recall_string << std::endl;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
diskann::cout << std::endl;
|
|
||||||
diskann::cout << "=================================================================="
|
|
||||||
"================================================================="
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
|
||||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
|
||||||
|
|
||||||
uint32_t optimized_beamwidth = 2;
|
|
||||||
|
|
||||||
double best_recall = 0.0;
|
|
||||||
|
|
||||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
|
||||||
{
|
|
||||||
uint32_t L = Lvec[test_id];
|
|
||||||
|
|
||||||
if (L < recall_at)
|
|
||||||
{
|
|
||||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (beamwidth <= 0)
|
|
||||||
{
|
|
||||||
diskann::cout << "Tuning beamwidth.." << std::endl;
|
|
||||||
optimized_beamwidth =
|
|
||||||
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
optimized_beamwidth = beamwidth;
|
|
||||||
|
|
||||||
query_result_ids[test_id].resize(recall_at * query_num);
|
|
||||||
query_result_dists[test_id].resize(recall_at * query_num);
|
|
||||||
|
|
||||||
auto stats = new diskann::QueryStats[query_num];
|
|
||||||
|
|
||||||
std::vector<uint64_t> query_result_ids_64(recall_at * query_num);
|
|
||||||
auto s = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
|
||||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
|
||||||
{
|
|
||||||
if (!filtered_search)
|
|
||||||
{
|
|
||||||
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
|
|
||||||
query_result_ids_64.data() + (i * recall_at),
|
|
||||||
query_result_dists[test_id].data() + (i * recall_at),
|
|
||||||
optimized_beamwidth, use_reorder_data, stats + i);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
LabelT label_for_search;
|
|
||||||
if (query_filters.size() == 1)
|
|
||||||
{ // one label for all queries
|
|
||||||
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{ // one label for each query
|
|
||||||
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
|
|
||||||
}
|
|
||||||
_pFlashIndex->cached_beam_search(
|
|
||||||
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
|
|
||||||
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
|
|
||||||
use_reorder_data, stats + i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto e = std::chrono::high_resolution_clock::now();
|
|
||||||
std::chrono::duration<double> diff = e - s;
|
|
||||||
double qps = (1.0 * query_num) / (1.0 * diff.count());
|
|
||||||
|
|
||||||
diskann::convert_types<uint64_t, uint32_t>(query_result_ids_64.data(), query_result_ids[test_id].data(),
|
|
||||||
query_num, recall_at);
|
|
||||||
|
|
||||||
auto mean_latency = diskann::get_mean_stats<float>(
|
|
||||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
|
||||||
|
|
||||||
auto latency_999 = diskann::get_percentile_stats<float>(
|
|
||||||
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
|
||||||
|
|
||||||
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
|
|
||||||
[](const diskann::QueryStats &stats) { return stats.n_ios; });
|
|
||||||
|
|
||||||
auto mean_cpuus = diskann::get_mean_stats<float>(stats, query_num,
|
|
||||||
[](const diskann::QueryStats &stats) { return stats.cpu_us; });
|
|
||||||
|
|
||||||
auto mean_io_us = diskann::get_mean_stats<float>(stats, query_num,
|
|
||||||
[](const diskann::QueryStats &stats) { return stats.io_us; });
|
|
||||||
|
|
||||||
double recall = 0;
|
|
||||||
if (calc_recall_flag)
|
|
||||||
{
|
|
||||||
recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
|
|
||||||
query_result_ids[test_id].data(), recall_at, recall_at);
|
|
||||||
best_recall = std::max(recall, best_recall);
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
|
|
||||||
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
|
|
||||||
<< std::setw(16) << mean_io_us << std::setw(16) << mean_cpuus;
|
|
||||||
if (calc_recall_flag)
|
|
||||||
{
|
|
||||||
diskann::cout << std::setw(16) << recall << std::endl;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
diskann::cout << std::endl;
|
|
||||||
delete[] stats;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::cout << "Done searching. Now saving results " << std::endl;
|
|
||||||
uint64_t test_id = 0;
|
|
||||||
for (auto L : Lvec)
|
|
||||||
{
|
|
||||||
if (L < recall_at)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin";
|
|
||||||
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
|
|
||||||
|
|
||||||
cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin";
|
|
||||||
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at);
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::aligned_free(query);
|
|
||||||
if (warmup != nullptr)
|
|
||||||
diskann::aligned_free(warmup);
|
|
||||||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label,
|
|
||||||
label_type, query_filters_file;
|
|
||||||
uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit;
|
|
||||||
std::vector<uint32_t> Lvec;
|
|
||||||
bool use_reorder_data = false;
|
|
||||||
float fail_if_recall_below = 0.0f;
|
|
||||||
|
|
||||||
po::options_description desc{
|
|
||||||
program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
|
|
||||||
// Required parameters
|
|
||||||
po::options_description required_configs("Required");
|
|
||||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
|
||||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
|
||||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
|
||||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
|
||||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
|
||||||
required_configs.add_options()("result_path", po::value<std::string>(&result_path_prefix)->required(),
|
|
||||||
program_options_utils::RESULT_PATH_DESCRIPTION);
|
|
||||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
|
||||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
|
|
||||||
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
|
|
||||||
required_configs.add_options()("search_list,L",
|
|
||||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
|
||||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
|
||||||
|
|
||||||
// Optional parameters
|
|
||||||
po::options_description optional_configs("Optional");
|
|
||||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
|
||||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
|
|
||||||
program_options_utils::BEAMWIDTH);
|
|
||||||
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
|
||||||
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
|
|
||||||
optional_configs.add_options()(
|
|
||||||
"search_io_limit",
|
|
||||||
po::value<uint32_t>(&search_io_limit)->default_value(std::numeric_limits<uint32_t>::max()),
|
|
||||||
"Max #IOs for search. Default value: uint32::max()");
|
|
||||||
optional_configs.add_options()("num_threads,T",
|
|
||||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false),
|
|
||||||
"Include full precision data in the index. Use only in "
|
|
||||||
"conjuction with compressed data on SSD. Default value: false");
|
|
||||||
optional_configs.add_options()("filter_label",
|
|
||||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
|
||||||
program_options_utils::FILTER_LABEL_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("query_filters_file",
|
|
||||||
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
|
|
||||||
program_options_utils::FILTERS_FILE_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
|
||||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("fail_if_recall_below",
|
|
||||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
|
||||||
program_options_utils::FAIL_IF_RECALL_BELOW);
|
|
||||||
|
|
||||||
// Merge required and optional parameters
|
|
||||||
desc.add(required_configs).add(optional_configs);
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
if (vm["use_reorder_data"].as<bool>())
|
|
||||||
use_reorder_data = true;
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << '\n';
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::Metric metric;
|
|
||||||
if (dist_fn == std::string("mips"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::INNER_PRODUCT;
|
|
||||||
}
|
|
||||||
else if (dist_fn == std::string("l2"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::L2;
|
|
||||||
}
|
|
||||||
else if (dist_fn == std::string("cosine"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::COSINE;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
|
||||||
"Product/Cosine are supported."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
|
|
||||||
{
|
|
||||||
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (use_reorder_data && data_type != std::string("float"))
|
|
||||||
{
|
|
||||||
std::cout << "Error: Reorder data for reordering currently only "
|
|
||||||
"supported for float data type."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (filter_label != "" && query_filters_file != "")
|
|
||||||
{
|
|
||||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> query_filters;
|
|
||||||
if (filter_label != "")
|
|
||||||
{
|
|
||||||
query_filters.push_back(filter_label);
|
|
||||||
}
|
|
||||||
else if (query_filters_file != "")
|
|
||||||
{
|
|
||||||
query_filters = read_file_to_vector_of_strings(query_filters_file);
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
if (!query_filters.empty() && label_type == "ushort")
|
|
||||||
{
|
|
||||||
if (data_type == std::string("float"))
|
|
||||||
return search_disk_index<float, uint16_t>(
|
|
||||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
|
||||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
|
||||||
else if (data_type == std::string("int8"))
|
|
||||||
return search_disk_index<int8_t, uint16_t>(
|
|
||||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
|
||||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
return search_disk_index<uint8_t, uint16_t>(
|
|
||||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
|
||||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (data_type == std::string("float"))
|
|
||||||
return search_disk_index<float>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
|
||||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
|
||||||
fail_if_recall_below, query_filters, use_reorder_data);
|
|
||||||
else if (data_type == std::string("int8"))
|
|
||||||
return search_disk_index<int8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
|
||||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
|
||||||
fail_if_recall_below, query_filters, use_reorder_data);
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
return search_disk_index<uint8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
|
||||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
|
||||||
fail_if_recall_below, query_filters, use_reorder_data);
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &e)
|
|
||||||
{
|
|
||||||
std::cout << std::string(e.what()) << std::endl;
|
|
||||||
diskann::cerr << "Index search failed." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,477 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <cstring>
|
|
||||||
#include <iomanip>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <numeric>
|
|
||||||
#include <omp.h>
|
|
||||||
#include <set>
|
|
||||||
#include <string.h>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
|
|
||||||
#ifndef _WINDOWS
|
|
||||||
#include <sys/mman.h>
|
|
||||||
#include <sys/stat.h>
|
|
||||||
#include <time.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "index.h"
|
|
||||||
#include "memory_mapper.h"
|
|
||||||
#include "utils.h"
|
|
||||||
#include "program_options_utils.hpp"
|
|
||||||
#include "index_factory.h"
|
|
||||||
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
template <typename T, typename LabelT = uint32_t>
|
|
||||||
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
|
|
||||||
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
|
|
||||||
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
|
|
||||||
const bool dynamic, const bool tags, const bool show_qps_per_thread,
|
|
||||||
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
|
|
||||||
{
|
|
||||||
using TagT = uint32_t;
|
|
||||||
// Load the query file
|
|
||||||
T *query = nullptr;
|
|
||||||
uint32_t *gt_ids = nullptr;
|
|
||||||
float *gt_dists = nullptr;
|
|
||||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
|
||||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
|
||||||
|
|
||||||
bool calc_recall_flag = false;
|
|
||||||
if (truthset_file != std::string("null") && file_exists(truthset_file))
|
|
||||||
{
|
|
||||||
diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim);
|
|
||||||
if (gt_num != query_num)
|
|
||||||
{
|
|
||||||
std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
|
||||||
}
|
|
||||||
calc_recall_flag = true;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool filtered_search = false;
|
|
||||||
if (!query_filters.empty())
|
|
||||||
{
|
|
||||||
filtered_search = true;
|
|
||||||
if (query_filters.size() != 1 && query_filters.size() != query_num)
|
|
||||||
{
|
|
||||||
std::cout << "Error. Mismatch in number of queries and size of query "
|
|
||||||
"filters file"
|
|
||||||
<< std::endl;
|
|
||||||
return -1; // To return -1 or some other error handling?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path);
|
|
||||||
|
|
||||||
auto config = diskann::IndexConfigBuilder()
|
|
||||||
.with_metric(metric)
|
|
||||||
.with_dimension(query_dim)
|
|
||||||
.with_max_points(0)
|
|
||||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
|
||||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
|
||||||
.with_data_type(diskann_type_to_name<T>())
|
|
||||||
.with_label_type(diskann_type_to_name<LabelT>())
|
|
||||||
.with_tag_type(diskann_type_to_name<TagT>())
|
|
||||||
.is_dynamic_index(dynamic)
|
|
||||||
.is_enable_tags(tags)
|
|
||||||
.is_concurrent_consolidate(false)
|
|
||||||
.is_pq_dist_build(false)
|
|
||||||
.is_use_opq(false)
|
|
||||||
.with_num_pq_chunks(0)
|
|
||||||
.with_num_frozen_pts(num_frozen_pts)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto index_factory = diskann::IndexFactory(config);
|
|
||||||
auto index = index_factory.create_instance();
|
|
||||||
index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end())));
|
|
||||||
std::cout << "Index loaded" << std::endl;
|
|
||||||
|
|
||||||
if (metric == diskann::FAST_L2)
|
|
||||||
index->optimize_index_layout();
|
|
||||||
|
|
||||||
std::cout << "Using " << num_threads << " threads to search" << std::endl;
|
|
||||||
std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
|
||||||
std::cout.precision(2);
|
|
||||||
const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS";
|
|
||||||
uint32_t table_width = 0;
|
|
||||||
if (tags)
|
|
||||||
{
|
|
||||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)"
|
|
||||||
<< std::setw(15) << "99.9 Latency";
|
|
||||||
table_width += 4 + 12 + 20 + 15;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps"
|
|
||||||
<< std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency";
|
|
||||||
table_width += 4 + 12 + 18 + 20 + 15;
|
|
||||||
}
|
|
||||||
uint32_t recalls_to_print = 0;
|
|
||||||
const uint32_t first_recall = print_all_recalls ? 1 : recall_at;
|
|
||||||
if (calc_recall_flag)
|
|
||||||
{
|
|
||||||
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
|
|
||||||
{
|
|
||||||
std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall));
|
|
||||||
}
|
|
||||||
recalls_to_print = recall_at + 1 - first_recall;
|
|
||||||
table_width += recalls_to_print * 12;
|
|
||||||
}
|
|
||||||
std::cout << std::endl;
|
|
||||||
std::cout << std::string(table_width, '=') << std::endl;
|
|
||||||
|
|
||||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
|
||||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
|
||||||
std::vector<float> latency_stats(query_num, 0);
|
|
||||||
std::vector<uint32_t> cmp_stats;
|
|
||||||
if (not tags || filtered_search)
|
|
||||||
{
|
|
||||||
cmp_stats = std::vector<uint32_t>(query_num, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<TagT> query_result_tags;
|
|
||||||
if (tags)
|
|
||||||
{
|
|
||||||
query_result_tags.resize(recall_at * query_num);
|
|
||||||
}
|
|
||||||
|
|
||||||
double best_recall = 0.0;
|
|
||||||
|
|
||||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
|
||||||
{
|
|
||||||
uint32_t L = Lvec[test_id];
|
|
||||||
if (L < recall_at)
|
|
||||||
{
|
|
||||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
query_result_ids[test_id].resize(recall_at * query_num);
|
|
||||||
query_result_dists[test_id].resize(recall_at * query_num);
|
|
||||||
std::vector<T *> res = std::vector<T *>();
|
|
||||||
|
|
||||||
auto s = std::chrono::high_resolution_clock::now();
|
|
||||||
omp_set_num_threads(num_threads);
|
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
|
||||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
|
||||||
{
|
|
||||||
auto qs = std::chrono::high_resolution_clock::now();
|
|
||||||
if (filtered_search && !tags)
|
|
||||||
{
|
|
||||||
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
|
|
||||||
|
|
||||||
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
|
|
||||||
query_result_ids[test_id].data() + i * recall_at,
|
|
||||||
query_result_dists[test_id].data() + i * recall_at);
|
|
||||||
cmp_stats[i] = retval.second;
|
|
||||||
}
|
|
||||||
else if (metric == diskann::FAST_L2)
|
|
||||||
{
|
|
||||||
index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L,
|
|
||||||
query_result_ids[test_id].data() + i * recall_at);
|
|
||||||
}
|
|
||||||
else if (tags)
|
|
||||||
{
|
|
||||||
if (!filtered_search)
|
|
||||||
{
|
|
||||||
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
|
|
||||||
query_result_tags.data() + i * recall_at, nullptr, res);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
|
|
||||||
|
|
||||||
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
|
|
||||||
query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t r = 0; r < (int64_t)recall_at; r++)
|
|
||||||
{
|
|
||||||
query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
cmp_stats[i] = index
|
|
||||||
->search(query + i * query_aligned_dim, recall_at, L,
|
|
||||||
query_result_ids[test_id].data() + i * recall_at)
|
|
||||||
.second;
|
|
||||||
}
|
|
||||||
auto qe = std::chrono::high_resolution_clock::now();
|
|
||||||
std::chrono::duration<double> diff = qe - qs;
|
|
||||||
latency_stats[i] = (float)(diff.count() * 1000000);
|
|
||||||
}
|
|
||||||
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
|
|
||||||
|
|
||||||
double displayed_qps = query_num / diff.count();
|
|
||||||
|
|
||||||
if (show_qps_per_thread)
|
|
||||||
displayed_qps /= num_threads;
|
|
||||||
|
|
||||||
std::vector<double> recalls;
|
|
||||||
if (calc_recall_flag)
|
|
||||||
{
|
|
||||||
recalls.reserve(recalls_to_print);
|
|
||||||
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
|
|
||||||
{
|
|
||||||
recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
|
|
||||||
query_result_ids[test_id].data(), recall_at, curr_recall));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::sort(latency_stats.begin(), latency_stats.end());
|
|
||||||
double mean_latency =
|
|
||||||
std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast<float>(query_num);
|
|
||||||
|
|
||||||
float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num;
|
|
||||||
|
|
||||||
if (tags && !filtered_search)
|
|
||||||
{
|
|
||||||
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency
|
|
||||||
<< std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps
|
|
||||||
<< std::setw(20) << (float)mean_latency << std::setw(15)
|
|
||||||
<< (float)latency_stats[(uint64_t)(0.999 * query_num)];
|
|
||||||
}
|
|
||||||
for (double recall : recalls)
|
|
||||||
{
|
|
||||||
std::cout << std::setw(12) << recall;
|
|
||||||
best_recall = std::max(recall, best_recall);
|
|
||||||
}
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "Done searching. Now saving results " << std::endl;
|
|
||||||
uint64_t test_id = 0;
|
|
||||||
for (auto L : Lvec)
|
|
||||||
{
|
|
||||||
if (L < recall_at)
|
|
||||||
{
|
|
||||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L);
|
|
||||||
|
|
||||||
std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin";
|
|
||||||
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
|
|
||||||
|
|
||||||
cur_result_path = cur_result_path_prefix + "_dists_float.bin";
|
|
||||||
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at);
|
|
||||||
|
|
||||||
test_id++;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::aligned_free(query);
|
|
||||||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
|
|
||||||
query_filters_file;
|
|
||||||
uint32_t num_threads, K;
|
|
||||||
std::vector<uint32_t> Lvec;
|
|
||||||
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
|
|
||||||
float fail_if_recall_below = 0.0f;
|
|
||||||
|
|
||||||
po::options_description desc{
|
|
||||||
program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print this information on arguments");
|
|
||||||
|
|
||||||
// Required parameters
|
|
||||||
po::options_description required_configs("Required");
|
|
||||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
|
||||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
|
||||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
|
||||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
|
||||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
|
||||||
required_configs.add_options()("result_path", po::value<std::string>(&result_path)->required(),
|
|
||||||
program_options_utils::RESULT_PATH_DESCRIPTION);
|
|
||||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
|
||||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
|
|
||||||
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
|
|
||||||
required_configs.add_options()("search_list,L",
|
|
||||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
|
||||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
|
||||||
|
|
||||||
// Optional parameters
|
|
||||||
po::options_description optional_configs("Optional");
|
|
||||||
optional_configs.add_options()("filter_label",
|
|
||||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
|
||||||
program_options_utils::FILTER_LABEL_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("query_filters_file",
|
|
||||||
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
|
|
||||||
program_options_utils::FILTERS_FILE_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
|
||||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
|
||||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("num_threads,T",
|
|
||||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
|
||||||
optional_configs.add_options()(
|
|
||||||
"dynamic", po::value<bool>(&dynamic)->default_value(false),
|
|
||||||
"Whether the index is dynamic. Dynamic indices must have associated tags. Default false.");
|
|
||||||
optional_configs.add_options()("tags", po::value<bool>(&tags)->default_value(false),
|
|
||||||
"Whether to search with external identifiers (tags). Default false.");
|
|
||||||
optional_configs.add_options()("fail_if_recall_below",
|
|
||||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
|
||||||
program_options_utils::FAIL_IF_RECALL_BELOW);
|
|
||||||
|
|
||||||
// Output controls
|
|
||||||
po::options_description output_controls("Output controls");
|
|
||||||
output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls),
|
|
||||||
"Print recalls at all positions, from 1 up to specified "
|
|
||||||
"recall_at value");
|
|
||||||
output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread),
|
|
||||||
"Print overall QPS divided by the number of threads in "
|
|
||||||
"the output table");
|
|
||||||
|
|
||||||
// Merge required and optional parameters
|
|
||||||
desc.add(required_configs).add(optional_configs).add(output_controls);
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << '\n';
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::Metric metric;
|
|
||||||
if ((dist_fn == std::string("mips")) && (data_type == std::string("float")))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::INNER_PRODUCT;
|
|
||||||
}
|
|
||||||
else if (dist_fn == std::string("l2"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::L2;
|
|
||||||
}
|
|
||||||
else if (dist_fn == std::string("cosine"))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::COSINE;
|
|
||||||
}
|
|
||||||
else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float")))
|
|
||||||
{
|
|
||||||
metric = diskann::Metric::FAST_L2;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Unsupported distance function. Currently only l2/ cosine are "
|
|
||||||
"supported in general, and mips/fast_l2 only for floating "
|
|
||||||
"point data."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dynamic && not tags)
|
|
||||||
{
|
|
||||||
std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0)
|
|
||||||
{
|
|
||||||
std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (filter_label != "" && query_filters_file != "")
|
|
||||||
{
|
|
||||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> query_filters;
|
|
||||||
if (filter_label != "")
|
|
||||||
{
|
|
||||||
query_filters.push_back(filter_label);
|
|
||||||
}
|
|
||||||
else if (query_filters_file != "")
|
|
||||||
{
|
|
||||||
query_filters = read_file_to_vector_of_strings(query_filters_file);
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
if (!query_filters.empty() && label_type == "ushort")
|
|
||||||
{
|
|
||||||
if (data_type == std::string("int8"))
|
|
||||||
{
|
|
||||||
return search_memory_index<int8_t, uint16_t>(
|
|
||||||
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
|
|
||||||
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
{
|
|
||||||
return search_memory_index<uint8_t, uint16_t>(
|
|
||||||
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
|
|
||||||
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("float"))
|
|
||||||
{
|
|
||||||
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
|
||||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
|
||||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (data_type == std::string("int8"))
|
|
||||||
{
|
|
||||||
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
|
||||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
|
||||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
{
|
|
||||||
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
|
||||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
|
||||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("float"))
|
|
||||||
{
|
|
||||||
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
|
|
||||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
|
||||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (std::exception &e)
|
|
||||||
{
|
|
||||||
std::cout << std::string(e.what()) << std::endl;
|
|
||||||
diskann::cerr << "Index search failed." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,536 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <index.h>
|
|
||||||
#include <numeric>
|
|
||||||
#include <omp.h>
|
|
||||||
#include <string.h>
|
|
||||||
#include <time.h>
|
|
||||||
#include <timer.h>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
#include <future>
|
|
||||||
|
|
||||||
#include "utils.h"
|
|
||||||
#include "filter_utils.h"
|
|
||||||
#include "program_options_utils.hpp"
|
|
||||||
#include "index_factory.h"
|
|
||||||
|
|
||||||
#ifndef _WINDOWS
|
|
||||||
#include <sys/mman.h>
|
|
||||||
#include <sys/stat.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "memory_mapper.h"
|
|
||||||
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
// load_aligned_bin modified to read pieces of the file, but using ifstream
|
|
||||||
// instead of cached_ifstream.
|
|
||||||
template <typename T>
|
|
||||||
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
|
|
||||||
{
|
|
||||||
diskann::Timer timer;
|
|
||||||
std::ifstream reader;
|
|
||||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
|
||||||
reader.open(bin_file, std::ios::binary | std::ios::ate);
|
|
||||||
size_t actual_file_size = reader.tellg();
|
|
||||||
reader.seekg(0, std::ios::beg);
|
|
||||||
|
|
||||||
int npts_i32, dim_i32;
|
|
||||||
reader.read((char *)&npts_i32, sizeof(int));
|
|
||||||
reader.read((char *)&dim_i32, sizeof(int));
|
|
||||||
size_t npts = (uint32_t)npts_i32;
|
|
||||||
size_t dim = (uint32_t)dim_i32;
|
|
||||||
|
|
||||||
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
|
||||||
if (actual_file_size != expected_actual_file_size)
|
|
||||||
{
|
|
||||||
std::stringstream stream;
|
|
||||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
|
|
||||||
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
|
|
||||||
<< std::endl;
|
|
||||||
std::cout << stream.str();
|
|
||||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (offset_points + points_to_read > npts)
|
|
||||||
{
|
|
||||||
std::stringstream stream;
|
|
||||||
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
|
|
||||||
<< " points, but have only " << npts << " points" << std::endl;
|
|
||||||
std::cout << stream.str();
|
|
||||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
|
||||||
}
|
|
||||||
|
|
||||||
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
|
|
||||||
|
|
||||||
const size_t rounded_dim = ROUND_UP(dim, 8);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < points_to_read; i++)
|
|
||||||
{
|
|
||||||
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
|
|
||||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
|
||||||
}
|
|
||||||
reader.close();
|
|
||||||
|
|
||||||
const double elapsedSeconds = timer.elapsed() / 1000000.0;
|
|
||||||
std::cout << "Read " << points_to_read << " points using non-cached reads in " << elapsedSeconds << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_save_filename(const std::string &save_path, size_t points_to_skip, size_t points_deleted,
|
|
||||||
size_t last_point_threshold)
|
|
||||||
{
|
|
||||||
std::string final_path = save_path;
|
|
||||||
if (points_to_skip > 0)
|
|
||||||
{
|
|
||||||
final_path += "skip" + std::to_string(points_to_skip) + "-";
|
|
||||||
}
|
|
||||||
|
|
||||||
final_path += "del" + std::to_string(points_deleted) + "-";
|
|
||||||
final_path += std::to_string(last_point_threshold);
|
|
||||||
return final_path;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename TagT, typename LabelT>
|
|
||||||
void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data,
|
|
||||||
size_t aligned_dim, std::vector<std::vector<LabelT>> &location_to_labels)
|
|
||||||
{
|
|
||||||
diskann::Timer insert_timer;
|
|
||||||
#pragma omp parallel for num_threads(thread_count) schedule(dynamic)
|
|
||||||
for (int64_t j = start; j < (int64_t)end; j++)
|
|
||||||
{
|
|
||||||
if (!location_to_labels.empty())
|
|
||||||
{
|
|
||||||
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
|
|
||||||
location_to_labels[j - start]);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
|
|
||||||
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
|
|
||||||
<< " points/second overall, " << (end - start) / elapsedSeconds / thread_count << " per thread)\n ";
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename TagT>
|
|
||||||
void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params,
|
|
||||||
size_t points_to_skip, size_t points_to_delete_from_beginning)
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
std::cout << std::endl
|
|
||||||
<< "Lazy deleting points " << points_to_skip << " to "
|
|
||||||
<< points_to_skip + points_to_delete_from_beginning << "... ";
|
|
||||||
for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i)
|
|
||||||
index.lazy_delete(static_cast<TagT>(i + 1)); // Since tags are data location + 1
|
|
||||||
std::cout << "done." << std::endl;
|
|
||||||
|
|
||||||
auto report = index.consolidate_deletes(delete_params);
|
|
||||||
std::cout << "#active points: " << report._active_points << std::endl
|
|
||||||
<< "max points: " << report._max_points << std::endl
|
|
||||||
<< "empty slots: " << report._empty_slots << std::endl
|
|
||||||
<< "deletes processed: " << report._slots_released << std::endl
|
|
||||||
<< "latest delete size: " << report._delete_set_size << std::endl
|
|
||||||
<< "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, "
|
|
||||||
<< points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)"
|
|
||||||
<< std::endl;
|
|
||||||
}
|
|
||||||
catch (std::system_error &e)
|
|
||||||
{
|
|
||||||
std::cout << "Exception caught in deletion thread: " << e.what() << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters ¶ms, size_t points_to_skip,
|
|
||||||
size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm,
|
|
||||||
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
|
|
||||||
const std::string &save_path, size_t points_to_delete_from_beginning,
|
|
||||||
size_t start_deletes_after, bool concurrent, const std::string &label_file,
|
|
||||||
const std::string &universal_label)
|
|
||||||
{
|
|
||||||
size_t dim, aligned_dim;
|
|
||||||
size_t num_points;
|
|
||||||
diskann::get_bin_metadata(data_path, num_points, dim);
|
|
||||||
aligned_dim = ROUND_UP(dim, 8);
|
|
||||||
bool has_labels = label_file != "";
|
|
||||||
using TagT = uint32_t;
|
|
||||||
using LabelT = uint32_t;
|
|
||||||
|
|
||||||
size_t current_point_offset = points_to_skip;
|
|
||||||
const size_t last_point_threshold = points_to_skip + max_points_to_insert;
|
|
||||||
|
|
||||||
bool enable_tags = true;
|
|
||||||
using TagT = uint32_t;
|
|
||||||
auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads);
|
|
||||||
diskann::IndexConfig index_config = diskann::IndexConfigBuilder()
|
|
||||||
.with_metric(diskann::L2)
|
|
||||||
.with_dimension(dim)
|
|
||||||
.with_max_points(max_points_to_insert)
|
|
||||||
.is_dynamic_index(true)
|
|
||||||
.with_index_write_params(params)
|
|
||||||
.with_index_search_params(index_search_params)
|
|
||||||
.with_data_type(diskann_type_to_name<T>())
|
|
||||||
.with_tag_type(diskann_type_to_name<TagT>())
|
|
||||||
.with_label_type(diskann_type_to_name<LabelT>())
|
|
||||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
|
||||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
|
||||||
.is_enable_tags(enable_tags)
|
|
||||||
.is_filtered(has_labels)
|
|
||||||
.with_num_frozen_pts(num_start_pts)
|
|
||||||
.is_concurrent_consolidate(concurrent)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
|
|
||||||
auto index = index_factory.create_instance();
|
|
||||||
|
|
||||||
if (universal_label != "")
|
|
||||||
{
|
|
||||||
LabelT u_label = 0;
|
|
||||||
index->set_universal_label(u_label);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (points_to_skip > num_points)
|
|
||||||
{
|
|
||||||
throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (max_points_to_insert == 0)
|
|
||||||
{
|
|
||||||
max_points_to_insert = num_points;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (points_to_skip + max_points_to_insert > num_points)
|
|
||||||
{
|
|
||||||
max_points_to_insert = num_points - points_to_skip;
|
|
||||||
std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert
|
|
||||||
<< " points since the data file has only that many" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (beginning_index_size > max_points_to_insert)
|
|
||||||
{
|
|
||||||
beginning_index_size = max_points_to_insert;
|
|
||||||
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size
|
|
||||||
<< " points since the data file has only that many" << std::endl;
|
|
||||||
}
|
|
||||||
if (checkpoints_per_snapshot > 0 && beginning_index_size > points_per_checkpoint)
|
|
||||||
{
|
|
||||||
beginning_index_size = points_per_checkpoint;
|
|
||||||
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
T *data = nullptr;
|
|
||||||
diskann::alloc_aligned(
|
|
||||||
(void **)&data, std::max(points_per_checkpoint, beginning_index_size) * aligned_dim * sizeof(T), 8 * sizeof(T));
|
|
||||||
|
|
||||||
std::vector<TagT> tags(beginning_index_size);
|
|
||||||
std::iota(tags.begin(), tags.end(), 1 + static_cast<TagT>(current_point_offset));
|
|
||||||
|
|
||||||
load_aligned_bin_part(data_path, data, current_point_offset, beginning_index_size);
|
|
||||||
std::cout << "load aligned bin succeeded" << std::endl;
|
|
||||||
diskann::Timer timer;
|
|
||||||
|
|
||||||
if (beginning_index_size > 0)
|
|
||||||
{
|
|
||||||
index->build(data, beginning_index_size, tags);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
index->set_start_points_at_random(static_cast<T>(start_point_norm));
|
|
||||||
}
|
|
||||||
|
|
||||||
const double elapsedSeconds = timer.elapsed() / 1000000.0;
|
|
||||||
std::cout << "Initial non-incremental index build time for " << beginning_index_size << " points took "
|
|
||||||
<< elapsedSeconds << " seconds (" << beginning_index_size / elapsedSeconds << " points/second)\n ";
|
|
||||||
|
|
||||||
current_point_offset += beginning_index_size;
|
|
||||||
|
|
||||||
if (points_to_delete_from_beginning > max_points_to_insert)
|
|
||||||
{
|
|
||||||
points_to_delete_from_beginning = static_cast<uint32_t>(max_points_to_insert);
|
|
||||||
std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning
|
|
||||||
<< " points since the data file has only that many" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<LabelT>> location_to_labels;
|
|
||||||
if (concurrent)
|
|
||||||
{
|
|
||||||
// handle labels
|
|
||||||
const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip,
|
|
||||||
points_to_delete_from_beginning, last_point_threshold);
|
|
||||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
|
||||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
|
||||||
if (has_labels)
|
|
||||||
{
|
|
||||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
|
||||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
|
||||||
location_to_labels = std::get<0>(parse_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t sub_threads = (params.num_threads + 1) / 2;
|
|
||||||
bool delete_launched = false;
|
|
||||||
std::future<void> delete_task;
|
|
||||||
|
|
||||||
diskann::Timer timer;
|
|
||||||
|
|
||||||
for (size_t start = current_point_offset; start < last_point_threshold;
|
|
||||||
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
|
|
||||||
{
|
|
||||||
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
|
|
||||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
|
||||||
|
|
||||||
auto insert_task = std::async(std::launch::async, [&]() {
|
|
||||||
load_aligned_bin_part(data_path, data, start, end - start);
|
|
||||||
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, sub_threads, data, aligned_dim,
|
|
||||||
location_to_labels);
|
|
||||||
});
|
|
||||||
insert_task.wait();
|
|
||||||
|
|
||||||
if (!delete_launched && end >= start_deletes_after &&
|
|
||||||
end >= points_to_skip + points_to_delete_from_beginning)
|
|
||||||
{
|
|
||||||
delete_launched = true;
|
|
||||||
diskann::IndexWriteParameters delete_params =
|
|
||||||
diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build();
|
|
||||||
|
|
||||||
delete_task = std::async(std::launch::async, [&]() {
|
|
||||||
delete_from_beginning<T, TagT>(*index, delete_params, points_to_skip,
|
|
||||||
points_to_delete_from_beginning);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete_task.wait();
|
|
||||||
|
|
||||||
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
|
|
||||||
index->save(save_path_inc.c_str(), true);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip,
|
|
||||||
points_to_delete_from_beginning, last_point_threshold);
|
|
||||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
|
||||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
|
||||||
if (has_labels)
|
|
||||||
{
|
|
||||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
|
||||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
|
||||||
location_to_labels = std::get<0>(parse_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t last_snapshot_points_threshold = 0;
|
|
||||||
size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot;
|
|
||||||
|
|
||||||
for (size_t start = current_point_offset; start < last_point_threshold;
|
|
||||||
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
|
|
||||||
{
|
|
||||||
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
|
|
||||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
|
||||||
|
|
||||||
load_aligned_bin_part(data_path, data, start, end - start);
|
|
||||||
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, (int32_t)params.num_threads, data,
|
|
||||||
aligned_dim, location_to_labels);
|
|
||||||
|
|
||||||
if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0)
|
|
||||||
{
|
|
||||||
diskann::Timer save_timer;
|
|
||||||
|
|
||||||
const auto save_path_inc =
|
|
||||||
get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end);
|
|
||||||
index->save(save_path_inc.c_str(), false);
|
|
||||||
const double elapsedSeconds = save_timer.elapsed() / 1000000.0;
|
|
||||||
const size_t points_saved = end - points_to_skip;
|
|
||||||
|
|
||||||
std::cout << "Saved " << points_saved << " points in " << elapsedSeconds << " seconds ("
|
|
||||||
<< points_saved / elapsedSeconds << " points/second)\n";
|
|
||||||
|
|
||||||
num_checkpoints_till_snapshot = checkpoints_per_snapshot;
|
|
||||||
last_snapshot_points_threshold = end;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "Number of points in the index post insertion " << end << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (checkpoints_per_snapshot > 0 && last_snapshot_points_threshold != last_point_threshold)
|
|
||||||
{
|
|
||||||
const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip,
|
|
||||||
points_to_delete_from_beginning, last_point_threshold);
|
|
||||||
// index.save(save_path_inc.c_str(), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (points_to_delete_from_beginning > 0)
|
|
||||||
{
|
|
||||||
delete_from_beginning<T, TagT>(*index, params, points_to_skip, points_to_delete_from_beginning);
|
|
||||||
}
|
|
||||||
|
|
||||||
index->save(save_path_inc.c_str(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::aligned_free(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
std::string data_type, dist_fn, data_path, index_path_prefix;
|
|
||||||
uint32_t num_threads, R, L, num_start_pts;
|
|
||||||
float alpha, start_point_norm;
|
|
||||||
size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot,
|
|
||||||
points_to_delete_from_beginning, start_deletes_after;
|
|
||||||
bool concurrent;
|
|
||||||
|
|
||||||
// label options
|
|
||||||
std::string label_file, label_type, universal_label;
|
|
||||||
std::uint32_t Lf, unique_labels_supported;
|
|
||||||
|
|
||||||
po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate",
|
|
||||||
"Test insert deletes & consolidate")};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
|
|
||||||
// Required parameters
|
|
||||||
po::options_description required_configs("Required");
|
|
||||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
|
||||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
|
||||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
|
||||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
|
||||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
|
||||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
|
||||||
program_options_utils::INPUT_DATA_PATH);
|
|
||||||
required_configs.add_options()("points_to_skip", po::value<uint64_t>(&points_to_skip)->required(),
|
|
||||||
"Skip these first set of points from file");
|
|
||||||
required_configs.add_options()("beginning_index_size", po::value<uint64_t>(&beginning_index_size)->required(),
|
|
||||||
"Batch build will be called on these set of points");
|
|
||||||
required_configs.add_options()("points_per_checkpoint", po::value<uint64_t>(&points_per_checkpoint)->required(),
|
|
||||||
"Insertions are done in batches of points_per_checkpoint");
|
|
||||||
required_configs.add_options()("checkpoints_per_snapshot",
|
|
||||||
po::value<uint64_t>(&checkpoints_per_snapshot)->required(),
|
|
||||||
"Save the index to disk every few checkpoints");
|
|
||||||
required_configs.add_options()("points_to_delete_from_beginning",
|
|
||||||
po::value<uint64_t>(&points_to_delete_from_beginning)->required(), "");
|
|
||||||
|
|
||||||
// Optional parameters
|
|
||||||
po::options_description optional_configs("Optional");
|
|
||||||
optional_configs.add_options()("num_threads,T",
|
|
||||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
|
||||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
|
||||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
|
||||||
program_options_utils::MAX_BUILD_DEGREE);
|
|
||||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
|
||||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
|
||||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
|
||||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
|
||||||
optional_configs.add_options()("max_points_to_insert",
|
|
||||||
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
|
|
||||||
"These number of points from the file are inserted after "
|
|
||||||
"points_to_skip");
|
|
||||||
optional_configs.add_options()("do_concurrent", po::value<bool>(&concurrent)->default_value(false), "");
|
|
||||||
optional_configs.add_options()("start_deletes_after",
|
|
||||||
po::value<uint64_t>(&start_deletes_after)->default_value(0), "");
|
|
||||||
optional_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->default_value(0),
|
|
||||||
"Set the start point to a random point on a sphere of this radius");
|
|
||||||
|
|
||||||
// optional params for filters
|
|
||||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
|
||||||
"Input label file in txt format for Filtered Index search. "
|
|
||||||
"The file should contain comma separated filters for each node "
|
|
||||||
"with each line corresponding to a graph node");
|
|
||||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
|
||||||
"Universal label, if using it, only in conjunction with labels_file");
|
|
||||||
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
|
|
||||||
"Build complexity for filtered points, higher value "
|
|
||||||
"results in better graphs");
|
|
||||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
|
||||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
|
||||||
"will consume memory 4 bytes per filter");
|
|
||||||
optional_configs.add_options()("unique_labels_supported",
|
|
||||||
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
|
|
||||||
"Number of unique labels supported by the dynamic index.");
|
|
||||||
|
|
||||||
optional_configs.add_options()(
|
|
||||||
"num_start_points",
|
|
||||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
|
||||||
"Set the number of random start (frozen) points to use when "
|
|
||||||
"inserting and searching");
|
|
||||||
|
|
||||||
// Merge required and optional parameters
|
|
||||||
desc.add(required_configs).add(optional_configs);
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
if (beginning_index_size == 0)
|
|
||||||
if (start_point_norm == 0)
|
|
||||||
{
|
|
||||||
std::cout << "When beginning_index_size is 0, use a start "
|
|
||||||
"point with "
|
|
||||||
"appropriate norm"
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << '\n';
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_labels = false;
|
|
||||||
if (!label_file.empty() || label_file != "")
|
|
||||||
{
|
|
||||||
has_labels = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (num_start_pts < unique_labels_supported)
|
|
||||||
{
|
|
||||||
num_start_pts = unique_labels_supported;
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
|
||||||
.with_max_occlusion_size(500)
|
|
||||||
.with_alpha(alpha)
|
|
||||||
.with_num_threads(num_threads)
|
|
||||||
.with_filter_list_size(Lf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
if (data_type == std::string("int8"))
|
|
||||||
build_incremental_index<int8_t>(
|
|
||||||
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
|
|
||||||
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
|
|
||||||
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
|
|
||||||
else if (data_type == std::string("uint8"))
|
|
||||||
build_incremental_index<uint8_t>(
|
|
||||||
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
|
|
||||||
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
|
|
||||||
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
|
|
||||||
else if (data_type == std::string("float"))
|
|
||||||
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
|
|
||||||
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
|
|
||||||
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
|
|
||||||
start_deletes_after, concurrent, label_file, universal_label);
|
|
||||||
else
|
|
||||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
|
||||||
}
|
|
||||||
catch (const std::exception &e)
|
|
||||||
{
|
|
||||||
std::cerr << "Caught exception: " << e.what() << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
std::cerr << "Caught unknown exception" << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
@@ -1,523 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <index.h>
|
|
||||||
#include <numeric>
|
|
||||||
#include <omp.h>
|
|
||||||
#include <string.h>
|
|
||||||
#include <time.h>
|
|
||||||
#include <timer.h>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
#include <future>
|
|
||||||
#include <abstract_index.h>
|
|
||||||
#include <index_factory.h>
|
|
||||||
|
|
||||||
#include "utils.h"
|
|
||||||
#include "filter_utils.h"
|
|
||||||
#include "program_options_utils.hpp"
|
|
||||||
|
|
||||||
#ifndef _WINDOWS
|
|
||||||
#include <sys/mman.h>
|
|
||||||
#include <sys/stat.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "memory_mapper.h"
|
|
||||||
|
|
||||||
namespace po = boost::program_options;
|
|
||||||
|
|
||||||
// load_aligned_bin modified to read pieces of the file, but using ifstream
|
|
||||||
// instead of cached_ifstream.
|
|
||||||
template <typename T>
|
|
||||||
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
|
|
||||||
{
|
|
||||||
std::ifstream reader;
|
|
||||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
|
||||||
reader.open(bin_file, std::ios::binary | std::ios::ate);
|
|
||||||
size_t actual_file_size = reader.tellg();
|
|
||||||
reader.seekg(0, std::ios::beg);
|
|
||||||
|
|
||||||
int npts_i32, dim_i32;
|
|
||||||
reader.read((char *)&npts_i32, sizeof(int));
|
|
||||||
reader.read((char *)&dim_i32, sizeof(int));
|
|
||||||
size_t npts = (uint32_t)npts_i32;
|
|
||||||
size_t dim = (uint32_t)dim_i32;
|
|
||||||
|
|
||||||
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
|
||||||
if (actual_file_size != expected_actual_file_size)
|
|
||||||
{
|
|
||||||
std::stringstream stream;
|
|
||||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
|
|
||||||
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
|
|
||||||
<< std::endl;
|
|
||||||
std::cout << stream.str();
|
|
||||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (offset_points + points_to_read > npts)
|
|
||||||
{
|
|
||||||
std::stringstream stream;
|
|
||||||
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
|
|
||||||
<< " points, but have only " << npts << " points" << std::endl;
|
|
||||||
std::cout << stream.str();
|
|
||||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
|
||||||
}
|
|
||||||
|
|
||||||
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
|
|
||||||
|
|
||||||
const size_t rounded_dim = ROUND_UP(dim, 8);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < points_to_read; i++)
|
|
||||||
{
|
|
||||||
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
|
|
||||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
|
||||||
}
|
|
||||||
reader.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_save_filename(const std::string &save_path, size_t active_window, size_t consolidate_interval,
|
|
||||||
size_t max_points_to_insert)
|
|
||||||
{
|
|
||||||
std::string final_path = save_path;
|
|
||||||
final_path += "act" + std::to_string(active_window) + "-";
|
|
||||||
final_path += "cons" + std::to_string(consolidate_interval) + "-";
|
|
||||||
final_path += "max" + std::to_string(max_points_to_insert);
|
|
||||||
return final_path;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename TagT, typename LabelT>
|
|
||||||
void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data,
|
|
||||||
size_t aligned_dim, std::vector<std::vector<LabelT>> &pts_to_labels)
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
diskann::Timer insert_timer;
|
|
||||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
|
||||||
|
|
||||||
size_t num_failed = 0;
|
|
||||||
#pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed)
|
|
||||||
for (int64_t j = start; j < (int64_t)end; j++)
|
|
||||||
{
|
|
||||||
int insert_result = -1;
|
|
||||||
if (pts_to_labels.size() > 0)
|
|
||||||
{
|
|
||||||
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
|
|
||||||
pts_to_labels[j - start]);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (insert_result != 0)
|
|
||||||
{
|
|
||||||
std::cerr << "Insert failed " << j << std::endl;
|
|
||||||
num_failed++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
|
|
||||||
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
|
|
||||||
<< " points/second overall, " << (end - start) / elapsedSeconds / insert_threads << " per thread)"
|
|
||||||
<< std::endl;
|
|
||||||
if (num_failed > 0)
|
|
||||||
std::cout << num_failed << " of " << end - start << "inserts failed" << std::endl;
|
|
||||||
}
|
|
||||||
catch (std::system_error &e)
|
|
||||||
{
|
|
||||||
std::cout << "Exiting after catching exception in insertion task: " << e.what() << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
|
||||||
void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start,
|
|
||||||
size_t end)
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... ";
|
|
||||||
for (size_t i = start; i < end; ++i)
|
|
||||||
index.lazy_delete(static_cast<TagT>(1 + i));
|
|
||||||
std::cout << "lazy delete done." << std::endl;
|
|
||||||
|
|
||||||
auto report = index.consolidate_deletes(delete_params);
|
|
||||||
while (report._status != diskann::consolidation_report::status_code::SUCCESS)
|
|
||||||
{
|
|
||||||
int wait_time = 5;
|
|
||||||
if (report._status == diskann::consolidation_report::status_code::LOCK_FAIL)
|
|
||||||
{
|
|
||||||
diskann::cerr << "Unable to acquire consolidate delete lock after "
|
|
||||||
<< "deleting points " << start << " to " << end << ". Will retry in " << wait_time
|
|
||||||
<< "seconds." << std::endl;
|
|
||||||
}
|
|
||||||
else if (report._status == diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR)
|
|
||||||
{
|
|
||||||
diskann::cerr << "Inconsistent counts in data structure. "
|
|
||||||
<< "Will retry in " << wait_time << "seconds." << std::endl;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::cerr << "Exiting after unknown error in consolidate delete" << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
std::this_thread::sleep_for(std::chrono::seconds(wait_time));
|
|
||||||
report = index.consolidate_deletes(delete_params);
|
|
||||||
}
|
|
||||||
auto points_processed = report._active_points + report._slots_released;
|
|
||||||
auto deletion_rate = points_processed / report._time;
|
|
||||||
std::cout << "#active points: " << report._active_points << std::endl
|
|
||||||
<< "max points: " << report._max_points << std::endl
|
|
||||||
<< "empty slots: " << report._empty_slots << std::endl
|
|
||||||
<< "deletes processed: " << report._slots_released << std::endl
|
|
||||||
<< "latest delete size: " << report._delete_set_size << std::endl
|
|
||||||
<< "Deletion rate: " << deletion_rate << "/sec "
|
|
||||||
<< "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl;
|
|
||||||
}
|
|
||||||
catch (std::system_error &e)
|
|
||||||
{
|
|
||||||
std::cerr << "Exiting after catching exception in deletion task: " << e.what() << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
|
||||||
void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha,
|
|
||||||
const uint32_t insert_threads, const uint32_t consolidate_threads,
|
|
||||||
size_t max_points_to_insert, size_t active_window, size_t consolidate_interval,
|
|
||||||
const float start_point_norm, uint32_t num_start_pts, const std::string &save_path,
|
|
||||||
const std::string &label_file, const std::string &universal_label, const uint32_t Lf)
|
|
||||||
{
|
|
||||||
const uint32_t C = 500;
|
|
||||||
const bool saturate_graph = false;
|
|
||||||
bool has_labels = label_file != "";
|
|
||||||
|
|
||||||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
|
||||||
.with_max_occlusion_size(C)
|
|
||||||
.with_alpha(alpha)
|
|
||||||
.with_saturate_graph(saturate_graph)
|
|
||||||
.with_num_threads(insert_threads)
|
|
||||||
.with_filter_list_size(Lf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto index_search_params = diskann::IndexSearchParams(L, insert_threads);
|
|
||||||
diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R)
|
|
||||||
.with_max_occlusion_size(C)
|
|
||||||
.with_alpha(alpha)
|
|
||||||
.with_saturate_graph(saturate_graph)
|
|
||||||
.with_num_threads(consolidate_threads)
|
|
||||||
.with_filter_list_size(Lf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
size_t dim, aligned_dim;
|
|
||||||
size_t num_points;
|
|
||||||
|
|
||||||
std::vector<std::vector<LabelT>> pts_to_labels;
|
|
||||||
|
|
||||||
const auto save_path_inc =
|
|
||||||
get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert);
|
|
||||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
|
||||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
|
||||||
if (has_labels)
|
|
||||||
{
|
|
||||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
|
||||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
|
||||||
pts_to_labels = std::get<0>(parse_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
diskann::get_bin_metadata(data_path, num_points, dim);
|
|
||||||
diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims"
|
|
||||||
<< std::endl;
|
|
||||||
aligned_dim = ROUND_UP(dim, 8);
|
|
||||||
auto index_config = diskann::IndexConfigBuilder()
|
|
||||||
.with_metric(diskann::L2)
|
|
||||||
.with_dimension(dim)
|
|
||||||
.with_max_points(active_window + 4 * consolidate_interval)
|
|
||||||
.is_dynamic_index(true)
|
|
||||||
.is_enable_tags(true)
|
|
||||||
.is_use_opq(false)
|
|
||||||
.is_filtered(has_labels)
|
|
||||||
.with_num_pq_chunks(0)
|
|
||||||
.is_pq_dist_build(false)
|
|
||||||
.with_num_frozen_pts(num_start_pts)
|
|
||||||
.with_tag_type(diskann_type_to_name<TagT>())
|
|
||||||
.with_label_type(diskann_type_to_name<LabelT>())
|
|
||||||
.with_data_type(diskann_type_to_name<T>())
|
|
||||||
.with_index_write_params(params)
|
|
||||||
.with_index_search_params(index_search_params)
|
|
||||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
|
||||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
|
|
||||||
auto index = index_factory.create_instance();
|
|
||||||
|
|
||||||
if (universal_label != "")
|
|
||||||
{
|
|
||||||
LabelT u_label = 0;
|
|
||||||
index->set_universal_label(u_label);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (max_points_to_insert == 0)
|
|
||||||
{
|
|
||||||
max_points_to_insert = num_points;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (num_points < max_points_to_insert)
|
|
||||||
throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) +
|
|
||||||
") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")",
|
|
||||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
|
||||||
|
|
||||||
if (max_points_to_insert < active_window + consolidate_interval)
|
|
||||||
throw diskann::ANNException("ERROR: max_points_to_insert < "
|
|
||||||
"active_window + consolidate_interval",
|
|
||||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
|
||||||
|
|
||||||
if (consolidate_interval < max_points_to_insert / 1000)
|
|
||||||
throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__);
|
|
||||||
|
|
||||||
index->set_start_points_at_random(static_cast<T>(start_point_norm));
|
|
||||||
|
|
||||||
T *data = nullptr;
|
|
||||||
diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T),
|
|
||||||
8 * sizeof(T));
|
|
||||||
|
|
||||||
std::vector<TagT> tags(max_points_to_insert);
|
|
||||||
std::iota(tags.begin(), tags.end(), static_cast<TagT>(0));
|
|
||||||
|
|
||||||
diskann::Timer timer;
|
|
||||||
|
|
||||||
std::vector<std::future<void>> delete_tasks;
|
|
||||||
|
|
||||||
auto insert_task = std::async(std::launch::async, [&]() {
|
|
||||||
load_aligned_bin_part(data_path, data, 0, active_window);
|
|
||||||
insert_next_batch<T, TagT, LabelT>(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim,
|
|
||||||
pts_to_labels);
|
|
||||||
});
|
|
||||||
insert_task.wait();
|
|
||||||
|
|
||||||
for (size_t start = active_window; start + consolidate_interval <= max_points_to_insert;
|
|
||||||
start += consolidate_interval)
|
|
||||||
{
|
|
||||||
auto end = std::min(start + consolidate_interval, max_points_to_insert);
|
|
||||||
auto insert_task = std::async(std::launch::async, [&]() {
|
|
||||||
load_aligned_bin_part(data_path, data, start, end - start);
|
|
||||||
insert_next_batch<T, TagT, LabelT>(*index, start, end, params.num_threads, data, aligned_dim,
|
|
||||||
pts_to_labels);
|
|
||||||
});
|
|
||||||
insert_task.wait();
|
|
||||||
|
|
||||||
if (delete_tasks.size() > 0)
|
|
||||||
delete_tasks[delete_tasks.size() - 1].wait();
|
|
||||||
if (start >= active_window + consolidate_interval)
|
|
||||||
{
|
|
||||||
auto start_del = start - active_window - consolidate_interval;
|
|
||||||
auto end_del = start - active_window;
|
|
||||||
|
|
||||||
delete_tasks.emplace_back(std::async(std::launch::async, [&]() {
|
|
||||||
delete_and_consolidate<T, TagT, LabelT>(*index, delete_params, (size_t)start_del, (size_t)end_del);
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (delete_tasks.size() > 0)
|
|
||||||
delete_tasks[delete_tasks.size() - 1].wait();
|
|
||||||
|
|
||||||
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
|
|
||||||
|
|
||||||
index->save(save_path_inc.c_str(), true);
|
|
||||||
|
|
||||||
diskann::aligned_free(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
|
|
||||||
uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, unique_labels_supported;
|
|
||||||
float alpha, start_point_norm;
|
|
||||||
size_t max_points_to_insert, active_window, consolidate_interval;
|
|
||||||
|
|
||||||
po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario",
|
|
||||||
"Test insert deletes & consolidate")};
|
|
||||||
try
|
|
||||||
{
|
|
||||||
desc.add_options()("help,h", "Print information on arguments");
|
|
||||||
|
|
||||||
// Required parameters
|
|
||||||
po::options_description required_configs("Required");
|
|
||||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
|
||||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
|
||||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
|
||||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
|
||||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
|
||||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
|
||||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
|
||||||
program_options_utils::INPUT_DATA_PATH);
|
|
||||||
required_configs.add_options()("active_window", po::value<uint64_t>(&active_window)->required(),
|
|
||||||
"Program maintains an index over an active window of "
|
|
||||||
"this size that slides through the data");
|
|
||||||
required_configs.add_options()("consolidate_interval", po::value<uint64_t>(&consolidate_interval)->required(),
|
|
||||||
"The program simultaneously adds this number of points to the "
|
|
||||||
"right of "
|
|
||||||
"the window while deleting the same number from the left");
|
|
||||||
required_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->required(),
|
|
||||||
"Set the start point to a random point on a sphere of this radius");
|
|
||||||
|
|
||||||
// Optional parameters
|
|
||||||
po::options_description optional_configs("Optional");
|
|
||||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
|
||||||
program_options_utils::MAX_BUILD_DEGREE);
|
|
||||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
|
||||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
|
||||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
|
||||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
|
||||||
optional_configs.add_options()("insert_threads",
|
|
||||||
po::value<uint32_t>(&insert_threads)->default_value(omp_get_num_procs() / 2),
|
|
||||||
"Number of threads used for inserting into the index (defaults to "
|
|
||||||
"omp_get_num_procs()/2)");
|
|
||||||
optional_configs.add_options()(
|
|
||||||
"consolidate_threads", po::value<uint32_t>(&consolidate_threads)->default_value(omp_get_num_procs() / 2),
|
|
||||||
"Number of threads used for consolidating deletes to "
|
|
||||||
"the index (defaults to omp_get_num_procs()/2)");
|
|
||||||
optional_configs.add_options()("max_points_to_insert",
|
|
||||||
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
|
|
||||||
"The number of points from the file that the program streams "
|
|
||||||
"over ");
|
|
||||||
optional_configs.add_options()(
|
|
||||||
"num_start_points",
|
|
||||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
|
||||||
"Set the number of random start (frozen) points to use when "
|
|
||||||
"inserting and searching");
|
|
||||||
|
|
||||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
|
||||||
"Input label file in txt format for Filtered Index search. "
|
|
||||||
"The file should contain comma separated filters for each node "
|
|
||||||
"with each line corresponding to a graph node");
|
|
||||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
|
||||||
"Universal label, if using it, only in conjunction with labels_file");
|
|
||||||
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
|
|
||||||
"Build complexity for filtered points, higher value "
|
|
||||||
"results in better graphs");
|
|
||||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
|
||||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
|
||||||
"will consume memory 4 bytes per filter");
|
|
||||||
optional_configs.add_options()("unique_labels_supported",
|
|
||||||
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
|
|
||||||
"Number of unique labels supported by the dynamic index.");
|
|
||||||
|
|
||||||
// Merge required and optional parameters
|
|
||||||
desc.add(required_configs).add(optional_configs);
|
|
||||||
|
|
||||||
po::variables_map vm;
|
|
||||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
|
||||||
if (vm.count("help"))
|
|
||||||
{
|
|
||||||
std::cout << desc;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
po::notify(vm);
|
|
||||||
}
|
|
||||||
catch (const std::exception &ex)
|
|
||||||
{
|
|
||||||
std::cerr << ex.what() << '\n';
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate arguments
|
|
||||||
if (start_point_norm == 0)
|
|
||||||
{
|
|
||||||
std::cout << "When beginning_index_size is 0, use a start point with "
|
|
||||||
"appropriate norm"
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (label_type != std::string("ushort") && label_type != std::string("uint"))
|
|
||||||
{
|
|
||||||
std::cerr << "Invalid label type. Supported types are uint and ushort" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data_type != std::string("int8") && data_type != std::string("uint8") && data_type != std::string("float"))
|
|
||||||
{
|
|
||||||
std::cerr << "Invalid data type. Supported types are int8, uint8 and float" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Are additional distance functions supported?
|
|
||||||
if (dist_fn != std::string("l2") && dist_fn != std::string("mips"))
|
|
||||||
{
|
|
||||||
std::cerr << "Invalid distance function. Supported functions are l2 and mips" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (num_start_pts < unique_labels_supported)
|
|
||||||
{
|
|
||||||
num_start_pts = unique_labels_supported;
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
if (data_type == std::string("uint8"))
|
|
||||||
{
|
|
||||||
if (label_type == std::string("ushort"))
|
|
||||||
{
|
|
||||||
build_incremental_index<uint8_t, uint32_t, uint16_t>(
|
|
||||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
|
||||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
|
||||||
universal_label, Lf);
|
|
||||||
}
|
|
||||||
else if (label_type == std::string("uint"))
|
|
||||||
{
|
|
||||||
build_incremental_index<uint8_t, uint32_t, uint32_t>(
|
|
||||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
|
||||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
|
||||||
universal_label, Lf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("int8"))
|
|
||||||
{
|
|
||||||
if (label_type == std::string("ushort"))
|
|
||||||
{
|
|
||||||
build_incremental_index<int8_t, uint32_t, uint16_t>(
|
|
||||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
|
||||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
|
||||||
universal_label, Lf);
|
|
||||||
}
|
|
||||||
else if (label_type == std::string("uint"))
|
|
||||||
{
|
|
||||||
build_incremental_index<int8_t, uint32_t, uint32_t>(
|
|
||||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
|
||||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
|
||||||
universal_label, Lf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (data_type == std::string("float"))
|
|
||||||
{
|
|
||||||
if (label_type == std::string("ushort"))
|
|
||||||
{
|
|
||||||
build_incremental_index<float, uint32_t, uint16_t>(
|
|
||||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
|
||||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
|
||||||
universal_label, Lf);
|
|
||||||
}
|
|
||||||
else if (label_type == std::string("uint"))
|
|
||||||
{
|
|
||||||
build_incremental_index<float, uint32_t, uint32_t>(
|
|
||||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
|
||||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
|
||||||
universal_label, Lf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::exception &e)
|
|
||||||
{
|
|
||||||
std::cerr << "Caught exception: " << e.what() << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
std::cerr << "Caught unknown exception" << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
|
||||||
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
|
|
||||||
|
|
||||||
|
|
||||||
add_executable(fvecs_to_bin fvecs_to_bin.cpp)
|
|
||||||
|
|
||||||
add_executable(fvecs_to_bvecs fvecs_to_bvecs.cpp)
|
|
||||||
|
|
||||||
add_executable(rand_data_gen rand_data_gen.cpp)
|
|
||||||
target_link_libraries(rand_data_gen ${PROJECT_NAME} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(float_bin_to_int8 float_bin_to_int8.cpp)
|
|
||||||
|
|
||||||
add_executable(ivecs_to_bin ivecs_to_bin.cpp)
|
|
||||||
|
|
||||||
add_executable(count_bfs_levels count_bfs_levels.cpp)
|
|
||||||
target_link_libraries(count_bfs_levels ${PROJECT_NAME} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(tsv_to_bin tsv_to_bin.cpp)
|
|
||||||
|
|
||||||
add_executable(bin_to_tsv bin_to_tsv.cpp)
|
|
||||||
|
|
||||||
add_executable(int8_to_float int8_to_float.cpp)
|
|
||||||
target_link_libraries(int8_to_float ${PROJECT_NAME})
|
|
||||||
|
|
||||||
add_executable(int8_to_float_scale int8_to_float_scale.cpp)
|
|
||||||
target_link_libraries(int8_to_float_scale ${PROJECT_NAME})
|
|
||||||
|
|
||||||
add_executable(uint8_to_float uint8_to_float.cpp)
|
|
||||||
target_link_libraries(uint8_to_float ${PROJECT_NAME})
|
|
||||||
|
|
||||||
add_executable(uint32_to_uint8 uint32_to_uint8.cpp)
|
|
||||||
target_link_libraries(uint32_to_uint8 ${PROJECT_NAME})
|
|
||||||
|
|
||||||
add_executable(vector_analysis vector_analysis.cpp)
|
|
||||||
target_link_libraries(vector_analysis ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
|
||||||
|
|
||||||
add_executable(gen_random_slice gen_random_slice.cpp)
|
|
||||||
target_link_libraries(gen_random_slice ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
|
||||||
|
|
||||||
add_executable(simulate_aggregate_recall simulate_aggregate_recall.cpp)
|
|
||||||
|
|
||||||
add_executable(calculate_recall calculate_recall.cpp)
|
|
||||||
target_link_libraries(calculate_recall ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
|
||||||
|
|
||||||
# Compute ground truth thing outside of DiskANN main source that depends on MKL.
|
|
||||||
add_executable(compute_groundtruth compute_groundtruth.cpp)
|
|
||||||
target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
|
|
||||||
target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp)
|
|
||||||
target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
|
|
||||||
target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
|
||||||
|
|
||||||
|
|
||||||
add_executable(generate_pq generate_pq.cpp)
|
|
||||||
target_link_libraries(generate_pq ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
|
||||||
|
|
||||||
|
|
||||||
add_executable(partition_data partition_data.cpp)
|
|
||||||
target_link_libraries(partition_data ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
|
||||||
|
|
||||||
add_executable(partition_with_ram_budget partition_with_ram_budget.cpp)
|
|
||||||
target_link_libraries(partition_with_ram_budget ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
|
||||||
|
|
||||||
add_executable(merge_shards merge_shards.cpp)
|
|
||||||
target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB})
|
|
||||||
|
|
||||||
add_executable(create_disk_layout create_disk_layout.cpp)
|
|
||||||
target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
|
||||||
|
|
||||||
add_executable(generate_synthetic_labels generate_synthetic_labels.cpp)
|
|
||||||
target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options)
|
|
||||||
|
|
||||||
add_executable(stats_label_data stats_label_data.cpp)
|
|
||||||
target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options)
|
|
||||||
|
|
||||||
if (NOT MSVC)
|
|
||||||
include(GNUInstallDirs)
|
|
||||||
install(TARGETS fvecs_to_bin
|
|
||||||
fvecs_to_bvecs
|
|
||||||
rand_data_gen
|
|
||||||
float_bin_to_int8
|
|
||||||
ivecs_to_bin
|
|
||||||
count_bfs_levels
|
|
||||||
tsv_to_bin
|
|
||||||
bin_to_tsv
|
|
||||||
int8_to_float
|
|
||||||
int8_to_float_scale
|
|
||||||
uint8_to_float
|
|
||||||
uint32_to_uint8
|
|
||||||
vector_analysis
|
|
||||||
gen_random_slice
|
|
||||||
simulate_aggregate_recall
|
|
||||||
calculate_recall
|
|
||||||
compute_groundtruth
|
|
||||||
compute_groundtruth_for_filters
|
|
||||||
generate_pq
|
|
||||||
partition_data
|
|
||||||
partition_with_ram_budget
|
|
||||||
merge_shards
|
|
||||||
create_disk_layout
|
|
||||||
generate_synthetic_labels
|
|
||||||
stats_label_data
|
|
||||||
RUNTIME
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include "util.h"
|
|
||||||
|
|
||||||
void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts,
|
|
||||||
uint64_t ndims)
|
|
||||||
{
|
|
||||||
writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned)));
|
|
||||||
#pragma omp parallel for
|
|
||||||
for (uint64_t i = 0; i < npts; i++)
|
|
||||||
{
|
|
||||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
|
|
||||||
}
|
|
||||||
readr.read((char *)write_buf, npts * ndims * sizeof(float));
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
|
||||||
{
|
|
||||||
if (argc != 3)
|
|
||||||
{
|
|
||||||
std::cout << argv[0] << " input_bin output_fvecs" << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
std::ifstream readr(argv[1], std::ios::binary);
|
|
||||||
int npts_s32;
|
|
||||||
int ndims_s32;
|
|
||||||
readr.read((char *)&npts_s32, sizeof(int32_t));
|
|
||||||
readr.read((char *)&ndims_s32, sizeof(int32_t));
|
|
||||||
size_t npts = npts_s32;
|
|
||||||
size_t ndims = ndims_s32;
|
|
||||||
uint32_t ndims_u32 = (uint32_t)ndims_s32;
|
|
||||||
// uint64_t fsize = writr.tellg();
|
|
||||||
readr.seekg(0, std::ios::beg);
|
|
||||||
|
|
||||||
unsigned ndims_u32;
|
|
||||||
writr.write((char *)&ndims_u32, sizeof(unsigned));
|
|
||||||
writr.seekg(0, std::ios::beg);
|
|
||||||
uint64_t ndims = (uint64_t)ndims_u32;
|
|
||||||
uint64_t npts = fsize / ((ndims + 1) * sizeof(float));
|
|
||||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
|
||||||
|
|
||||||
uint64_t blk_size = 131072;
|
|
||||||
uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
|
||||||
std::cout << "# blks: " << nblks << std::endl;
|
|
||||||
|
|
||||||
std::ofstream writr(argv[2], std::ios::binary);
|
|
||||||
float *read_buf = new float[npts * (ndims + 1)];
|
|
||||||
float *write_buf = new float[npts * ndims];
|
|
||||||
for (uint64_t i = 0; i < nblks; i++)
|
|
||||||
{
|
|
||||||
uint64_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
|
||||||
block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims);
|
|
||||||
std::cout << "Block #" << i << " written" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
delete[] read_buf;
|
|
||||||
delete[] write_buf;
|
|
||||||
|
|
||||||
writr.close();
|
|
||||||
readr.close();
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user