Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text

72
.gitignore vendored Executable file
View File

@@ -0,0 +1,72 @@
raw_data/
scaling_out/
scaling_out_old/
sanity_check/
demo/indices/
# .vscode/
*.log
*pycache*
outputs/
*.pkl
.history/
scripts/
lm_eval.egg-info/
demo/experiment_results/**/*.json
*.jsonl
*.sh
*.txt
!CMakeLists.txt
latency_breakdown*.json
experiment_results/eval_results/diskann/*.json
aws/
.venv/
.cursor/rules/
*.egg-info/
skip_reorder_comparison/
analysis_results/
build/
.cache/
nprobe_logs/
micro/results
micro/contriever-INT8
*.qdstrm
benchmark_results/
results/
frac_*.png
final_in_*.png
embedding_comparison_results/
*.ind
*.gz
*.fvecs
*.ivecs
*.index
*.bin
read_graph
analyze_diskann_graph
degree_distribution.png
micro/degree_distribution.png
policy_results_*
results_*/
experiment_results/
.DS_Store
# The above are inherited from old Power RAG repo
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
.env
test_indices*/
test_*.py
!tests/**
packages/leann-backend-diskann/third_party/DiskANN/_deps/

6
.gitmodules vendored Normal file
View File

@@ -0,0 +1,6 @@
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
path = packages/leann-backend-diskann/third_party/DiskANN
url = https://github.com/yichuan520030910320/DiskANN.git
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
path = packages/leann-backend-hnsw/third_party/faiss
url = https://github.com/yichuan520030910320/faiss.git

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.11

9
.vscode/extensions.json vendored Normal file
View File

@@ -0,0 +1,9 @@
{
"recommendations": [
"llvm-vs-code-extensions.vscode-clangd",
"ms-python.python",
"ms-vscode.cmake-tools",
"vadimcn.vscode-lldb",
"eamodio.gitlens",
]
}

283
.vscode/launch.json vendored Executable file
View File

@@ -0,0 +1,283 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
// new emdedder
{
"name": "New Embedder",
"type": "debugpy",
"request": "launch",
"program": "demo/main.py",
"console": "integratedTerminal",
"args": [
"--search",
"--use-original",
"--domain",
"dpr",
"--nprobe",
"5000",
"--load",
"flat",
"--embedder",
"intfloat/multilingual-e5-small"
]
}
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
{
"name": "main.py",
"type": "debugpy",
"request": "launch",
"program": "demo/main.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--query",
"1000",
"--load",
"bm25"
]
},
{
"name": "Simple Build",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"faiss/demo/simple_build.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
//# Fix for Intel MKL error
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
//python faiss/demo/build_demo.py
{
"name": "Build Demo",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"faiss/demo/build_demo.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
{
"name": "DiskANN Serve",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--mode",
"serve",
"--engine",
"sglang",
"--load-indices",
"diskann",
"--domain",
"rpj_wiki",
"--lazy-load",
"--recompute-beighbor-embeddings",
"--port",
"8082",
"--diskann-search-memory-maximum",
"2",
"--diskann-graph",
"240",
"--search-only"
],
"env": {
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
},
"preLaunchTask": "CMake: build",
},
{
"name": "DiskANN Serve MAC",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--mode",
"serve",
"--engine",
"ollama",
"--load-indices",
"diskann",
"--domain",
"rpj_wiki",
"--lazy-load",
"--recompute-beighbor-embeddings"
],
"preLaunchTask": "CMake: build",
"env": {
"KMP_DUPLICATE_LIB_OK": "TRUE",
"OMP_NUM_THREADS": "1",
"MKL_NUM_THREADS": "1",
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
"KMP_BLOCKTIME": "0"
}
},
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "ric/main_ric.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--config-name",
"${input:configSelection}"
],
"justMyCode": false
},
//python ./demo/validate_equivalence.py sglang
{
"name": "Validate Equivalence",
"type": "debugpy",
"request": "launch",
"program": "demo/validate_equivalence.py",
"console": "integratedTerminal",
"args": [
"sglang"
],
},
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
{
"name": "Retrieval Demo",
"type": "debugpy",
"request": "launch",
"program": "demo/retrieval_demo.py",
"console": "integratedTerminal",
"args": [
"--engine",
"vllm",
"--skip-embeddings",
"--domain",
"dpr",
"--load-indices",
// "flat",
"ivf_flat"
],
},
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
{
"name": "Retrieval Demo DiskANN",
"type": "debugpy",
"request": "launch",
"program": "demo/retrieval_demo.py",
"console": "integratedTerminal",
"args": [
"--engine",
"sglang",
"--skip-embeddings",
"--domain",
"dpr",
"--load-indices",
"diskann",
"--hnsw-M",
"64",
"--hnsw-efConstruction",
"150",
"--hnsw-efSearch",
"128",
"--hnsw-sq-bits",
"8"
],
},
{
"name": "Find Probe",
"type": "debugpy",
"request": "launch",
"program": "find_probe.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
},
{
"name": "Python: Attach",
"type": "debugpy",
"request": "attach",
"processId": "${command:pickProcess}",
"justMyCode": true
},
{
"name": "Edge RAG",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"edgerag_demo.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
"MKL_NUM_THREADS": "1",
"OMP_NUM_THREADS": "1",
}
},
{
"name": "Launch Embedding Server",
"type": "debugpy",
"request": "launch",
"program": "demo/embedding_server.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--domain",
"rpj_wiki",
"--zmq-port",
"5556",
]
},
{
"name": "HNSW Serve",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--domain",
"rpj_wiki",
"--load",
"hnsw",
"--mode",
"serve",
"--search",
"--skip-pa",
"--recompute",
"--hnsw-old"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
],
"inputs": [
{
"id": "configSelection",
"type": "pickString",
"description": "Select a configuration",
"options": [
"example_config",
"vllm_gritlm"
],
"default": "example_config"
}
],
}

43
.vscode/settings.json vendored Executable file
View File

@@ -0,0 +1,43 @@
{
"python.analysis.extraPaths": [
"./sglang_repo/python"
],
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
"cmake.configureArgs": [
"-DPYBIND=True",
"-DUPDATE_EDITABLE_INSTALL=ON",
],
"cmake.environment": {
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
},
"cmake.buildDirectory": "${workspaceFolder}/build",
"files.associations": {
"*.tcc": "cpp",
"deque": "cpp",
"string": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"map": "cpp",
"unordered_set": "cpp",
"atomic": "cpp",
"inplace_vector": "cpp",
"*.ipp": "cpp",
"forward_list": "cpp",
"list": "cpp",
"any": "cpp",
"system_error": "cpp",
"__hash_table": "cpp",
"__split_buffer": "cpp",
"__tree": "cpp",
"ios": "cpp",
"set": "cpp",
"__string": "cpp",
"string_view": "cpp",
"ranges": "cpp",
"iosfwd": "cpp"
},
"lldb.displayFormat": "auto",
"lldb.showDisassembly": "auto",
"lldb.dereferencePointers": true,
"lldb.consoleMode": "commands",
}

16
.vscode/tasks.json vendored Normal file
View File

@@ -0,0 +1,16 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "cmake",
"label": "CMake: build",
"command": "build",
"targets": [
"all"
],
"group": "build",
"problemMatcher": [],
"detail": "CMake template build task"
}
]
}

9
LICENSE Executable file
View File

@@ -0,0 +1,9 @@
MIT License
Copyright (c) 2024 Rulin Shao
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

292
README.md Executable file
View File

@@ -0,0 +1,292 @@
# 🚀 LEANN: A Low-Storage Vector Index
<p align="center">
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome">
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows-lightgrey" alt="Platform">
</p>
<p align="center">
<strong>⚡ Real-time embedding computation for large-scale RAG on consumer hardware</strong>
</p>
<p align="center">
<a href="#-quick-start">Quick Start</a> •
<a href="#-features">Features</a> •
<a href="#-benchmarks">Benchmarks</a> •
<a href="#-documentation">Documentation</a> •
<a href="#-paper">Paper</a>
</p>
---
## 🌟 What is Leann?
**Leann** revolutionizes Retrieval-Augmented Generation (RAG) by eliminating the storage bottleneck of traditional vector databases. Instead of pre-computing and storing billions of embeddings, Leann dynamically computes embeddings at query time using highly optimized graph-based search algorithms.
### 🎯 Why Leann?
Traditional RAG systems face a fundamental trade-off:
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
- **💰 Cost**: Vector databases are expensive to scale
**Leann solves this by:**
-**Zero embedding storage** - Only graph structure is persisted
-**Real-time computation** - Embeddings computed on-demand with ms latency
-**Memory efficient** - Runs on consumer hardware (8GB RAM)
-**Always fresh** - No stale embeddings, ever
## 🚀 Quick Start
### Installation
```bash
git clone https://github.com/yichuan520030910320/Power-RAG.git leann
cd leann
uv sync
```
### 30-Second Example
```python
from leann.api import LeannBuilder, LeannSearcher
# 1. Build index (no embeddings stored!)
builder = LeannBuilder(backend_name="diskann")
builder.add_text("Python is a powerful programming language")
builder.add_text("Machine learning transforms industries")
builder.add_text("Neural networks process complex data")
builder.build_index("knowledge.leann")
# 2. Search with real-time embeddings
searcher = LeannSearcher("knowledge.leann")
results = searcher.search("programming languages", top_k=2)
for result in results:
print(f"Score: {result['score']:.3f} - {result['text']}")
```
### Run the Demo
```bash
uv run examples/document_search.py
```
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
This demo showcases how to build a RAG system for PDF documents using Leann.
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
```bash
uv run examples/main_cli_example.py
```
## ✨ Features
### 🔥 Core Features
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
- **📈 Scalable Architecture**: Handles millions of documents on consumer hardware
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
### 🛠️ Technical Highlights
- **Zero-copy operations** for maximum performance
- **SIMD-optimized** distance computations (AVX2/AVX512)
- **Async embedding pipeline** with batched processing
- **Memory-mapped indices** for fast startup
- **Recompute mode** for highest accuracy scenarios
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
- **Rich debugging tools** - Built-in performance profiling
## 📊 Benchmarks
### Memory Usage Comparison
| System | 1M Documents | 10M Documents | 100M Documents |
|--------|-------------|---------------|----------------|
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
### Query Performance
| Backend | Index Size | Query Time | Recall@10 |
|---------|------------|------------|-----------|
| DiskANN | 1M docs | 12ms | 0.95 |
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
| HNSW | 1M docs | 8ms | 0.93 |
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
## 🏗️ Architecture
```
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
│ Query Text │───▶│ Embedding │───▶│ Graph-based │
│ │ │ Computation │ │ Search │
└─────────────────┘ └──────────────────┘ └─────────────────┘
│ │
▼ ▼
┌──────────────┐ ┌──────────────┐
│ ZMQ Server │ │ Pruned Graph │
│ (Cached) │ │ Index │
└──────────────┘ └──────────────┘
```
### Key Components
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
2. **📊 Graph Index**: Memory-efficient navigation structures
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
## 🎓 Supported Models & Backends
### 🤖 Embedding Models
- **sentence-transformers/all-mpnet-base-v2** (default)
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
- Any HuggingFace sentence-transformer model
- Custom model support via API
### 🔧 Search Backends
- **DiskANN**: Microsoft's billion-scale ANN algorithm
- **HNSW**: Hierarchical Navigable Small World graphs
- **Coming soon**: ScaNN, Faiss-IVF, NGT
### 📏 Distance Functions
- **L2**: Euclidean distance for precise similarity
- **Cosine**: Angular similarity for normalized vectors
- **MIPS**: Maximum Inner Product Search for recommendation systems
## 🔬 Paper
If you find Leann useful, please cite:
**[LEANN: A Low-Storage Vector Index](https://arxiv.org/abs/2506.08276)**
```bibtex
@misc{wang2025leannlowstoragevectorindex,
title={LEANN: A Low-Storage Vector Index},
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
year={2025},
eprint={2506.08276},
archivePrefix={arXiv},
primaryClass={cs.DB},
url={https://arxiv.org/abs/2506.08276},
}
```
## 🌍 Use Cases
### 💼 Enterprise RAG
```python
# Handle millions of documents with limited resources
builder = LeannBuilder(
backend_name="diskann",
distance_metric="cosine",
graph_degree=64,
memory_budget="4GB"
)
```
### 🔬 Research & Experimentation
```python
# Quick prototyping with different algorithms
for backend in ["diskann", "hnsw"]:
searcher = LeannSearcher(index_path, backend=backend)
evaluate_recall(searcher, queries, ground_truth)
```
### 🚀 Real-time Applications
```python
# Sub-second response times
chat = LeannChat("knowledge.leann")
response = chat.ask("What is quantum computing?")
# Returns in <100ms with recompute mode
```
## 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
### Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
### Development Setup
```bash
git clone https://github.com/yourname/leann
cd leann
uv sync --dev
uv run pytest tests/
```
### Quick Tests
```bash
# Sanity check all distance functions
uv run python tests/sanity_checks/test_distance_functions.py
# Verify L2 implementation
uv run python tests/sanity_checks/test_l2_verification.py
```
## 📈 Roadmap
### 🎯 Q1 2024
- [x] DiskANN backend with MIPS/L2/Cosine support
- [x] HNSW backend integration
- [x] Real-time embedding pipeline
- [x] Memory-efficient graph pruning
### 🚀 Q2 2024
- [ ] Distributed search across multiple nodes
- [ ] ScaNN backend support
- [ ] Advanced caching strategies
- [ ] Kubernetes deployment guides
### 🌟 Q3 2024
- [ ] GPU-accelerated embedding computation
- [ ] Approximate distance functions
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
## 💬 Community
Join our growing community of researchers and engineers!
- 🐦 **Twitter**: [@LeannAI](https://twitter.com/LeannAI)
- 💬 **Discord**: [Join our server](https://discord.gg/leann)
- 📧 **Email**: leann@yourcompany.com
- 🐙 **GitHub Discussions**: [Ask questions here](https://github.com/yourname/leann/discussions)
## 📄 License
MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments
- **Microsoft Research** for the DiskANN algorithm
- **Meta AI** for FAISS and optimization insights
- **HuggingFace** for the transformer ecosystem
- **Our amazing contributors** who make this possible
---
<p align="center">
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
</p>
<p align="center">
Made with ❤️ by the Leann team
</p>

248
demo.ipynb Normal file
View File

@@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: LeannBuilder initialized with 'diskann' backend.\n",
"INFO: Computing embeddings for 6 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 77.61it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Building DiskANN index for 6 vectors with metric Metric.INNER_PRODUCT...\n",
"Using Inner Product search, so need to pre-process base data into temp file. Please ensure there is additional (n*(d+1)*4) bytes for storing pre-processed base vectors, apart from the interim indices created by DiskANN and the final index.\n",
"Pre-processing base file by adding extra coordinate\n",
"✅ DiskANN index built successfully at 'knowledge'\n",
"Writing bin: knowledge_disk.index_max_base_norm.bin\n",
"bin: #pts = 1, #dims = 1, size = 12B\n",
"Finished writing bin.\n",
"Time for preprocessing data for inner product: 0.000165 seconds\n",
"Reading max_norm_of_base from knowledge_disk.index_max_base_norm.bin\n",
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
"Metadata: #pts = 1, #dims = 1...\n",
"done.\n",
"max_norm_of_base: 1\n",
"! Using prepped_base file at knowledge_prepped_base.bin\n",
"Starting index build: R=32 L=64 Query RAM budget: 4.02653e+09 Indexing ram budget: 8 T: 8\n",
"getting bin metadata\n",
"Time for getting bin metadata: 0.000008 seconds\n",
"Compressing 769-dimensional data into 512 bytes per vector.\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Training data with 6 samples loaded.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"PQ pivot file exists. Not generating again\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 4, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 769, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 513, #dims = 1...\n",
"done.\n",
"Loaded PQ pivot information\n",
"Processing points [0, 6)...done.\n",
"Time for generating quantized data: 0.023918 seconds\n",
"Full index fits in RAM budget, should consume at most 2.03973e-05GiBs, so building in one shot\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"Passed, empty search_params while creating index config\n",
"Using only first 6 from file.. \n",
"Starting index build with 6 points... \n",
"0% of index build completed.Starting final cleanup..done. Link time: 9e-05s\n",
"Index built with degree: max:5 avg:5 min:5 count(deg<2):0\n",
"Not saving tags as they are not enabled.\n",
"Time taken for save: 0.000178s.\n",
"Time for building merged vamana index: 0.000579 seconds\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Vamana index file size=168\n",
"Opened: knowledge_disk.index, cache_size: 67108864\n",
"medoid: 0B\n",
"max_node_len: 3100B\n",
"nnodes_per_sector: 1B\n",
"# sectors: 6\n",
"Sector #0written\n",
"Finished writing 28672B\n",
"Writing bin: knowledge_disk.index\n",
"bin: #pts = 9, #dims = 1, size = 80B\n",
"Finished writing bin.\n",
"Output disk index file written to knowledge_disk.index\n",
"Finished writing 28672B\n",
"Time for generating disk layout: 0.043488 seconds\n",
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
"Loading base knowledge_prepped_base.bin. #points: 6. #dim: 769.\n",
"Wrote 1 points to sample file: knowledge_sample_data.bin\n",
"Indexing time: 0.0684344\n",
"INFO: Leann metadata saved to knowledge.leann.meta.json\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"Opened file : knowledge_disk.index\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Since data is floating point, we assume that it has been appropriately pre-processed (normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we shall invoke an l2 distance function.\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"L2: Using AVX2 distance computation DistanceL2Float\n",
"Before index load\n",
"✅ DiskANN index loaded successfully.\n",
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
"Reading bin file knowledge_pq_compressed.bin ...\n",
"Opening bin file knowledge_pq_compressed.bin... \n",
"Metadata: #pts = 6, #dims = 512...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 4, #dims = 1...\n",
"done.\n",
"Offsets: 4096 791560 794644 796704\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 256, #dims = 769...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 769, #dims = 1...\n",
"done.\n",
"Reading bin file knowledge_pq_pivots.bin ...\n",
"Opening bin file knowledge_pq_pivots.bin... \n",
"Metadata: #pts = 513, #dims = 1...\n",
"done.\n",
"Loaded PQ Pivots: #ctrs: 256, #dims: 769, #chunks: 512\n",
"Loaded PQ centroids and in-memory compressed vectors. #points: 6 #dim: 769 #aligned_dim: 776 #chunks: 512\n",
"Loading index metadata from knowledge_disk.index\n",
"Disk-Index File Meta-data: # nodes per sector: 1, max node len (bytes): 3100, max node degree: 5\n",
"Disk-Index Meta: nodes per sector: 1, max node len: 3100, max node degree: 5\n",
"Setting up thread-specific contexts for nthreads: 8\n",
"allocating ctx: 0x78348f4de000 to thread-id:132170359560000\n",
"allocating ctx: 0x78348f4cd000 to thread-id:132158431693760\n",
"allocating ctx: 0x78348f4bc000 to thread-id:132158442179392\n",
"allocating ctx: 0x78348f4ab000 to thread-id:132158421208128\n",
"allocating ctx: 0x78348f49a000 to thread-id:132158452665024\n",
"allocating ctx: 0x78348f489000 to thread-id:132158389751232\n",
"allocating ctx: 0x78348f478000 to thread-id:132158410722496\n",
"allocating ctx: 0x78348f467000 to thread-id:132158400236864\n",
"Loading centroid data from medoids vector data of 1 medoid(s)\n",
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
"Metadata: #pts = 1, #dims = 1...\n",
"done.\n",
"Setting re-scaling factor of base vectors to 1\n",
"load_from_separate_paths done.\n",
"Reading (with alignment) bin file knowledge_sample_data.bin ...Metadata: #pts = 1, #dims = 769, aligned_dim = 776... allocating aligned memory of 3104 bytes... done. Copying data to mem_aligned buffer... done.\n",
"reserve ratio: 1\n",
"Graph traversal completed, hops: 3\n",
"Loading the cache list into memory....done.\n",
"After index load\n",
"Clearing scratch\n",
"INFO: Computing embeddings for 1 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 92.66it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score: -0.481 - C++ is a powerful programming language\n",
"Score: -1.049 - Java is a powerful programming language\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"reserve ratio: 1\n",
"Graph traversal completed, hops: 3\n"
]
}
],
"source": [
"from leann.api import LeannBuilder, LeannSearcher\n",
"import leann_backend_diskann\n",
"# 1. Build index (no embeddings stored!)\n",
"builder = LeannBuilder(backend_name=\"diskann\")\n",
"builder.add_text(\"Python is a powerful programming language\")\n",
"builder.add_text(\"Machine learning transforms industries\") \n",
"builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Java is a powerful programming language\")\n",
"builder.add_text(\"C++ is a powerful programming language\")\n",
"builder.add_text(\"C# is a powerful programming language\")\n",
"builder.build_index(\"knowledge.leann\")\n",
"\n",
"# 2. Search with real-time embeddings\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n",
"results = searcher.search(\"C++ programming languages\", top_k=2)\n",
"\n",
"for result in results:\n",
" print(f\"Score: {result['score']:.3f} - {result['text']}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

File diff suppressed because it is too large Load Diff

146
examples/document_search.py Normal file
View File

@@ -0,0 +1,146 @@
#!/usr/bin/env python3
"""
Document search demo with recompute mode
"""
import os
from pathlib import Path
import shutil
import time
# Import backend packages to trigger plugin registration
try:
import leann_backend_diskann
import leann_backend_hnsw
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
# Import upper-level API from leann-core
from leann.api import LeannBuilder, LeannSearcher, LeannChat
def load_sample_documents():
"""Create sample documents for demonstration"""
docs = [
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
]
return docs
def main():
print("==========================================================")
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
print("==========================================================")
INDEX_DIR = Path("./test_indices")
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
BACKEND_TO_TEST = "diskann"
if INDEX_DIR.exists():
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
shutil.rmtree(INDEX_DIR)
# --- 1. Build index ---
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
builder = LeannBuilder(
backend_name=BACKEND_TO_TEST,
graph_degree=32,
complexity=64
)
documents = load_sample_documents()
print(f"Loaded {len(documents)} sample documents.")
for doc in documents:
builder.add_text(doc["content"], metadata={"title": doc["title"]})
builder.build_index(INDEX_PATH)
print(f"\nIndex built!")
# --- 2. Basic search demo ---
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
searcher = LeannSearcher(index_path=INDEX_PATH)
query = "What is machine learning?"
print(f"\nQuery: '{query}'")
print("\n--- Basic search mode (PQ computation) ---")
start_time = time.time()
results = searcher.search(query, top_k=2)
basic_time = time.time() - start_time
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<")
for i, res in enumerate(results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
# --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...")
print("\n--- Recompute search mode (get real embeddings via network) ---")
# Configure recompute parameters
recompute_params = {
"recompute_beighbor_embeddings": True, # Enable network recomputation
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
"skip_search_reorder": True, # Skip search reordering
"dedup_node_dis": True, # Enable node distance deduplication
"prune_ratio": 0.1, # Pruning ratio 10%
"batch_recompute": False, # Don't use batch recomputation
"global_pruning": False, # Don't use global pruning
"zmq_port": 5555, # ZMQ port
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
}
print("Recompute parameter configuration:")
for key, value in recompute_params.items():
print(f" {key}: {value}")
print(f"\n🔄 Executing Recompute search...")
try:
start_time = time.time()
recompute_results = searcher.search(query, top_k=2, **recompute_params)
recompute_time = time.time() - start_time
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
# Compare results
print(f"\n--- Result comparison ---")
print(f"Basic search time: {basic_time:.3f} seconds")
print(f"Recompute time: {recompute_time:.3f} seconds")
print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))):
basic_score = results[i]['score']
recompute_score = recompute_results[i]['score']
score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
if recompute_time > basic_time:
print(f"✅ Recompute mode working correctly (more accurate but slower)")
else:
print(f" Recompute time is unusually fast, network recomputation may not be enabled")
except Exception as e:
print(f"❌ Recompute search failed: {e}")
print("This usually indicates an embedding server connection issue")
# --- 4. Chat demo ---
print(f"\n[PHASE 4] Starting chat session...")
chat = LeannChat(index_path=INDEX_PATH)
chat_response = chat.ask(query)
print(f"You: {query}")
print(f"Leann: {chat_response}")
print("\n==========================================================")
print("✅ Demo finished successfully!")
print("==========================================================")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,76 @@
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.readers.base import BaseReader
from llama_index.node_parser.docling import DoclingNodeParser
from llama_index.readers.docling import DoclingReader
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
import asyncio
import os
import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat
import leann_backend_diskann # Import to ensure backend registration
import shutil
from pathlib import Path
dotenv.load_dotenv()
reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)
file_extractor: dict[str, BaseReader] = {
".docx": reader,
".pptx": reader,
".pdf": reader,
".xlsx": reader,
}
node_parser = DoclingNodeParser(
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=10240)
)
documents = SimpleDirectoryReader(
"examples/data",
recursive=True,
file_extractor=file_extractor,
encoding="utf-8",
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
).load_data(show_progress=True)
# Extract text from documents and prepare for Leann
all_texts = []
for doc in documents:
# DoclingNodeParser returns Node objects, which have a text attribute
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.text)
INDEX_DIR = Path("./test_pdf_index")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if INDEX_DIR.exists():
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
shutil.rmtree(INDEX_DIR)
print(f"\n[PHASE 1] Building Leann index...")
builder = LeannBuilder(
backend_name="diskann",
embedding_model="sentence-transformers/all-mpnet-base-v2", # Using a common sentence transformer model
graph_degree=32,
complexity=64
)
print(f"Loaded {len(all_texts)} text chunks from documents.")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
print(f"\nLeann index built at {INDEX_PATH}!")
async def main():
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
query = "Based on the paper, what are the two main techniques LEANN uses to achieve low storage overhead and high retrieval accuracy?"
print(f"You: {query}")
chat_response = chat.ask(query, recompute_beighbor_embeddings=True)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
asyncio.run(main())

81
examples/simple_demo.py Normal file
View File

@@ -0,0 +1,81 @@
"""
Simple demo showing basic leann usage
Run: uv run python examples/simple_demo.py
"""
from leann import LeannBuilder, LeannSearcher, LeannChat
def main():
print("=== Leann Simple Demo ===")
print()
# Sample knowledge base
chunks = [
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
"Deep learning uses neural networks with multiple layers to process data and make decisions.",
"Natural language processing helps computers understand and generate human language.",
"Computer vision enables machines to interpret and understand visual information from images and videos.",
"Reinforcement learning teaches agents to make decisions by receiving rewards or penalties for their actions.",
"Data science combines statistics, programming, and domain expertise to extract insights from data.",
"Big data refers to extremely large datasets that require special tools and techniques to process.",
"Cloud computing provides on-demand access to computing resources over the internet.",
]
print("1. Building index (no embeddings stored)...")
builder = LeannBuilder(
embedding_model="sentence-transformers/all-mpnet-base-v2",
prune_ratio=0.7, # Keep 30% of connections
)
builder.add_chunks(chunks)
builder.build_index("demo_knowledge.leann")
print()
print("2. Searching with real-time embeddings...")
searcher = LeannSearcher("demo_knowledge.leann")
queries = [
"What is machine learning?",
"How does neural network work?",
"Tell me about data processing",
]
for query in queries:
print(f"Query: {query}")
results = searcher.search(query, top_k=2)
for i, result in enumerate(results, 1):
print(f" {i}. Score: {result.score:.3f}")
print(f" Text: {result.text[:100]}...")
print()
print("3. Memory stats:")
stats = searcher.get_memory_stats()
print(f" Cache size: {stats.embedding_cache_size}")
print(f" Cache memory: {stats.embedding_cache_memory_mb:.1f} MB")
print(f" Total chunks: {stats.total_chunks}")
print()
print("4. Interactive chat demo:")
print(" (Note: Requires OpenAI API key for real responses)")
chat = LeannChat("demo_knowledge.leann")
# Demo questions
demo_questions: list[str] = [
"What is the difference between machine learning and deep learning?",
"How is data science related to big data?",
]
for question in demo_questions:
print(f" Q: {question}")
response = chat.ask(question)
print(f" A: {response}")
print()
print("Demo completed! Try running:")
print(" uv run python examples/document_search.py")
if __name__ == "__main__":
main()

32
knowledge.leann.meta.json Normal file
View File

@@ -0,0 +1,32 @@
{
"version": "0.1.0",
"backend_name": "diskann",
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
"num_chunks": 6,
"chunks": [
{
"text": "Python is a powerful programming language",
"metadata": {}
},
{
"text": "Machine learning transforms industries",
"metadata": {}
},
{
"text": "Neural networks process complex data",
"metadata": {}
},
{
"text": "Java is a powerful programming language",
"metadata": {}
},
{
"text": "C++ is a powerful programming language",
"metadata": {}
},
{
"text": "C# is a powerful programming language",
"metadata": {}
}
]
}

View File

@@ -0,0 +1,8 @@
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
cmake_minimum_required(VERSION 3.20)
project(leann_backend_diskann_wrapper)
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
add_subdirectory(src/third_party/DiskANN)

View File

@@ -0,0 +1,7 @@
print("Initializing leann-backend-diskann...")
try:
from .diskann_backend import DiskannBackend
print("INFO: DiskANN backend loaded successfully")
except ImportError as e:
print(f"WARNING: Could not import DiskANN backend: {e}")

View File

@@ -0,0 +1,299 @@
import numpy as np
import os
import json
import struct
from pathlib import Path
from typing import Dict
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
from leann.registry import register_backend
from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface
)
from . import _diskannpy as diskannpy
METRIC_MAP = {
"mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE,
}
@contextlib.contextmanager
def chdir(path):
original_dir = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
num_vectors, dim = data.shape
with open(file_path, 'wb') as f:
f.write(struct.pack('I', num_vectors))
f.write(struct.pack('I', dim))
f.write(data.tobytes())
def _check_port(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
class EmbeddingServerManager:
def __init__(self):
self.server_process = None
self.server_port = None
atexit.register(self.stop_server)
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
return True
# 检查端口是否已被其他无关进程占用
if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
return True
print(f"INFO: Starting session-level embedding server as a background process...")
try:
command = [
sys.executable,
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
"--zmq-port", str(port),
"--model-name", model_name
]
project_root = Path(__file__).parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding='utf-8'
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
def _log_monitor(self):
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[EmbeddingServer LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[EmbeddingServer ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self):
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: Server process did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None
@register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod
def builder(**kwargs) -> LeannBackendBuilderInterface:
return DiskannBuilder(**kwargs)
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
with open(meta_path, 'r') as f:
meta = json.load(f)
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(meta.get("embedding_model"))
dimensions = model.get_sentence_embedding_dimension()
kwargs['dimensions'] = dimensions
except ImportError:
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
except Exception as e:
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def build(self, data: np.ndarray, index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
complexity = build_kwargs.get("complexity", 64)
graph_degree = build_kwargs.get("graph_degree", 32)
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = ""
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
with chdir(index_dir):
diskannpy.build_disk_float_index(
metric_enum,
data_filename,
index_prefix,
complexity,
graph_degree,
final_index_ram_limit,
indexing_ram_budget,
num_threads,
pq_disk_bytes,
codebook_prefix
)
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise
finally:
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
class DiskannSearcher(LeannBackendSearcherInterface):
def __init__(self, index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
metric_str = kwargs.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
dimensions = kwargs.get("dimensions")
if not dimensions:
raise ValueError("Vector dimension not provided to DiskannSearcher.")
try:
full_index_prefix = str(index_dir / index_prefix)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
)
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager()
print("✅ DiskANN index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
complexity = kwargs.get("complexity", 100)
beam_width = kwargs.get("beam_width", 4)
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
skip_search_reorder = kwargs.get("skip_search_reorder", False)
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
dedup_node_dis = kwargs.get("dedup_node_dis", False)
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
zmq_port = kwargs.get("zmq_port", 5555)
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
kwargs['recompute_beighbor_embeddings'] = False
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
try:
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
USE_DEFERRED_FETCH,
skip_search_reorder,
recompute_beighbor_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
global_pruning
)
return {"labels": labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: embedding.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_NODEEMBEDDINGREQUEST._serialized_start=35
_NODEEMBEDDINGREQUEST._serialized_end=75
_NODEEMBEDDINGRESPONSE._serialized_start=77
_NODEEMBEDDINGRESPONSE._serialized_end=166
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,397 @@
#!/usr/bin/env python3
"""
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
"""
import pickle
import argparse
import threading
import time
from transformers import AutoTokenizer, AutoModel
import os
from contextlib import contextmanager
import zmq
import numpy as np
RED = "\033[91m"
RESET = "\033[0m"
# 简化的文档存储 - 替代 LazyPassages
class SimpleDocumentStore:
"""简化的文档存储支持任意ID"""
def __init__(self, documents: dict = None):
self.documents = documents or {}
# 默认演示文档
self.default_docs = {
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
def __getitem__(self, doc_id):
doc_id = int(doc_id)
# 优先使用指定的文档
if doc_id in self.documents:
return {"text": self.documents[doc_id]}
# 其次使用默认演示文档
if doc_id in self.default_docs:
return {"text": self.default_docs[doc_id]}
# 对于任意其他ID返回通用文档
fallback_docs = [
"This is a general document about technology and programming concepts.",
"This document discusses machine learning and artificial intelligence topics.",
"This content covers data structures, algorithms, and computer science fundamentals.",
"This is a document about software engineering and development practices.",
"This content focuses on databases, data management, and information systems."
]
# 根据ID选择一个fallback文档
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
return {"text": f"[ID:{doc_id}] {fallback_text}"}
def __len__(self):
return len(self.documents) + len(self.default_docs)
def create_embedding_server_thread(
zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128,
):
"""
在当前线程中创建并运行 embedding server
这个函数设计为在单独的线程中调用
"""
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
try:
# 检查端口是否已被占用
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}")
return
# 初始化模型
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
import torch
# 选择设备
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
if cuda_available:
device = torch.device("cuda")
print("INFO: Using CUDA device")
elif mps_available:
device = torch.device("mps")
print("INFO: Using MPS device (Apple Silicon)")
else:
device = torch.device("cpu")
print("INFO: Using CPU device")
# 加载模型
print(f"INFO: Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
# 优化模型
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# 默认演示文档
demo_documents = {
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
passages = SimpleDocumentStore(demo_documents)
print(f"INFO: Loaded {len(passages)} demo documents")
class DeviceTimer:
"""设备计时器"""
def __init__(self, name="", device=device):
self.name = name
self.device = device
self.start_time = 0
self.end_time = 0
if cuda_available:
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
else:
self.start_event = None
self.end_event = None
@contextmanager
def timing(self):
self.start()
yield
self.end()
def start(self):
if cuda_available:
torch.cuda.synchronize()
self.start_event.record()
else:
if self.device.type == "mps":
torch.mps.synchronize()
self.start_time = time.time()
def end(self):
if cuda_available:
self.end_event.record()
torch.cuda.synchronize()
else:
if self.device.type == "mps":
torch.mps.synchronize()
self.end_time = time.time()
def elapsed_time(self):
if cuda_available:
return self.start_event.elapsed_time(self.end_event) / 1000.0
else:
return self.end_time - self.start_time
def print_elapsed(self):
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
def process_batch(texts_batch, ids_batch, missing_ids):
"""处理文本批次"""
batch_size = len(texts_batch)
print(f"INFO: Processing batch of size {batch_size}")
tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device)
embed_timer = DeviceTimer("embedding (batch)", device)
pool_timer = DeviceTimer("mean pooling (batch)", device)
with tokenize_timer.timing():
encoded_batch = tokenizer.batch_encode_plus(
texts_batch,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt",
return_token_type_ids=False,
)
tokenize_timer.print_elapsed()
seq_length = encoded_batch["input_ids"].size(1)
print(f"Batch size: {batch_size}, Sequence length: {seq_length}")
with to_device_timer.timing():
enc = {k: v.to(device) for k, v in encoded_batch.items()}
to_device_timer.print_elapsed()
with torch.no_grad():
with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
embed_timer.print_elapsed()
with pool_timer.timing():
hidden_states = out.last_hidden_state if hasattr(out, "last_hidden_state") else out
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
batch_embeddings = sum_embeddings / sum_mask
pool_timer.print_elapsed()
return batch_embeddings.cpu().numpy()
# ZMQ server 主循环 - 修改为REP套接字
context = zmq.Context()
socket = context.socket(zmq.ROUTER) # 改为REP套接字
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
# 设置超时
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5秒接收超时
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300秒发送超时
from . import embedding_pb2
print(f"INFO: Embedding server ready to serve requests")
while True:
try:
parts = socket.recv_multipart()
# --- 恢复稳健的消息格式判断 ---
# 必须检查 parts 的长度,避免 IndexError
if len(parts) >= 3:
identity = parts[0]
# empty = parts[1] # 中间的空帧我们通常不关心
message = parts[2]
elif len(parts) == 2:
# 也能处理没有空帧的情况
identity = parts[0]
message = parts[1]
else:
# 如果收到格式错误的消息,打印警告并忽略它,而不是崩溃
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
continue
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup", device)
# 解析请求
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = req_proto.node_ids
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
# 添加调试信息
if len(node_ids) > 0:
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
# 查找文本
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
txtinfo = passages[nid]
txt = txtinfo["text"]
texts.append(txt)
lookup_timer.print_elapsed()
if missing_ids:
print(f"WARNING: Missing passages for IDs: {missing_ids}")
# 处理批次
total_size = len(texts)
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
all_embeddings = []
if total_size > max_batch_size:
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings)
print(f"INFO: Combined embeddings shape: {hidden.shape}")
else:
hidden = process_batch(texts, node_ids, missing_ids)
# 序列化响应
ser_start = time.time()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
resp_proto.missing_ids.extend(missing_ids)
response_data = resp_proto.SerializeToString()
# REP 套接字发送单个响应
socket.send_multipart([identity, b'', response_data])
ser_end = time.time()
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
except zmq.Again:
print("INFO: ZMQ socket timeout, continuing to listen")
# REP套接字不需要重新创建只需要继续监听
continue
except Exception as e:
print(f"ERROR: Error in ZMQ server: {e}")
try:
# 发送空响应以维持REQ-REP状态
empty_resp = embedding_pb2.NodeEmbeddingResponse()
socket.send(empty_resp.SerializeToString())
except:
# 如果发送失败重新创建socket
socket.close()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 5000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
print("INFO: ZMQ socket recreated after error")
except Exception as e:
print(f"ERROR: Failed to start embedding server: {e}")
raise
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
def create_embedding_server(
domain="demo",
load_passages=True,
load_embeddings=False,
use_fp16=True,
use_int8=False,
use_cuda_graphs=False,
zmq_port=5555,
max_batch_size=128,
lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2",
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False)
parser.add_argument("--use-int8", action="store_true", default=False)
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
args = parser.parse_args()
create_embedding_server(
domain=args.domain,
load_passages=args.load_passages,
load_embeddings=args.load_embeddings,
use_fp16=args.use_fp16,
use_int8=args.use_int8,
use_cuda_graphs=args.use_cuda_graphs,
zmq_port=args.zmq_port,
max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name,
)

View File

@@ -0,0 +1,16 @@
[build-system]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.1.0"
dependencies = ["leann-core==0.1.0", "numpy"]
[tool.scikit-build]
# 关键:简化的 CMake 路径
cmake.source-dir = "third_party/DiskANN"
# 关键Python 包在根目录,路径完全匹配
wheel.packages = ["leann_backend_diskann"]
# 使用默认的 redirect 模式
editable.mode = "redirect"

View File

@@ -0,0 +1,6 @@
---
BasedOnStyle: Microsoft
---
Language: Cpp
SortIncludes: false
...

View File

@@ -0,0 +1,14 @@
# Set the default behavior, in case people don't have core.autocrlf set.
* text=auto
# Explicitly declare text files you want to always be normalized and converted
# to native line endings on checkout.
*.c text
*.h text
# Declare files that will always have CRLF line endings on checkout.
*.sln text eol=crlf
# Denote all files that are truly binary and should not be modified.
*.png binary
*.jpg binary

View File

@@ -0,0 +1,40 @@
---
name: Bug report
about: Bug reports help us improve! Thanks for submitting yours!
title: "[BUG] "
labels: bug
assignees: ''
---
## Expected Behavior
Tell us what should happen
## Actual Behavior
Tell us what happens instead
## Example Code
Please see [How to create a Minimal, Reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) for some guidance on creating the best possible example of the problem
```bash
```
## Dataset Description
Please tell us about the shape and datatype of your data, (e.g. 128 dimensions, 12.3 billion points, floats)
- Dimensions:
- Number of Points:
- Data type:
## Error
```
Paste the full error, with any sensitive information minimally redacted and marked $$REDACTED$$
```
## Your Environment
* Operating system (e.g. Windows 11 Pro, Ubuntu 22.04.1 LTS)
* DiskANN version (or commit built from)
## Additional Details
Any other contextual information you might feel is important.

View File

@@ -0,0 +1,2 @@
blank_issues_enabled: false

View File

@@ -0,0 +1,25 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: enhancement
assignees: ''
---
## Is your feature request related to a problem? Please describe.
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
## Describe the solution you'd like
A clear and concise description of what you want to happen.
## Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
## Provide references (if applicable)
If your feature request is related to a published algorithm/idea, please provide links to
any relevant articles or webpages.
## Additional context
Add any other context or screenshots about the feature request here.

View File

@@ -0,0 +1,11 @@
---
name: Usage Question
about: Ask us a question about DiskANN!
title: "[Question]"
labels: question
assignees: ''
---
This is our forum for asking whatever DiskANN question you'd like! No need to feel shy - we're happy to talk about use cases and optimal tuning strategies!

View File

@@ -0,0 +1,22 @@
<!--
Thanks for contributing a pull request! Please ensure you have taken a look at
the contribution guidelines: https://github.com/microsoft/DiskANN/blob/main/CONTRIBUTING.md
-->
- [ ] Does this PR have a descriptive title that could go in our release notes?
- [ ] Does this PR add any new dependencies?
- [ ] Does this PR modify any existing APIs?
- [ ] Is the change to the API backwards compatible?
- [ ] Should this result in any changes to our documentation, either updating existing docs or adding new ones?
#### Reference Issues/PRs
<!--
Example: Fixes #1234. See also #3456.
Please use keywords (e.g., Fixes) to create link to the issues or pull requests
you resolved, so that they will automatically be closed when your pull request
is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests
-->
#### What does this implement/fix? Briefly explain your changes.
#### Any other comments?

View File

@@ -0,0 +1,39 @@
name: 'DiskANN Build Bootstrap'
description: 'Prepares DiskANN build environment and executes build'
runs:
using: "composite"
steps:
# ------------ Linux Build ---------------
- name: Prepare and Execute Build
if: ${{ runner.os == 'Linux' }}
run: |
sudo scripts/dev/install-dev-deps-ubuntu.bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
cmake --build build -- -j
cmake --install build --prefix="dist"
shell: bash
# ------------ End Linux Build ---------------
# ------------ Windows Build ---------------
- name: Add VisualStudio command line tools into path
if: runner.os == 'Windows'
uses: ilammy/msvc-dev-cmd@v1
- name: Run configure and build for Windows
if: runner.os == 'Windows'
run: |
mkdir build && cd build && cmake .. -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
cd ..
mkdir dist
mklink /j .\dist\bin .\x64\Release\
shell: cmd
# ------------ End Windows Build ---------------
# ------------ Windows Build With EXEC_ENV_OLS and USE_BING_INFRA ---------------
- name: Add VisualStudio command line tools into path
if: runner.os == 'Windows'
uses: ilammy/msvc-dev-cmd@v1
- name: Run configure and build for Windows with Bing feature flags
if: runner.os == 'Windows'
run: |
mkdir build_bing && cd build_bing && cmake .. -DEXEC_ENV_OLS=1 -DUSE_BING_INFRA=1 -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
cd ..
shell: cmd
# ------------ End Windows Build ---------------

View File

@@ -0,0 +1,13 @@
name: 'Checking code formatting...'
description: 'Ensures code complies with code formatting rules'
runs:
using: "composite"
steps:
- name: Checking code formatting...
run: |
sudo apt install clang-format
find include -name '*.h' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
find src -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
find apps -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
find python -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
shell: bash

View File

@@ -0,0 +1,28 @@
name: 'Generating Random Data (Basic)'
description: 'Generates the random data files used in acceptance tests'
runs:
using: "composite"
steps:
- name: Generate Random Data (Basic)
run: |
mkdir data
echo "Generating random 1020,1024,1536D float and 4096 int8 vectors for index"
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_5K_norm1.0.bin -D 1020 -N 5000 --norm 1.0
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_5K_norm1.0.bin -D 1024 -N 5000 --norm 1.0
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_5K_norm1.0.bin -D 1536 -N 5000 --norm 1.0
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_5K_norm1.0.bin -D 4096 -N 5000 --norm 1.0
echo "Generating random 1020,1024,1536D float and 4096D int8 avectors for query"
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_1K_norm1.0.bin -D 1020 -N 1000 --norm 1.0
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_1K_norm1.0.bin -D 1024 -N 1000 --norm 1.0
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_1K_norm1.0.bin -D 1536 -N 1000 --norm 1.0
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_1K_norm1.0.bin -D 4096 -N 1000 --norm 1.0
echo "Computing ground truth for 1020,1024,1536D float and 4096D int8 avectors for query"
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1020D_5K_norm1.0.bin --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --K 100
#dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1024D_5K_norm1.0.bin --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1536D_5K_norm1.0.bin --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_4096D_5K_norm1.0.bin --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --K 100
shell: bash

View File

@@ -0,0 +1,38 @@
name: 'Generating Random Data (Basic)'
description: 'Generates the random data files used in acceptance tests'
runs:
using: "composite"
steps:
- name: Generate Random Data (Basic)
run: |
mkdir data
echo "Generating random vectors for index"
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_norm1.0.bin -D 10 -N 10000 --norm 1.0
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_unnorm.bin -D 10 -N 10000 --rand_scaling 2.0
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
echo "Generating random vectors for query"
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_norm1.0.bin -D 10 -N 1000 --norm 1.0
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_unnorm.bin -D 10 -N 1000 --rand_scaling 2.0
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
echo "Computing ground truth for floats across l2, mips, and cosine distance functions"
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type float --dist_fn mips --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_unnorm.bin --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --K 100
echo "Computing ground truth for int8s across l2, mips, and cosine distance functions"
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type int8 --dist_fn mips --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/mips_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type int8 --dist_fn cosine --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
echo "Computing ground truth for uint8s across l2, mips, and cosine distance functions"
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type uint8 --dist_fn mips --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
shell: bash

View File

@@ -0,0 +1,22 @@
name: Build Python Wheel
description: Builds a python wheel with cibuildwheel
inputs:
cibw-identifier:
description: "CI build wheel identifier to build"
required: true
runs:
using: "composite"
steps:
- uses: actions/setup-python@v3
- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.11.3
shell: bash
- name: Building Python ${{inputs.cibw-identifier}} Wheel
run: python -m cibuildwheel --output-dir dist
env:
CIBW_BUILD: ${{inputs.cibw-identifier}}
shell: bash
- uses: actions/upload-artifact@v3
with:
name: wheels
path: ./dist/*.whl

View File

@@ -0,0 +1,81 @@
name: DiskANN Build PDoc Documentation
on: [workflow_call]
jobs:
build-reference-documentation:
permissions:
contents: write
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install python build
run: python -m pip install build
shell: bash
# Install required dependencies
- name: Prepare Linux environment
run: |
sudo scripts/dev/install-dev-deps-ubuntu.bash
shell: bash
# We need to build the wheel in order to run pdoc. pdoc does not seem to work if you just point it at
# our source directory.
- name: Building Python Wheel for documentation generation
run: python -m build --wheel --outdir documentation_dist
shell: bash
- name: "Run Reference Documentation Generation"
run: |
pip install pdoc pipdeptree
pip install documentation_dist/*.whl
echo "documentation" > dependencies_documentation.txt
pipdeptree >> dependencies_documentation.txt
pdoc -o docs/python/html diskannpy
- name: Create version environment variable
run: |
echo "DISKANN_VERSION=$(python <<EOF
from importlib.metadata import version
v = version('diskannpy')
print(v)
EOF
)" >> $GITHUB_ENV
- name: Archive documentation version artifact
uses: actions/upload-artifact@v4
with:
name: dependencies
path: |
${{ github.run_id }}-dependencies_documentation.txt
overwrite: true
- name: Archive documentation artifacts
uses: actions/upload-artifact@v4
with:
name: documentation-site
path: |
docs/python/html
# Publish to /dev if we are on the "main" branch
- name: Publish reference docs for latest development version (main branch)
uses: peaceiris/actions-gh-pages@v3
if: github.ref == 'refs/heads/main'
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: docs/python/html
destination_dir: docs/python/dev
# Publish to /<version> if we are releasing
- name: Publish reference docs by version (main branch)
uses: peaceiris/actions-gh-pages@v3
if: github.event_name == 'release'
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: docs/python/html
destination_dir: docs/python/${{ env.DISKANN_VERSION }}
# Publish to /latest if we are releasing
- name: Publish latest reference docs (main branch)
uses: peaceiris/actions-gh-pages@v3
if: github.event_name == 'release'
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: docs/python/html
destination_dir: docs/python/latest

View File

@@ -0,0 +1,42 @@
name: DiskANN Build Python Wheel
on: [workflow_call]
jobs:
linux-build:
name: Python - Ubuntu - ${{matrix.cibw-identifier}}
strategy:
fail-fast: false
matrix:
cibw-identifier: ["cp39-manylinux_x86_64", "cp310-manylinux_x86_64", "cp311-manylinux_x86_64"]
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Building python wheel ${{matrix.cibw-identifier}}
uses: ./.github/actions/python-wheel
with:
cibw-identifier: ${{matrix.cibw-identifier}}
windows-build:
name: Python - Windows - ${{matrix.cibw-identifier}}
strategy:
fail-fast: false
matrix:
cibw-identifier: ["cp39-win_amd64", "cp310-win_amd64", "cp311-win_amd64"]
runs-on: windows-latest
defaults:
run:
shell: bash
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 1
- name: Building python wheel ${{matrix.cibw-identifier}}
uses: ./.github/actions/python-wheel
with:
cibw-identifier: ${{matrix.cibw-identifier}}

View File

@@ -0,0 +1,28 @@
name: DiskANN Common Checks
# common means common to both pr-test and push-test
on: [workflow_call]
jobs:
formatting-check:
strategy:
fail-fast: true
name: Code Formatting Test
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checking code formatting...
uses: ./.github/actions/format-check
docker-container-build:
name: Docker Container Build
needs: [formatting-check]
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Docker build
run: |
docker build .

View File

@@ -0,0 +1,117 @@
name: Disk With PQ
on: [workflow_call]
jobs:
acceptance-tests-disk-pq:
name: Disk, PQ
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build
- name: Generate Data
uses: ./.github/actions/generate-random
- name: build and search disk index (one shot graph build, L2, no diskPQ) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, cosine, no diskPQ) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, no diskPQ) (int8)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, no diskPQ) (uint8)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (int8)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16\
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (uint8)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (sharded graph build, L2, no diskPQ) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (sharded graph build, cosine, no diskPQ) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (sharded graph build, L2, no diskPQ) (int8)
run: |
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (sharded graph build, L2, no diskPQ) (uint8)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, diskPQ) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, diskPQ) (int8)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, diskPQ) (uint8)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (sharded graph build, MIPS, diskPQ) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 --PQ_disk_bytes 5
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: upload data and bin
uses: actions/upload-artifact@v4
with:
name: disk-pq-${{matrix.os}}
path: |
./dist/**
./data/**

View File

@@ -0,0 +1,102 @@
name: Dynamic-Labels
on: [workflow_call]
jobs:
acceptance-tests-dynamic:
name: Dynamic-Labels
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build
- name: Generate Data
uses: ./.github/actions/generate-random
- name: Generate Labels
run: |
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
- name: Test a streaming index (float) with labels (Zipf distributed)
run: |
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_zipf_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
echo "Computing groundtruth with filter"
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 --label_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.tags
echo "Searching with filter"
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
echo "Computing groundtruth w/o filter"
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000
echo "Searching without filter"
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
- name: Test a streaming index (float) with labels (random distributed)
run: |
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_rand_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
echo "Computing groundtruth with filter"
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 --label_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.tags
echo "Searching with filter"
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
echo "Computing groundtruth w/o filter"
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000
echo "Searching without filter"
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
- name: Test Insert Delete Consolidate (float) with labels (zipf distributed)
run: |
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/zipf_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_zipf_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
echo "Computing groundtruth with filter"
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K_wlabel_5 --label_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.tags
echo "Searching with filter"
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 10 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_zipf_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
echo "Computing groundtruth w/o filter"
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K
echo "Searching without filter"
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
- name: Test Insert Delete Consolidate (float) with labels (random distributed)
run: |
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/rand_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_rand_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
echo "Computing groundtruth with filter"
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K_wlabel_5 --label_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.tags
echo "Searching with filter"
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 40 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_rand_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
echo "Computing groundtruth w/o filter"
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K
echo "Searching without filter"
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
- name: upload data and bin
uses: actions/upload-artifact@v4
with:
name: dynamic-labels-${{matrix.os}}
path: |
./dist/**
./data/**

View File

@@ -0,0 +1,75 @@
name: Dynamic
on: [workflow_call]
jobs:
acceptance-tests-dynamic:
name: Dynamic
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build
- name: Generate Data
uses: ./.github/actions/generate-random
- name: test a streaming index (float)
run: |
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
- name: test a streaming index (int8)
if: success() || failure()
run: |
dist/bin/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
- name: test a streaming index
if: success() || failure()
run: |
dist/bin/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
- name: build and search an incremental index (float)
if: success() || failure()
run: |
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2;
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
- name: build and search an incremental index (int8)
if: success() || failure()
run: |
dist/bin/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
- name: build and search an incremental index (uint8)
if: success() || failure()
run: |
dist/bin/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_10K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
- name: upload data and bin
uses: actions/upload-artifact@v4
with:
name: dynamic-${{matrix.os}}
path: |
./dist/**
./data/**

View File

@@ -0,0 +1,81 @@
name: In-Memory Without PQ
on: [workflow_call]
jobs:
acceptance-tests-mem-no-pq:
name: In-Mem, Without PQ
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build
- name: Generate Data
uses: ./.github/actions/generate-random
- name: build and search in-memory index with L2 metrics (float)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
- name: build and search in-memory index with L2 metrics (int8)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: build and search in-memory index with L2 metrics (uint8)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: Searching with fast_l2 distance function (float)
if: runner.os != 'Windows' && (success() || failure())
run: |
dist/bin/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
- name: build and search in-memory index with MIPS metric (float)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_mips_rand_float_10D_10K_norm1.0
dist/bin/search_memory_index --data_type float --dist_fn mips --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
- name: build and search in-memory index with cosine metric (float)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_cosine_rand_float_10D_10K_norm1.0
dist/bin/search_memory_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
- name: build and search in-memory index with cosine metric (int8)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type int8 --dist_fn cosine --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_int8_10D_10K_norm50.0
dist/bin/search_memory_index --data_type int8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: build and search in-memory index with cosine metric
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50.0
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: upload data and bin
uses: actions/upload-artifact@v4
with:
name: in-memory-no-pq-${{matrix.os}}
path: |
./dist/**
./data/**

View File

@@ -0,0 +1,56 @@
name: In-Memory With PQ
on: [workflow_call]
jobs:
acceptance-tests-mem-pq:
name: In-Mem, PQ
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build
- name: Generate Data
uses: ./.github/actions/generate-random
- name: build and search in-memory index with L2 metric with PQ based distance comparisons (float)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --build_PQ_bytes 5
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (int8)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (uint8)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: upload data and bin
uses: actions/upload-artifact@v4
with:
name: in-memory-pq-${{matrix.os}}
path: |
./dist/**
./data/**

View File

@@ -0,0 +1,120 @@
name: Labels
on: [workflow_call]
jobs:
acceptance-tests-labels:
name: Labels
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build
- name: Generate Data
uses: ./.github/actions/generate-random
- name: Generate Labels
run: |
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/mips_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
echo "Generating synthetic labels and computing ground truth for filtered search without a universal label"
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --K 100
dist/bin/generate_synthetic_labels --num_labels 10 --num_points 1000 --output_file data/query_labels_1K.txt --distribution_type one_per_point
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label_file data/query_labels_1K.txt --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
- name: build and search in-memory index with labels using L2 and Cosine metrics (random distributed labels)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
echo "Searching without filters"
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
- name: build and search disk index with labels using L2 and Cosine metrics (random distributed labels)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search in-memory index with labels using L2 and Cosine metrics (zipf distributed labels)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
echo "Searching without filters"
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
- name: build and search disk index with labels using L2 and Cosine metrics (zipf distributed labels)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name : build and search in-memory and disk index (without universal label, zipf distributed)
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal -R 32 -L 5 -B 0.00003 -M 1
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal -L 16 32
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: Generate combined GT for each query with a separate label and search
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --query_filters_file data/query_labels_1K.txt --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
- name: build and search in-memory index with pq_dist of 5 with 10 dimensions
if: success() || failure()
run: |
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
- name: Build and search stitched vamana with random and zipf distributed labels
if: success() || failure()
run: |
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_rand_32_100_64_new --universal_label 0
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_zipf_32_100_64_new --universal_label 0
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix data/stit_rand_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/rand_stit_96_10_90_new --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/stit_zipf_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/zipf_stit_96_10_90_new --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
- name: upload data and bin
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: labels-${{matrix.os}}
path: |
./dist/**
./data/**

View File

@@ -0,0 +1,60 @@
name: Disk With PQ
on: [workflow_call]
jobs:
acceptance-tests-disk-pq:
name: Disk, PQ
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build
- name: Generate Data
uses: ./.github/actions/generate-high-dim-random
- name: build and search disk index (1020D, one shot graph build, L2, no diskPQ) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1020D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
#- name: build and search disk index (1024D, one shot graph build, L2, no diskPQ) (float)
# if: success() || failure()
# run: |
# dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1024D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
# dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
- name: build and search disk index (1536D, one shot graph build, L2, no diskPQ) (float)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1536D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
- name: build and search disk index (4096D, one shot graph build, L2, no diskPQ) (int8)
if: success() || failure()
run: |
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_4096D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
- name: upload data and bin
uses: actions/upload-artifact@v4
with:
name: multi-sector-disk-pq-${{matrix.os}}
path: |
./dist/**
./data/**

View File

@@ -0,0 +1,26 @@
name: DiskANN Nightly Performance Metrics
on:
schedule:
- cron: "41 14 * * *" # 14:41 UTC, 7:41 PDT, 8:41 PST, 08:11 IST
jobs:
perf-test:
name: Run Perf Test from main
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Build Perf Container
run: |
docker build --build-arg GIT_COMMIT_ISH="$GITHUB_SHA" -t perf -f scripts/perf/Dockerfile scripts
- name: Performance Tests
run: |
mkdir metrics
docker run -v ./metrics:/app/logs perf &> ./metrics/combined_stdouterr.log
- name: Upload Metrics Logs
uses: actions/upload-artifact@v4
with:
name: metrics-${{matrix.os}}
path: |
./metrics/**

View File

@@ -0,0 +1,35 @@
name: DiskANN Pull Request Build and Test
on: [pull_request]
jobs:
common:
strategy:
fail-fast: true
name: DiskANN Common Build Checks
uses: ./.github/workflows/common.yml
unit-tests:
name: Unit tests
uses: ./.github/workflows/unit-tests.yml
in-mem-pq:
name: In-Memory with PQ
uses: ./.github/workflows/in-mem-pq.yml
in-mem-no-pq:
name: In-Memory without PQ
uses: ./.github/workflows/in-mem-no-pq.yml
disk-pq:
name: Disk with PQ
uses: ./.github/workflows/disk-pq.yml
multi-sector-disk-pq:
name: Multi-sector Disk with PQ
uses: ./.github/workflows/multi-sector-disk-pq.yml
labels:
name: Labels
uses: ./.github/workflows/labels.yml
dynamic:
name: Dynamic
uses: ./.github/workflows/dynamic.yml
dynamic-labels:
name: Dynamic Labels
uses: ./.github/workflows/dynamic-labels.yml
python:
name: Python
uses: ./.github/workflows/build-python.yml

View File

@@ -0,0 +1,50 @@
name: DiskANN Push Build
on: [push]
jobs:
common:
strategy:
fail-fast: true
name: DiskANN Common Build Checks
uses: ./.github/workflows/common.yml
build-documentation:
permissions:
contents: write
strategy:
fail-fast: true
name: DiskANN Build Documentation
uses: ./.github/workflows/build-python-pdoc.yml
build:
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest, windows-2019, windows-latest ]
name: Build for ${{matrix.os}}
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: Build diskannpy dependency tree
run: |
pip install diskannpy pipdeptree
echo "dependencies" > dependencies_${{ matrix.os }}.txt
pipdeptree >> dependencies_${{ matrix.os }}.txt
- name: Archive diskannpy dependencies artifact
uses: actions/upload-artifact@v4
with:
name: dependencies_${{ matrix.os }}
path: |
dependencies_${{ matrix.os }}.txt
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build

View File

@@ -0,0 +1,43 @@
name: Build and Release Python Wheels
on:
release:
types: [published]
jobs:
python-release-wheels:
name: Python
uses: ./.github/workflows/build-python.yml
build-documentation:
strategy:
fail-fast: true
name: DiskANN Build Documentation
uses: ./.github/workflows/build-python-pdoc.yml
release:
permissions:
contents: write
runs-on: ubuntu-latest
needs: python-release-wheels
steps:
- uses: actions/download-artifact@v3
with:
name: wheels
path: dist/
- name: Generate SHA256 files for each wheel
run: |
sha256sum dist/*.whl > checksums.txt
cat checksums.txt
- uses: actions/setup-python@v3
- name: Install twine
run: python -m pip install twine
- name: Publish with twine
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
twine upload dist/*.whl
- name: Update release with SHA256 and Artifacts
uses: softprops/action-gh-release@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
files: |
dist/*.whl
checksums.txt

View File

@@ -0,0 +1,32 @@
name: Unit Tests
on: [workflow_call]
jobs:
acceptance-tests-labels:
name: Unit Tests
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
runs-on: ${{matrix.os}}
defaults:
run:
shell: bash
steps:
- name: Checkout repository
if: ${{ runner.os == 'Linux' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Checkout repository
if: ${{ runner.os == 'Windows' }}
uses: actions/checkout@v3
with:
fetch-depth: 1
submodules: true
- name: DiskANN Build CLI Applications
uses: ./.github/actions/build
- name: Run Unit Tests
run: |
cd build
ctest -C Release

View File

@@ -0,0 +1,384 @@
## Ignore Visual Studio temporary files, build results, and
## files generated by popular Visual Studio add-ons.
##
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
# User-specific files
*.rsuser
*.suo
*.user
*.userosscache
*.sln.docstates
# User-specific files (MonoDevelop/Xamarin Studio)
*.userprefs
# Mono auto generated files
mono_crash.*
# Build results
[Dd]ebug/
[Dd]ebugPublic/
[Rr]elease/
[Rr]eleases/
x64/
x86/
[Aa][Rr][Mm]/
[Aa][Rr][Mm]64/
bld/
[Bb]in/
[Oo]bj/
[Ll]og/
[Ll]ogs/
# Visual Studio 2015/2017 cache/options directory
.vs/
# Uncomment if you have tasks that create the project's static files in wwwroot
#wwwroot/
# Visual Studio 2017 auto generated files
Generated\ Files/
# MSTest test Results
[Tt]est[Rr]esult*/
[Bb]uild[Ll]og.*
# NUnit
*.VisualState.xml
TestResult.xml
nunit-*.xml
# Build Results of an ATL Project
[Dd]ebugPS/
[Rr]eleasePS/
dlldata.c
# Benchmark Results
BenchmarkDotNet.Artifacts/
# .NET Core
project.lock.json
project.fragment.lock.json
artifacts/
# StyleCop
StyleCopReport.xml
# Files built by Visual Studio
*_i.c
*_p.c
*_h.h
*.ilk
*.meta
*.obj
*.iobj
*.pch
*.pdb
*.ipdb
*.pgc
*.pgd
*.rsp
*.sbr
*.tlb
*.tli
*.tlh
*.tmp
*.tmp_proj
*_wpftmp.csproj
*.log
*.vspscc
*.vssscc
.builds
*.pidb
*.svclog
*.scc
# Chutzpah Test files
_Chutzpah*
# Visual C++ cache files
ipch/
*.aps
*.ncb
*.opendb
*.opensdf
*.sdf
*.cachefile
*.VC.db
*.VC.VC.opendb
# Visual Studio profiler
*.psess
*.vsp
*.vspx
*.sap
# Visual Studio Trace Files
*.e2e
# TFS 2012 Local Workspace
$tf/
# Guidance Automation Toolkit
*.gpState
# ReSharper is a .NET coding add-in
_ReSharper*/
*.[Rr]e[Ss]harper
*.DotSettings.user
# TeamCity is a build add-in
_TeamCity*
# DotCover is a Code Coverage Tool
*.dotCover
# AxoCover is a Code Coverage Tool
.axoCover/*
!.axoCover/settings.json
# Visual Studio code coverage results
*.coverage
*.coveragexml
# NCrunch
_NCrunch_*
.*crunch*.local.xml
nCrunchTemp_*
# MightyMoose
*.mm.*
AutoTest.Net/
# Web workbench (sass)
.sass-cache/
# Installshield output folder
[Ee]xpress/
# DocProject is a documentation generator add-in
DocProject/buildhelp/
DocProject/Help/*.HxT
DocProject/Help/*.HxC
DocProject/Help/*.hhc
DocProject/Help/*.hhk
DocProject/Help/*.hhp
DocProject/Help/Html2
DocProject/Help/html
# Click-Once directory
publish/
# Publish Web Output
*.[Pp]ublish.xml
*.azurePubxml
# Note: Comment the next line if you want to checkin your web deploy settings,
# but database connection strings (with potential passwords) will be unencrypted
*.pubxml
*.publishproj
# Microsoft Azure Web App publish settings. Comment the next line if you want to
# checkin your Azure Web App publish settings, but sensitive information contained
# in these scripts will be unencrypted
PublishScripts/
# NuGet Packages
*.nupkg
# NuGet Symbol Packages
*.snupkg
# The packages folder can be ignored because of Package Restore
**/[Pp]ackages/*
# except build/, which is used as an MSBuild target.
!**/[Pp]ackages/build/
# Uncomment if necessary however generally it will be regenerated when needed
#!**/[Pp]ackages/repositories.config
# NuGet v3's project.json files produces more ignorable files
*.nuget.props
*.nuget.targets
# Microsoft Azure Build Output
csx/
*.build.csdef
# Microsoft Azure Emulator
ecf/
rcf/
# Windows Store app package directories and files
AppPackages/
BundleArtifacts/
Package.StoreAssociation.xml
_pkginfo.txt
*.appx
*.appxbundle
*.appxupload
# Visual Studio cache files
# files ending in .cache can be ignored
*.[Cc]ache
# but keep track of directories ending in .cache
!?*.[Cc]ache/
# Others
ClientBin/
~$*
*~
*.dbmdl
*.dbproj.schemaview
*.jfm
*.pfx
*.publishsettings
orleans.codegen.cs
# Including strong name files can present a security risk
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
#*.snk
# Since there are multiple workflows, uncomment next line to ignore bower_components
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
#bower_components/
# RIA/Silverlight projects
Generated_Code/
# Backup & report files from converting an old project file
# to a newer Visual Studio version. Backup files are not needed,
# because we have git ;-)
_UpgradeReport_Files/
Backup*/
UpgradeLog*.XML
UpgradeLog*.htm
ServiceFabricBackup/
*.rptproj.bak
# SQL Server files
*.mdf
*.ldf
*.ndf
# Business Intelligence projects
*.rdl.data
*.bim.layout
*.bim_*.settings
*.rptproj.rsuser
*- [Bb]ackup.rdl
*- [Bb]ackup ([0-9]).rdl
*- [Bb]ackup ([0-9][0-9]).rdl
# Microsoft Fakes
FakesAssemblies/
# GhostDoc plugin setting file
*.GhostDoc.xml
# Node.js Tools for Visual Studio
.ntvs_analysis.dat
node_modules/
# Visual Studio 6 build log
*.plg
# Visual Studio 6 workspace options file
*.opt
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
*.vbw
# Visual Studio LightSwitch build output
**/*.HTMLClient/GeneratedArtifacts
**/*.DesktopClient/GeneratedArtifacts
**/*.DesktopClient/ModelManifest.xml
**/*.Server/GeneratedArtifacts
**/*.Server/ModelManifest.xml
_Pvt_Extensions
# Paket dependency manager
.paket/paket.exe
paket-files/
# FAKE - F# Make
.fake/
# CodeRush personal settings
.cr/personal
# Python Tools for Visual Studio (PTVS)
__pycache__/
*.pyc
# Cake - Uncomment if you are using it
# tools/**
# !tools/packages.config
# Tabs Studio
*.tss
# Telerik's JustMock configuration file
*.jmconfig
# BizTalk build output
*.btp.cs
*.btm.cs
*.odx.cs
*.xsd.cs
# OpenCover UI analysis results
OpenCover/
# Azure Stream Analytics local run output
ASALocalRun/
# MSBuild Binary and Structured Log
*.binlog
# NVidia Nsight GPU debugger configuration file
*.nvuser
# MFractors (Xamarin productivity tool) working folder
.mfractor/
# Local History for Visual Studio
.localhistory/
# BeatPulse healthcheck temp database
healthchecksdb
# Backup folder for Package Reference Convert tool in Visual Studio 2017
MigrationBackup/
# Ionide (cross platform F# VS Code tools) working folder
.ionide/
/vcproj/nsg/x64/Debug/nsg.Build.CppClean.log
/vcproj/test_recall/x64/Debug/test_recall.Build.CppClean.log
/vcproj/test_recall/test_recall.vcxproj.user
/.vs
/out/build/x64-Debug
cscope*
build/
build_linux/
!.github/actions/build
# jetbrains specific stuff
.idea/
cmake-build-debug/
#python extension module ignores
python/diskannpy.egg-info/
python/dist/
**/*.egg-info
wheelhouse/*
dist/*
venv*/**
*.swp
gperftools
# Rust
rust/target
python/src/*.so
compile_commands.json

View File

@@ -0,0 +1,3 @@
[submodule "gperftools"]
path = gperftools
url = https://github.com/gperftools/gperftools.git

View File

@@ -0,0 +1,563 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
# Parameters:
#
# BOOST_ROOT:
# Specify root of the Boost library if Boost cannot be auto-detected. On Windows, a fallback to a
# downloaded nuget version will be used if Boost cannot be found.
#
# DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS:
# This is a work-in-progress feature, not completed yet. The core DiskANN library will be split into
# build-related and search-related functionality. In build-related functionality, when using tcmalloc,
# it's possible to release memory that's free but reserved by tcmalloc. Setting this to true enables
# such behavior.
# Contact for this feature: gopalrs.
# Some variables like MSVC are defined only after project(), so put that first.
cmake_minimum_required(VERSION 3.20)
project(diskann)
#Set option to use tcmalloc
option(USE_TCMALLOC "Use tcmalloc from gperftools" ON)
# set tcmalloc to false when on macos
if(APPLE)
set(USE_TCMALLOC OFF)
endif()
option(PYBIND "Build with Python bindings" ON)
if(PYBIND)
# Find Python
find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())"
OUTPUT_VARIABLE pybind11_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)
find_package(pybind11 CONFIG REQUIRED)
message(STATUS "Python include dirs: ${Python_INCLUDE_DIRS}")
message(STATUS "Pybind11 include dirs: ${pybind11_INCLUDE_DIRS}")
# Add pybind11 include directories
include_directories(SYSTEM ${pybind11_INCLUDE_DIRS} ${Python_INCLUDE_DIRS})
# Add compilation definitions
add_definitions(-DPYBIND11_EMBEDDED)
# Set visibility flags
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
endif()
endif()
set(CMAKE_STANDARD 17)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# if(NOT MSVC)
# set(CMAKE_CXX_COMPILER g++)
# endif()
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
# Install nuget packages for dependencies.
if (MSVC)
find_program(NUGET_EXE NAMES nuget)
if (NOT NUGET_EXE)
message(FATAL_ERROR "Cannot find nuget command line tool.\nPlease install it from e.g. https://www.nuget.org/downloads")
endif()
set(DISKANN_MSVC_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/packages.config)
set(DISKANN_MSVC_PACKAGES ${CMAKE_BINARY_DIR}/packages)
message(STATUS "Invoking nuget to download Boost, OpenMP and MKL dependencies...")
configure_file(${PROJECT_SOURCE_DIR}/windows/packages.config.in ${DISKANN_MSVC_PACKAGES_CONFIG})
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
if (RESTAPI)
set(DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/restapi/packages.config)
configure_file(${PROJECT_SOURCE_DIR}/windows/packages_restapi.config.in ${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG})
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
endif()
message(STATUS "Finished setting up nuget dependencies")
endif()
include_directories(${PROJECT_SOURCE_DIR}/include)
include(FetchContent)
if(USE_TCMALLOC)
FetchContent_Declare(
tcmalloc
GIT_REPOSITORY https://github.com/google/tcmalloc.git
GIT_TAG origin/master # or specify a particular version or commit
)
FetchContent_MakeAvailable(tcmalloc)
endif()
if(NOT PYBIND)
set(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS ON)
endif()
# It's necessary to include tcmalloc headers only if calling into MallocExtension interface.
# For using tcmalloc in DiskANN tools, it's enough to just link with tcmalloc.
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
include_directories(${tcmalloc_SOURCE_DIR}/src)
if (MSVC)
include_directories(${tcmalloc_SOURCE_DIR}/src/windows)
endif()
endif()
#OpenMP
if (MSVC)
# Do not use find_package here since it would use VisualStudio's built-in OpenMP, but MKL libraries
# refer to Intel's OpenMP.
#
# No extra settings are needed for compilation: it only needs /openmp flag which is set further below,
# in the common MSVC compiler options block.
include_directories(BEFORE "${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/include")
link_libraries("${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/win-x64/libiomp5md.lib")
set(OPENMP_WINDOWS_RUNTIME_FILES
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.dll"
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.pdb")
elseif(APPLE)
# Check if we're building Python bindings
if(PYBIND)
# First look for PyTorch's OpenMP to avoid conflicts
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib', 'libomp.dylib'))"
RESULT_VARIABLE TORCH_PATH_RESULT
OUTPUT_VARIABLE TORCH_LIBOMP_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)
execute_process(
COMMAND brew --prefix libomp
OUTPUT_VARIABLE LIBOMP_ROOT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(EXISTS "${TORCH_LIBOMP_PATH}")
message(STATUS "Found PyTorch's libomp: ${TORCH_LIBOMP_PATH}")
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp")
set(OpenMP_C_FLAGS "-Xclang -fopenmp")
set(OpenMP_CXX_LIBRARIES "${TORCH_LIBOMP_PATH}")
set(OpenMP_C_LIBRARIES "${TORCH_LIBOMP_PATH}")
set(OpenMP_FOUND TRUE)
include_directories(${LIBOMP_ROOT}/include)
# Set compiler flags and link libraries
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries("${TORCH_LIBOMP_PATH}")
else()
message(STATUS "No PyTorch's libomp found, falling back to normal OpenMP detection")
# Fallback to normal OpenMP detection
execute_process(
COMMAND brew --prefix libomp
OUTPUT_VARIABLE LIBOMP_ROOT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(OpenMP_ROOT "${LIBOMP_ROOT}")
find_package(OpenMP)
if (OPENMP_FOUND)
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries(OpenMP::OpenMP_CXX)
else()
message(FATAL_ERROR "No OpenMP support")
endif()
endif()
else()
# Regular OpenMP setup for non-Python builds
execute_process(
COMMAND brew --prefix libomp
OUTPUT_VARIABLE LIBOMP_ROOT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(OpenMP_ROOT "${LIBOMP_ROOT}")
find_package(OpenMP)
if (OPENMP_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries(OpenMP::OpenMP_CXX)
else()
message(FATAL_ERROR "No OpenMP support")
endif()
endif()
else()
find_package(OpenMP)
if (OPENMP_FOUND)
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
else()
message(FATAL_ERROR "No OpenMP support")
endif()
endif()
# DiskANN core uses header-only libraries. Only DiskANN tools need program_options which has a linker library,
# but its size is small. Reduce number of dependent DLLs by linking statically.
if (MSVC)
set(Boost_USE_STATIC_LIBS ON)
endif()
if(NOT MSVC)
find_package(Boost COMPONENTS program_options)
endif()
# For Windows, fall back to nuget version if find_package didn't find it.
if (MSVC AND NOT Boost_FOUND)
set(DISKANN_BOOST_INCLUDE "${DISKANN_MSVC_PACKAGES}/boost/lib/native/include")
# Multi-threaded static library.
set(PROGRAM_OPTIONS_LIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-x64-*.lib")
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_LIB ${PROGRAM_OPTIONS_LIB_PATTERN})
set(PROGRAM_OPTIONS_DLIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-gd-x64-*.lib")
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_DLIB ${PROGRAM_OPTIONS_DLIB_PATTERN})
if (EXISTS ${DISKANN_BOOST_INCLUDE} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_LIB} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB})
set(Boost_FOUND ON)
set(Boost_INCLUDE_DIR ${DISKANN_BOOST_INCLUDE})
add_library(Boost::program_options STATIC IMPORTED)
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_RELEASE "${DISKANN_BOOST_PROGRAM_OPTIONS_LIB}")
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_DEBUG "${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB}")
message(STATUS "Falling back to using Boost from the nuget package")
else()
message(WARNING "Couldn't find Boost. Was looking for ${DISKANN_BOOST_INCLUDE} and ${PROGRAM_OPTIONS_LIB_PATTERN}")
endif()
endif()
if (NOT Boost_FOUND)
message(FATAL_ERROR "Couldn't find Boost dependency")
endif()
include_directories(${Boost_INCLUDE_DIR})
#MKL Config
if (MSVC)
# Only the DiskANN DLL and one of the tools need MKL libraries. Additionally, only a small part of MKL is used.
# Given that and given that MKL DLLs are huge, use static linking to end up with no MKL DLL dependencies and with
# significantly smaller disk footprint.
#
# The compile options are not modified as there's already an unconditional -DMKL_ILP64 define below
# for all architectures, which is all that's needed.
set(DISKANN_MKL_INCLUDE_DIRECTORIES "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/include")
set(DISKANN_MKL_LIB_PATH "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/win-x64")
set(DISKANN_MKL_LINK_LIBRARIES
"${DISKANN_MKL_LIB_PATH}/mkl_intel_ilp64.lib"
"${DISKANN_MKL_LIB_PATH}/mkl_core.lib"
"${DISKANN_MKL_LIB_PATH}/mkl_intel_thread.lib")
elseif(APPLE)
# no mkl on non-intel devices
find_library(ACCELERATE_LIBRARY Accelerate)
message(STATUS "Found Accelerate (${ACCELERATE_LIBRARY})")
set(DISKANN_ACCEL_LINK_OPTIONS ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
# expected path for manual intel mkl installs
set(POSSIBLE_OMP_PATHS "/opt/intel/oneapi/compiler/2025.0/lib/libiomp5.so;/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin/libiomp5.so;/usr/lib/x86_64-linux-gnu/libiomp5.so;/opt/intel/lib/intel64_lin/libiomp5.so")
foreach(POSSIBLE_OMP_PATH ${POSSIBLE_OMP_PATHS})
if (EXISTS ${POSSIBLE_OMP_PATH})
get_filename_component(OMP_PATH ${POSSIBLE_OMP_PATH} DIRECTORY)
endif()
endforeach()
if(NOT OMP_PATH)
message(FATAL_ERROR "Could not find Intel OMP in standard locations; use -DOMP_PATH to specify the install location for your environment")
endif()
link_directories(${OMP_PATH})
set(POSSIBLE_MKL_LIB_PATHS "/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so;/usr/lib/x86_64-linux-gnu/libmkl_core.so;/opt/intel/mkl/lib/intel64/libmkl_core.so")
foreach(POSSIBLE_MKL_LIB_PATH ${POSSIBLE_MKL_LIB_PATHS})
if (EXISTS ${POSSIBLE_MKL_LIB_PATH})
get_filename_component(MKL_PATH ${POSSIBLE_MKL_LIB_PATH} DIRECTORY)
endif()
endforeach()
set(POSSIBLE_MKL_INCLUDE_PATHS "/opt/intel/oneapi/mkl/latest/include;/usr/include/mkl;/opt/intel/mkl/include/;")
foreach(POSSIBLE_MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATHS})
if (EXISTS ${POSSIBLE_MKL_INCLUDE_PATH})
set(MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATH})
endif()
endforeach()
if(NOT MKL_PATH)
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_PATH to specify the install location for your environment")
elseif(NOT MKL_INCLUDE_PATH)
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_INCLUDE_PATH to specify the install location for headers for your environment")
endif()
if (EXISTS ${MKL_PATH}/libmkl_def.so.2)
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so.2)
elseif(EXISTS ${MKL_PATH}/libmkl_def.so)
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so)
else()
message(FATAL_ERROR "Despite finding MKL, libmkl_def.so was not found in expected locations.")
endif()
link_directories(${MKL_PATH})
include_directories(${MKL_INCLUDE_PATH})
# compile flags and link libraries
# if gcc/g++
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
add_compile_options(-m64 -Wl,--no-as-needed)
endif()
if (NOT PYBIND)
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl)
else()
# static linking for python so as to minimize customer dependency issues
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
# In debug mode, use dynamic linking to ensure all symbols are available
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core ${MKL_DEF_SO} iomp5 pthread m dl)
else()
# In release mode, use static linking to minimize dependencies
link_libraries(
${MKL_PATH}/libmkl_intel_ilp64.a
${MKL_PATH}/libmkl_intel_thread.a
${MKL_PATH}/libmkl_core.a
${MKL_DEF_SO}
iomp5
pthread
m
dl
)
endif()
endif()
add_definitions(-DMKL_ILP64)
endif()
# Section for tcmalloc. The DiskANN tools are always linked to tcmalloc. For Windows, they also need to
# force-include the _tcmalloc symbol for enabling tcmalloc.
#
# The DLL itself needs to be linked to tcmalloc only if DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS
# is enabled.
if(USE_TCMALLOC)
if (MSVC)
if (NOT EXISTS "${PROJECT_SOURCE_DIR}/gperftools/gperftools.sln")
message(FATAL_ERROR "The gperftools submodule was not found. "
"Please check-out git submodules by doing 'git submodule init' followed by 'git submodule update'")
endif()
set(TCMALLOC_LINK_LIBRARY "${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.lib")
set(TCMALLOC_WINDOWS_RUNTIME_FILES
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.dll"
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.pdb")
# Tell CMake how to build the tcmalloc linker library from the submodule.
add_custom_target(build_libtcmalloc_minimal DEPENDS ${TCMALLOC_LINK_LIBRARY})
add_custom_command(OUTPUT ${TCMALLOC_LINK_LIBRARY}
COMMAND ${CMAKE_VS_MSBUILD_COMMAND} gperftools.sln /m /nologo
/t:libtcmalloc_minimal /p:Configuration="Release-Patch"
/property:Platform="x64"
/p:PlatformToolset=v${MSVC_TOOLSET_VERSION}
/p:WindowsTargetPlatformVersion=${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION}
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/gperftools)
add_library(libtcmalloc_minimal_for_exe STATIC IMPORTED)
add_library(libtcmalloc_minimal_for_dll STATIC IMPORTED)
set_target_properties(libtcmalloc_minimal_for_dll PROPERTIES
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}")
set_target_properties(libtcmalloc_minimal_for_exe PROPERTIES
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}"
INTERFACE_LINK_OPTIONS /INCLUDE:_tcmalloc)
# Ensure libtcmalloc_minimal is built before it's being used.
add_dependencies(libtcmalloc_minimal_for_dll build_libtcmalloc_minimal)
add_dependencies(libtcmalloc_minimal_for_exe build_libtcmalloc_minimal)
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_exe)
elseif(APPLE) # ! Inherited from #474, not been adjusted for TCMalloc Removal
execute_process(
COMMAND brew --prefix gperftools
OUTPUT_VARIABLE GPERFTOOLS_PREFIX
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-L${GPERFTOOLS_PREFIX}/lib -ltcmalloc")
elseif(NOT PYBIND)
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-ltcmalloc")
endif()
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
add_definitions(-DRELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
if (MSVC)
set(DISKANN_DLL_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_dll)
endif()
endif()
endif()
if (NOT MSVC AND NOT APPLE)
set(DISKANN_ASYNC_LIB aio)
endif()
#Main compiler/linker settings
if(MSVC)
#language options
add_compile_options(/permissive- /openmp:experimental /Zc:twoPhase- /Zc:inline /WX- /std:c++17 /Gd /W3 /MP /Zi /FC /nologo)
#code generation options
add_compile_options(/arch:AVX2 /fp:fast /fp:except- /EHsc /GS- /Gy)
#optimization options
add_compile_options(/Ot /Oy /Oi)
#path options
add_definitions(-DUSE_AVX2 -DUSE_ACCELERATED_PQ -D_WINDOWS -DNOMINMAX -DUNICODE)
# Linker options. Exclude VCOMP/VCOMPD.LIB which contain VisualStudio's version of OpenMP.
# MKL was linked against Intel's OpenMP and depends on the corresponding DLL.
add_link_options(/NODEFAULTLIB:VCOMP.LIB /NODEFAULTLIB:VCOMPD.LIB /DEBUG:FULL /OPT:REF /OPT:ICF)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
elseif(APPLE)
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -Xclang -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -Wno-inconsistent-missing-override -Wno-return-type")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DNDEBUG -ftree-vectorize")
if (NOT PYBIND)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
if (NOT PORTABLE)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -mtune=native")
endif()
else()
# -Ofast is not supported in a python extension module
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -fPIC")
endif()
else()
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -msse2 -ftree-vectorize -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2 -fPIC")
if(USE_TCMALLOC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free")
endif()
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
if (NOT PYBIND)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
if (NOT PORTABLE)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native")
endif()
else()
# -Ofast is not supported in a python extension module
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG")
endif()
endif()
add_subdirectory(src)
if (NOT PYBIND)
add_subdirectory(apps)
add_subdirectory(apps/utils)
endif()
if (UNIT_TEST)
enable_testing()
add_subdirectory(tests)
endif()
if (MSVC)
message(STATUS "The ${PROJECT_NAME}.sln has been created, opened it from VisualStudio to build Release or Debug configurations.\n"
"Alternatively, use MSBuild to build:\n\n"
"msbuild.exe ${PROJECT_NAME}.sln /m /nologo /t:Build /p:Configuration=\"Release\" /property:Platform=\"x64\"\n")
endif()
if (RESTAPI)
if (MSVC)
set(DISKANN_CPPRESTSDK "${DISKANN_MSVC_PACKAGES}/cpprestsdk.v142/build/native")
# expected path for apt packaged intel mkl installs
link_libraries("${DISKANN_CPPRESTSDK}/x64/lib/cpprest142_2_10.lib")
include_directories("${DISKANN_CPPRESTSDK}/include")
endif()
add_subdirectory(apps/restapi)
endif()
include(clang-format.cmake)
if(PYBIND)
add_subdirectory(python)
install(TARGETS _diskannpy
DESTINATION leann_backend_diskann
COMPONENT python_modules
)
endif()
###############################################################################
# PROTOBUF SECTION - Corrected to use CONFIG mode explicitly
###############################################################################
set(Protobuf_USE_STATIC_LIBS OFF)
find_package(ZLIB REQUIRED)
find_package(Protobuf REQUIRED)
message(STATUS "Protobuf found: ${Protobuf_VERSION}")
message(STATUS "Protobuf include dirs: ${Protobuf_INCLUDE_DIRS}")
message(STATUS "Protobuf libraries: ${Protobuf_LIBRARIES}")
message(STATUS "Protobuf protoc executable: ${Protobuf_PROTOC_EXECUTABLE}")
include_directories(${Protobuf_INCLUDE_DIRS})
set(PROTO_FILE "${CMAKE_CURRENT_SOURCE_DIR}/../embedding.proto")
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
set(generated_proto_sources ${PROTO_SRCS})
add_library(proto_embeddings STATIC ${generated_proto_sources})
target_link_libraries(proto_embeddings PUBLIC protobuf::libprotobuf)
target_include_directories(proto_embeddings PUBLIC
${CMAKE_CURRENT_BINARY_DIR}
${Protobuf_INCLUDE_DIRS}
)
target_link_libraries(diskann PRIVATE proto_embeddings protobuf::libprotobuf)
target_include_directories(diskann PRIVATE
${CMAKE_CURRENT_BINARY_DIR}
${Protobuf_INCLUDE_DIRS}
)
target_link_libraries(diskann_s PRIVATE proto_embeddings protobuf::libprotobuf)
target_include_directories(diskann_s PRIVATE
${CMAKE_CURRENT_BINARY_DIR}
${Protobuf_INCLUDE_DIRS}
)
###############################################################################
# ZEROMQ SECTION - REQUIRED
###############################################################################
find_package(ZeroMQ QUIET)
if(NOT ZeroMQ_FOUND)
find_path(ZeroMQ_INCLUDE_DIR zmq.h)
find_library(ZeroMQ_LIBRARY zmq)
if(ZeroMQ_INCLUDE_DIR AND ZeroMQ_LIBRARY)
set(ZeroMQ_FOUND TRUE)
endif()
endif()
if(ZeroMQ_FOUND)
message(STATUS "Found ZeroMQ: ${ZeroMQ_LIBRARY}")
include_directories(${ZeroMQ_INCLUDE_DIR})
target_link_libraries(diskann PRIVATE ${ZeroMQ_LIBRARY})
target_link_libraries(diskann_s PRIVATE ${ZeroMQ_LIBRARY})
add_definitions(-DUSE_ZEROMQ)
else()
message(FATAL_ERROR "ZeroMQ is required but not found. Please install ZeroMQ and try again.")
endif()
target_link_libraries(diskann ${PYBIND11_LIBRARIES})
target_link_libraries(diskann_s ${PYBIND11_LIBRARIES})

View File

@@ -0,0 +1,28 @@
{
"configurations": [
{
"name": "x64-Release",
"generator": "Ninja",
"configurationType": "Release",
"inheritEnvironments": [ "msvc_x64" ],
"buildRoot": "${projectDir}\\out\\build\\${name}",
"installRoot": "${projectDir}\\out\\install\\${name}",
"cmakeCommandArgs": "",
"buildCommandArgs": "",
"ctestCommandArgs": ""
},
{
"name": "WSL-GCC-Release",
"generator": "Ninja",
"configurationType": "RelWithDebInfo",
"buildRoot": "${projectDir}\\out\\build\\${name}",
"installRoot": "${projectDir}\\out\\install\\${name}",
"cmakeExecutable": "cmake",
"cmakeCommandArgs": "",
"buildCommandArgs": "",
"ctestCommandArgs": "",
"inheritEnvironments": [ "linux_x64" ],
"wslPath": "${defaultWSLPath}"
}
]
}

View File

@@ -0,0 +1,9 @@
# Microsoft Open Source Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns

View File

@@ -0,0 +1,9 @@
# Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.

View File

@@ -0,0 +1,17 @@
#Copyright(c) Microsoft Corporation.All rights reserved.
#Licensed under the MIT license.
FROM ubuntu:jammy
RUN apt update
RUN apt install -y software-properties-common
RUN add-apt-repository -y ppa:git-core/ppa
RUN apt update
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev python3.10
WORKDIR /app
RUN git clone https://github.com/microsoft/DiskANN.git
WORKDIR /app/DiskANN
RUN mkdir build
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
RUN cmake --build build -- -j

View File

@@ -0,0 +1,17 @@
#Copyright(c) Microsoft Corporation.All rights reserved.
#Licensed under the MIT license.
FROM ubuntu:jammy
RUN apt update
RUN apt install -y software-properties-common
RUN add-apt-repository -y ppa:git-core/ppa
RUN apt update
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libboost-test-dev libmkl-full-dev libcpprest-dev python3.10
WORKDIR /app
RUN git clone https://github.com/microsoft/DiskANN.git
WORKDIR /app/DiskANN
RUN mkdir build
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
RUN cmake --build build -- -j

View File

@@ -0,0 +1,23 @@
DiskANN
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE

View File

@@ -0,0 +1,12 @@
include MANIFEST.in
include *.txt
include *.md
include setup.py
include pyproject.toml
include *.cmake
recursive-include gperftools *
recursive-include include *
recursive-include python *
recursive-include windows *
prune python/tests
recursive-include src *

View File

@@ -0,0 +1,135 @@
# DiskANN
[![DiskANN Main](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml/badge.svg?branch=main)](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml)
[![PyPI version](https://img.shields.io/pypi/v/diskannpy.svg)](https://pypi.org/project/diskannpy/)
[![Downloads shield](https://pepy.tech/badge/diskannpy)](https://pepy.tech/project/diskannpy)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![DiskANN Paper](https://img.shields.io/badge/Paper-NeurIPS%3A_DiskANN-blue)](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf)
[![DiskANN Paper](https://img.shields.io/badge/Paper-Arxiv%3A_Fresh--DiskANN-blue)](https://arxiv.org/abs/2105.09613)
[![DiskANN Paper](https://img.shields.io/badge/Paper-Filtered--DiskANN-blue)](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf)
DiskANN is a suite of scalable, accurate and cost-effective approximate nearest neighbor search algorithms for large-scale vector search that support real-time changes and simple filters.
This code is based on ideas from the [DiskANN](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf), [Fresh-DiskANN](https://arxiv.org/abs/2105.09613) and the [Filtered-DiskANN](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf) papers with further improvements.
This code forked off from [code for NSG](https://github.com/ZJULearning/nsg) algorithm.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
See [guidelines](CONTRIBUTING.md) for contributing to this project.
## Linux build:
Install the following packages through apt-get
```bash
sudo apt install make cmake g++ libaio-dev libgoogle-perftools-dev clang-format libboost-all-dev
```
### Install Intel MKL
#### Ubuntu 20.04 or newer
```bash
sudo apt install libmkl-full-dev
```
#### Earlier versions of Ubuntu
Install Intel MKL either by downloading the [oneAPI MKL installer](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) or using [apt](https://software.intel.com/en-us/articles/installing-intel-free-libs-and-python-apt-repo) (we tested with build 2019.4-070 and 2022.1.2.146).
```
# OneAPI MKL Installer
wget https://registrationcenter-download.intel.com/akdlm/irc_nas/18487/l_BaseKit_p_2022.1.2.146.sh
sudo sh l_BaseKit_p_2022.1.2.146.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
```
### Build
```bash
mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
```
## Windows build:
The Windows version has been tested with Enterprise editions of Visual Studio 2022, 2019 and 2017. It should work with the Community and Professional editions as well without any changes.
**Prerequisites:**
* CMake 3.15+ (available in VisualStudio 2019+ or from https://cmake.org)
* NuGet.exe (install from https://www.nuget.org/downloads)
* The build script will use NuGet to get MKL, OpenMP and Boost packages.
* DiskANN git repository checked out together with submodules. To check out submodules after git clone:
```
git submodule init
git submodule update
```
* Environment variables:
* [optional] If you would like to override the Boost library listed in windows/packages.config.in, set BOOST_ROOT to your Boost folder.
**Build steps:**
* Open the "x64 Native Tools Command Prompt for VS 2019" (or corresponding version) and change to DiskANN folder
* Create a "build" directory inside it
* Change to the "build" directory and run
```
cmake ..
```
OR for Visual Studio 2017 and earlier:
```
<full-path-to-installed-cmake>\cmake ..
```
**This will create a diskann.sln solution**. Now you can:
- Open it from VisualStudio and build either Release or Debug configuration.
- `<full-path-to-installed-cmake>\cmake --build build`
- Use MSBuild:
```
msbuild.exe diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64"
```
* This will also build gperftools submodule for libtcmalloc_minimal dependency.
* Generated binaries are stored in the x64/Release or x64/Debug directories.
## macOS Build
### Prerequisites
* Apple Silicon. The code should still work on Intel-based Macs, but there are no guarantees.
* macOS >= 12.0
* XCode Command Line Tools (install with `xcode-select --install`)
* [homebrew](https://brew.sh/)
### Install Required Packages
```zsh
brew install cmake
brew install boost
brew install gperftools
brew install libomp
```
### Build DiskANN
```zsh
# same as ubuntu instructions
mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
```
## Usage:
Please see the following pages on using the compiled code:
- [Commandline interface for building and search SSD based indices](workflows/SSD_index.md)
- [Commandline interface for building and search in memory indices](workflows/in_memory_index.md)
- [Commandline examples for using in-memory streaming indices](workflows/dynamic_index.md)
- [Commandline interface for building and search in memory indices with label data and filters](workflows/filtered_in_memory.md)
- [Commandline interface for building and search SSD based indices with label data and filters](workflows/filtered_ssd_index.md)
- [diskannpy - DiskANN as a python extension module](python/README.md)
Please cite this software in your work as:
```
@misc{diskann-github,
author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan and Patel, Yash}},
title = {{DiskANN: Graph-structured Indices for Scalable, Fast, Fresh and Filtered Approximate Nearest Neighbor Search}},
url = {https://github.com/Microsoft/DiskANN},
version = {0.6.1},
year = {2023}
}
```

View File

@@ -0,0 +1,41 @@
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
<!-- END MICROSOFT SECURITY.MD BLOCK -->

View File

@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
add_executable(build_memory_index build_memory_index.cpp)
target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(build_stitched_index build_stitched_index.cpp)
target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(search_memory_index search_memory_index.cpp)
target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(build_disk_index build_disk_index.cpp)
target_link_libraries(build_disk_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB} Boost::program_options)
add_executable(search_disk_index search_disk_index.cpp)
target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(range_search_disk_index range_search_disk_index.cpp)
target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(test_streaming_scenario test_streaming_scenario.cpp)
target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(test_insert_deletes_consolidate test_insert_deletes_consolidate.cpp)
target_link_libraries(test_insert_deletes_consolidate ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
if (NOT MSVC)
install(TARGETS build_memory_index
build_stitched_index
search_memory_index
build_disk_index
search_disk_index
range_search_disk_index
test_streaming_scenario
test_insert_deletes_consolidate
RUNTIME
)
endif()

View File

@@ -0,0 +1,191 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <omp.h>
#include <boost/program_options.hpp>
#include "utils.h"
#include "disk_utils.h"
#include "math_utils.h"
#include "index.h"
#include "partition.h"
#include "program_options_utils.hpp"
namespace po = boost::program_options;
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label,
label_type;
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold;
float B, M;
bool append_reorder_data = false;
bool use_opq = false;
po::options_description desc{
program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
required_configs.add_options()("search_DRAM_budget,B", po::value<float>(&B)->required(),
"DRAM budget in GB for searching the index to set the "
"compressed level for data while search happens");
required_configs.add_options()("build_DRAM_budget,M", po::value<float>(&M)->required(),
"DRAM budget in GB for building the index");
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("QD", po::value<uint32_t>(&QD)->default_value(0),
" Quantized Dimension for compression");
optional_configs.add_options()("codebook_prefix", po::value<std::string>(&codebook_prefix)->default_value(""),
"Path prefix for pre-trained codebook");
optional_configs.add_options()("PQ_disk_bytes", po::value<uint32_t>(&disk_PQ)->default_value(0),
"Number of bytes to which vectors should be compressed "
"on SSD; 0 for no compression");
optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false),
"Include full precision data in the index. Use only in "
"conjuction with compressed data on SSD.");
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
program_options_utils::LABEL_FILE);
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
program_options_utils::UNIVERSAL_LABEL);
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
program_options_utils::FILTERED_LBUILD);
optional_configs.add_options()("filter_threshold,F", po::value<uint32_t>(&filter_threshold)->default_value(0),
"Threshold to break up the existing nodes to generate new graph "
"internally where each node has a maximum F labels.");
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
if (vm["append_reorder_data"].as<bool>())
append_reorder_data = true;
if (vm["use_opq"].as<bool>())
use_opq = true;
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
bool use_filters = (label_file != "") ? true : false;
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else if (dist_fn == std::string("cosine"))
metric = diskann::Metric::COSINE;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
if (append_reorder_data)
{
if (disk_PQ == 0)
{
std::cout << "Error: It is not necessary to append data for reordering "
"when vectors are not compressed on disk."
<< std::endl;
return -1;
}
if (data_type != std::string("float"))
{
std::cout << "Error: Appending data for reordering currently only "
"supported for float data type."
<< std::endl;
return -1;
}
}
std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " +
std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " +
std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " +
std::string(std::to_string(append_reorder_data)) + " " +
std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD));
try
{
if (label_file != "" && label_type == "ushort")
{
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
use_filters, label_file, universal_label, filter_threshold, Lf);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
use_filters, label_file, universal_label, filter_threshold, Lf);
else
{
diskann::cerr << "Error. Unsupported data type" << std::endl;
return -1;
}
}
else
{
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
else
{
diskann::cerr << "Error. Unsupported data type" << std::endl;
return -1;
}
}
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index build failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,164 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <omp.h>
#include <cstring>
#include <boost/program_options.hpp>
#include "index.h"
#include "utils.h"
#include "program_options_utils.hpp"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <unistd.h>
#else
#include <Windows.h>
#endif
#include "memory_mapper.h"
#include "ann_exception.h"
#include "index_factory.h"
namespace po = boost::program_options;
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
float alpha;
bool use_pq_build, use_opq;
po::options_description desc{
program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
program_options_utils::LABEL_FILE);
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
program_options_utils::UNIVERSAL_LABEL);
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
program_options_utils::FILTERED_LBUILD);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
use_pq_build = (build_PQ_bytes > 0);
use_opq = vm["use_opq"].as<bool>();
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cout << "Unsupported distance function. Currently only L2/ Inner "
"Product/Cosine are supported."
<< std::endl;
return -1;
}
try
{
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha
<< " #threads: " << num_threads << std::endl;
size_t data_num, data_dim;
diskann::get_bin_metadata(data_path, data_num, data_dim);
auto index_build_params = diskann::IndexWriteParametersBuilder(L, R)
.with_filter_list_size(Lf)
.with_alpha(alpha)
.with_saturate_graph(false)
.with_num_threads(num_threads)
.build();
auto filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(label_file)
.with_save_path_prefix(index_path_prefix)
.build();
auto config = diskann::IndexConfigBuilder()
.with_metric(metric)
.with_dimension(data_dim)
.with_max_points(data_num)
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
.with_data_type(data_type)
.with_label_type(label_type)
.is_dynamic_index(false)
.with_index_write_params(index_build_params)
.is_enable_tags(false)
.is_use_opq(use_opq)
.is_pq_dist_build(use_pq_build)
.with_num_pq_chunks(build_PQ_bytes)
.build();
auto index_factory = diskann::IndexFactory(config);
auto index = index_factory.create_instance();
index->build(data_path, data_num, filter_params);
index->save(index_path_prefix.c_str());
index.reset();
return 0;
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index build failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,441 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <boost/program_options.hpp>
#include <chrono>
#include <cstdio>
#include <cstring>
#include <random>
#include <string>
#include <tuple>
#include "filter_utils.h"
#include <omp.h>
#ifndef _WINDOWS
#include <sys/uio.h>
#endif
#include "index.h"
#include "memory_mapper.h"
#include "parameters.h"
#include "utils.h"
#include "program_options_utils.hpp"
namespace po = boost::program_options;
typedef std::tuple<std::vector<std::vector<uint32_t>>, uint64_t> stitch_indices_return_values;
/*
* Inline function to display progress bar.
*/
inline void print_progress(double percentage)
{
int val = (int)(percentage * 100);
int lpad = (int)(percentage * PBWIDTH);
int rpad = PBWIDTH - lpad;
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
fflush(stdout);
}
/*
* Inline function to generate a random integer in a range.
*/
inline size_t random(size_t range_from, size_t range_to)
{
std::random_device rand_dev;
std::mt19937 generator(rand_dev());
std::uniform_int_distribution<size_t> distr(range_from, range_to);
return distr(generator);
}
/*
* function to handle command line parsing.
*
* Arguments are merely the inputs from the command line.
*/
void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix,
path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L,
uint32_t &stitched_R, float &alpha)
{
po::options_description desc{
program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("index_path_prefix",
po::value<std::string>(&final_index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&input_data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("label_file", po::value<std::string>(&label_data_path)->default_value(""),
program_options_utils::LABEL_FILE);
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
program_options_utils::UNIVERSAL_LABEL);
optional_configs.add_options()("stitched_R", po::value<uint32_t>(&stitched_R)->default_value(100),
"Degree to prune final graph down to");
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
exit(0);
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
throw;
}
}
/*
* Custom index save to write the in-memory index to disk.
* Also writes required files for diskANN API -
* 1. labels_to_medoids
* 2. universal_label
* 3. data (redundant for static indices)
* 4. labels (redundant for static indices)
*/
void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size,
std::vector<std::vector<uint32_t>> stitched_graph,
tsl::robin_map<std::string, uint32_t> entry_points, std::string universal_label,
path label_data_path)
{
// aux. file 1
auto saving_index_timer = std::chrono::high_resolution_clock::now();
std::ifstream original_label_data_stream;
original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
original_label_data_stream.open(label_data_path, std::ios::binary);
std::ofstream new_label_data_stream;
new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary);
new_label_data_stream << original_label_data_stream.rdbuf();
original_label_data_stream.close();
new_label_data_stream.close();
// aux. file 2
std::ifstream original_input_data_stream;
original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
original_input_data_stream.open(input_data_path, std::ios::binary);
std::ofstream new_input_data_stream;
new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary);
new_input_data_stream << original_input_data_stream.rdbuf();
original_input_data_stream.close();
new_input_data_stream.close();
// aux. file 3
std::ofstream labels_to_medoids_writer;
labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit);
labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt");
for (auto iter : entry_points)
labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl;
labels_to_medoids_writer.close();
// aux. file 4 (only if we're using a universal label)
if (universal_label != "")
{
std::ofstream universal_label_writer;
universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit);
universal_label_writer.open(final_index_path_prefix + "_universal_label.txt");
universal_label_writer << universal_label << std::endl;
universal_label_writer.close();
}
// main index
uint64_t index_num_frozen_points = 0, index_num_edges = 0;
uint32_t index_max_observed_degree = 0, index_entry_point = 0;
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
for (auto &point_neighbors : stitched_graph)
{
index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size());
}
std::ofstream stitched_graph_writer;
stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit);
stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary);
stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t));
stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t));
stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t));
stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t));
size_t bytes_written = METADATA;
for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++)
{
uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size();
std::vector<uint32_t> current_node_neighbors = stitched_graph[node_point];
stitched_graph_writer.write((char *)&current_node_num_neighbors, sizeof(uint32_t));
bytes_written += sizeof(uint32_t);
for (const auto &current_node_neighbor : current_node_neighbors)
{
stitched_graph_writer.write((char *)&current_node_neighbor, sizeof(uint32_t));
bytes_written += sizeof(uint32_t);
}
index_num_edges += current_node_num_neighbors;
}
if (bytes_written != final_index_size)
{
std::cerr << "Error: written bytes does not match allocated space" << std::endl;
throw;
}
stitched_graph_writer.close();
std::chrono::duration<double> saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer;
std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl;
std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size()))
<< std::endl;
std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl;
}
/*
* Unions the per-label graph indices together via the following policy:
* - any two nodes can only have at most one edge between them -
*
* Returns the "stitched" graph and its expected file size.
*/
template <typename T>
stitch_indices_return_values stitch_label_indices(
path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels,
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
tsl::robin_map<std::string, uint32_t> &label_entry_points,
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map)
{
size_t final_index_size = 0;
std::vector<std::vector<uint32_t>> stitched_graph(total_number_of_points);
auto stitching_index_timer = std::chrono::high_resolution_clock::now();
for (const auto &lbl : all_labels)
{
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
std::vector<std::vector<uint32_t>> curr_label_index;
uint64_t curr_label_index_size;
uint32_t curr_label_entry_point;
std::tie(curr_label_index, curr_label_index_size) =
diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]);
curr_label_entry_point = (uint32_t)random(0, curr_label_index.size());
label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point];
for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++)
{
uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point];
for (auto &node_neighbor : curr_label_index[node_point])
{
uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor];
std::vector<uint32_t> curr_point_neighbors = stitched_graph[original_point_id];
if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) ==
curr_point_neighbors.end())
{
stitched_graph[original_point_id].push_back(original_neighbor_id);
final_index_size += sizeof(uint32_t);
}
}
}
}
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA);
std::chrono::duration<double> stitching_index_time =
std::chrono::high_resolution_clock::now() - stitching_index_timer;
std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl;
return std::make_tuple(stitched_graph, final_index_size);
}
/*
* Applies the prune_neighbors function from src/index.cpp to
* every node in the stitched graph.
*
* This is an optional step, hence the saving of both the full
* and pruned graph.
*/
template <typename T>
void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path,
std::vector<std::vector<uint32_t>> stitched_graph, uint32_t stitched_R,
tsl::robin_map<std::string, uint32_t> label_entry_points, std::string universal_label,
path label_data_path, uint32_t num_threads)
{
size_t dimension, number_of_label_points;
auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr);
auto std_cout_buffer = std::cout.rdbuf(nullptr);
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false,
false, false, 0, false);
// not searching this index, set search_l to 0
index.load(full_index_path_prefix.c_str(), num_threads, 1);
std::cout << "parsing labels" << std::endl;
index.prune_all_neighbors(stitched_R, 750, 1.2);
index.save((final_index_path_prefix).c_str());
diskann::cout.rdbuf(diskann_cout_buffer);
std::cout.rdbuf(std_cout_buffer);
std::chrono::duration<double> pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer;
std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl;
}
/*
* Delete all temporary artifacts.
* In the process of creating the stitched index, some temporary artifacts are
* created:
* 1. the separate bin files for each labels' points
* 2. the separate diskANN indices built for each label
* 3. the '.data' file created while generating the indices
*/
void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels)
{
for (const auto &lbl : all_labels)
{
path curr_label_input_data_path(input_data_path + "_" + lbl);
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
path curr_label_index_path_data(curr_label_index_path + ".data");
if (std::remove(curr_label_index_path.c_str()) != 0)
throw;
if (std::remove(curr_label_input_data_path.c_str()) != 0)
throw;
if (std::remove(curr_label_index_path_data.c_str()) != 0)
throw;
}
}
int main(int argc, char **argv)
{
// 1. handle cmdline inputs
std::string data_type;
path input_data_path, final_index_path_prefix, label_data_path;
std::string universal_label;
uint32_t num_threads, R, L, stitched_R;
float alpha;
auto index_timer = std::chrono::high_resolution_clock::now();
handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label,
num_threads, R, L, stitched_R, alpha);
path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt";
path labels_map_file = final_index_path_prefix + "_labels_map.txt";
convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label);
// 2. parse label file and create necessary data structures
std::vector<label_set> point_ids_to_labels;
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
label_set all_labels;
std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) =
diskann::parse_label_file(labels_file_to_use, universal_label);
// 3. for each label, make a separate data file
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map;
uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size();
#ifndef _WINDOWS
if (data_type == "uint8")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<uint8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else if (data_type == "int8")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<int8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else if (data_type == "float")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<float>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else
throw;
#else
if (data_type == "uint8")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<uint8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else if (data_type == "int8")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<int8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else if (data_type == "float")
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<float>(
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
else
throw;
#endif
// 4. for each created data file, create a vanilla diskANN index
if (data_type == "uint8")
diskann::generate_label_indices<uint8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
num_threads);
else if (data_type == "int8")
diskann::generate_label_indices<int8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
num_threads);
else if (data_type == "float")
diskann::generate_label_indices<float>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
num_threads);
else
throw;
// 5. "stitch" the indices together
std::vector<std::vector<uint32_t>> stitched_graph;
tsl::robin_map<std::string, uint32_t> label_entry_points;
uint64_t stitched_graph_size;
if (data_type == "uint8")
std::tie(stitched_graph, stitched_graph_size) =
stitch_label_indices<uint8_t>(final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
else if (data_type == "int8")
std::tie(stitched_graph, stitched_graph_size) =
stitch_label_indices<int8_t>(final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
else if (data_type == "float")
std::tie(stitched_graph, stitched_graph_size) =
stitch_label_indices<float>(final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
else
throw;
path full_index_path_prefix = final_index_path_prefix + "_full";
// 5a. save the stitched graph to disk
save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points,
universal_label, labels_file_to_use);
// 6. run a prune on the stitched index, and save to disk
if (data_type == "uint8")
prune_and_save<uint8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
else if (data_type == "int8")
prune_and_save<int8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
else if (data_type == "float")
prune_and_save<float>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
else
throw;
std::chrono::duration<double> index_time = std::chrono::high_resolution_clock::now() - index_timer;
std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl;
clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels);
}

View File

@@ -0,0 +1,46 @@
<!-- Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT license. -->
# Integration Tests
The following tests use Python to prepare, run, verify, and tear down the rest api services.
We do make use of the built-in `unittest` library, but that's only to take advantage of test reporting purposes.
These are decidedly **not** _unit_ tests. These are end to end integration tests.
## Caveats
This has only been tested or built for Linux, though we have written platform agnostic Python for the smoke test
(i.e. using `os.path.join`, etc)
It has been tested on Python 3.9 and 3.10, but should work on Python 3.6+.
## How to Run
First, build the DiskANN RestAPI code; see $REPOSITORY_ROOT/workflows/rest_api.md for detailed instructions.
```bash
cd tests/python
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
export DISKANN_BUILD_DIR=/path/to/your/diskann/build
python -m unittest
```
## Smoke Test Failed, Now What?
The smoke test written takes advantage of temporary directories that are only valid during the
lifetime of the test. The contents of these directories include:
- Randomized vectors (first in tsv, then bin form) used to build the PQFlashIndex
- The PQFlashIndex files
It is useful to keep these around. By setting some environment variables, you can control whether an ephemeral,
temporary directory is used (and deleted on test completion), or left as an exercise for the developer to
clean up.
The valid environment variables are:
- `DISKANN_REST_TEST_WORKING_DIR` (example: `$USER/DiskANNRestTest`)
- If this is specified, it **must exist** and **must be writeable**. Any existing files will be clobbered.
- `DISKANN_REST_SERVER` (example: `http://127.0.0.1:10067`)
- Note that if this is set, no data will be generated, nor will a server be started; it is presumed you have done
all the work in creating and starting the rest server prior to running the test and just submits requests against it.

View File

@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import numpy as np
import os
import subprocess
def output_vectors(
diskann_build_path: str,
temporary_file_path: str,
vectors: np.ndarray,
timeout: int = 60
) -> str:
vectors_as_tsv_path = os.path.join(temporary_file_path, "vectors.tsv")
with open(vectors_as_tsv_path, "w") as vectors_tsv_out:
for vector in vectors:
as_str = "\t".join((str(component) for component in vector))
print(as_str, file=vectors_tsv_out)
# there is probably a clever way to have numpy write out C++ friendly floats, so feel free to remove this in
# favor of something more sane later
vectors_as_bin_path = os.path.join(temporary_file_path, "vectors.bin")
tsv_to_bin_path = os.path.join(diskann_build_path, "apps", "utils", "tsv_to_bin")
number_of_points, dimensions = vectors.shape
args = [
tsv_to_bin_path,
"float",
vectors_as_tsv_path,
vectors_as_bin_path,
str(dimensions),
str(number_of_points)
]
completed = subprocess.run(args, timeout=timeout)
if completed.returncode != 0:
raise Exception(f"Unable to convert tsv to binary using tsv_to_bin, completed_process: {completed}")
return vectors_as_bin_path
def build_ssd_index(
diskann_build_path: str,
temporary_file_path: str,
vectors: np.ndarray,
per_process_timeout: int = 60 # this may not be long enough if you're doing something larger
):
vectors_as_bin_path = output_vectors(diskann_build_path, temporary_file_path, vectors, timeout=per_process_timeout)
ssd_builder_path = os.path.join(diskann_build_path, "apps", "build_disk_index")
args = [
ssd_builder_path,
"--data_type", "float",
"--dist_fn", "l2",
"--data_path", vectors_as_bin_path,
"--index_path_prefix", os.path.join(temporary_file_path, "smoke_test"),
"-R", "64",
"-L", "100",
"--search_DRAM_budget", "1",
"--build_DRAM_budget", "1",
"--num_threads", "1",
"--PQ_disk_bytes", "0"
]
completed = subprocess.run(args, timeout=per_process_timeout)
if completed.returncode != 0:
command_run = " ".join(args)
raise Exception(f"Unable to build a disk index with the command: '{command_run}'\ncompleted_process: {completed}\nstdout: {completed.stdout}\nstderr: {completed.stderr}")
# index is now built inside of temporary_file_path

View File

@@ -0,0 +1,379 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <atomic>
#include <cstring>
#include <iomanip>
#include <omp.h>
#include <set>
#include <boost/program_options.hpp>
#include "index.h"
#include "disk_utils.h"
#include "math_utils.h"
#include "memory_mapper.h"
#include "pq_flash_index.h"
#include "partition.h"
#include "timer.h"
#include "program_options_utils.hpp"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include "linux_aligned_file_reader.h"
#else
#ifdef USE_BING_INFRA
#include "bing_aligned_file_reader.h"
#else
#include "windows_aligned_file_reader.h"
#endif
#endif
namespace po = boost::program_options;
#define WARMUP false
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
{
diskann::cout << std::setw(20) << category << ": " << std::flush;
for (uint32_t s = 0; s < percentiles.size(); s++)
{
diskann::cout << std::setw(8) << percentiles[s] << "%";
}
diskann::cout << std::endl;
diskann::cout << std::setw(22) << " " << std::flush;
for (uint32_t s = 0; s < percentiles.size(); s++)
{
diskann::cout << std::setw(9) << results[s];
}
diskann::cout << std::endl;
}
template <typename T, typename LabelT = uint32_t>
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file,
std::string &gt_file, const uint32_t num_threads, const float search_range,
const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector<uint32_t> &Lvec)
{
std::string pq_prefix = index_path_prefix + "_pq";
std::string disk_index_file = index_path_prefix + "_disk.index";
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
if (beamwidth <= 0)
diskann::cout << "beamwidth to be optimized for each L value" << std::endl;
else
diskann::cout << " beamwidth: " << beamwidth << std::endl;
// load query bin
T *query = nullptr;
std::vector<std::vector<uint32_t>> groundtruth_ids;
size_t query_num, query_dim, query_aligned_dim, gt_num;
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
bool calc_recall_flag = false;
if (gt_file != std::string("null") && file_exists(gt_file))
{
diskann::load_range_truthset(gt_file, groundtruth_ids,
gt_num); // use for range search type of truthset
// diskann::prune_truthset_for_range(gt_file, search_range,
// groundtruth_ids, gt_num); // use for traditional truthset
if (gt_num != query_num)
{
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
return -1;
}
calc_recall_flag = true;
}
std::shared_ptr<AlignedFileReader> reader = nullptr;
#ifdef _WINDOWS
#ifndef USE_BING_INFRA
reader.reset(new WindowsAlignedFileReader());
#else
reader.reset(new diskann::BingAlignedFileReader());
#endif
#else
reader.reset(new LinuxAlignedFileReader());
#endif
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
if (res != 0)
{
return res;
}
// cache bfs levels
std::vector<uint32_t> node_list;
diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl;
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
// _pFlashIndex->generate_cache_list_from_sample_queries(
// warmup_query_file, 15, 6, num_nodes_to_cache, num_threads,
// node_list);
_pFlashIndex->load_cache_list(node_list);
node_list.clear();
node_list.shrink_to_fit();
omp_set_num_threads(num_threads);
uint64_t warmup_L = 20;
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
T *warmup = nullptr;
if (WARMUP)
{
if (file_exists(warmup_query_file))
{
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
}
else
{
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
warmup_dim = query_dim;
warmup_aligned_dim = query_aligned_dim;
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-128, 127);
for (uint32_t i = 0; i < warmup_num; i++)
{
for (uint32_t d = 0; d < warmup_dim; d++)
{
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
}
}
}
diskann::cout << "Warming up index... " << std::flush;
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<float> warmup_result_dists(warmup_num, 0);
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
{
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
warmup_result_ids_64.data() + (i * 1),
warmup_result_dists.data() + (i * 1), 4);
}
diskann::cout << "..done" << std::endl;
}
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
diskann::cout.precision(2);
std::string recall_string = "Recall@rng=" + std::to_string(search_range);
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
<< "CPU (s)";
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall_string << std::endl;
}
else
diskann::cout << std::endl;
diskann::cout << "==============================================================="
"==========================================="
<< std::endl;
std::vector<std::vector<std::vector<uint32_t>>> query_result_ids(Lvec.size());
uint32_t optimized_beamwidth = 2;
uint32_t max_list_size = 10000;
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
{
uint32_t L = Lvec[test_id];
if (beamwidth <= 0)
{
optimized_beamwidth =
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
}
else
optimized_beamwidth = beamwidth;
query_result_ids[test_id].clear();
query_result_ids[test_id].resize(query_num);
diskann::QueryStats *stats = new diskann::QueryStats[query_num];
auto s = std::chrono::high_resolution_clock::now();
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)query_num; i++)
{
std::vector<uint64_t> indices;
std::vector<float> distances;
uint32_t res_count =
_pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices,
distances, optimized_beamwidth, stats + i);
query_result_ids[test_id][i].reserve(res_count);
query_result_ids[test_id][i].resize(res_count);
for (uint32_t idx = 0; idx < res_count; idx++)
query_result_ids[test_id][i][idx] = (uint32_t)indices[idx];
}
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
auto qps = (1.0 * query_num) / (1.0 * diff.count());
auto mean_latency = diskann::get_mean_stats<float>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
auto latency_999 = diskann::get_percentile_stats<float>(
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.n_ios; });
double mean_cpuus = diskann::get_mean_stats<float>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; });
double recall = 0;
double ratio_of_sums = 0;
if (calc_recall_flag)
{
recall =
diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]);
uint32_t total_true_positive = 0;
uint32_t total_positive = 0;
for (uint32_t i = 0; i < query_num; i++)
{
total_true_positive += (uint32_t)query_result_ids[test_id][i].size();
total_positive += (uint32_t)groundtruth_ids[i].size();
}
ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive);
}
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
<< std::setw(16) << mean_cpuus;
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall << "," << ratio_of_sums << std::endl;
}
else
diskann::cout << std::endl;
}
diskann::cout << "Done searching. " << std::endl;
diskann::aligned_free(query);
if (warmup != nullptr)
diskann::aligned_free(warmup);
return 0;
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file;
uint32_t num_threads, W, num_nodes_to_cache;
std::vector<uint32_t> Lvec;
float range;
po::options_description desc{program_options_utils::make_program_description(
"range_search_disk_index", "Searches disk DiskANN indexes using ranges")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
program_options_utils::QUERY_FILE_DESCRIPTION);
required_configs.add_options()("search_list,L",
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
program_options_utils::SEARCH_LIST_DESCRIPTION);
required_configs.add_options()("range_threshold,K", po::value<float>(&range)->required(),
"Number of neighbors to be returned");
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("gt_file", po::value<std::string>(&gt_file)->default_value(std::string("null")),
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
program_options_utils::BEAMWIDTH);
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cout << "Unsupported distance function. Currently only L2/ Inner "
"Product/Cosine are supported."
<< std::endl;
return -1;
}
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
{
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
return -1;
}
try
{
if (data_type == std::string("float"))
return search_disk_index<float>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
num_nodes_to_cache, Lvec);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
num_nodes_to_cache, Lvec);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
num_nodes_to_cache, Lvec);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
return -1;
}
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index search failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
set(CMAKE_CXX_STANDARD 17)
add_executable(inmem_server inmem_server.cpp)
if(MSVC)
target_link_options(inmem_server PRIVATE /MACHINE:x64)
target_link_libraries(inmem_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(inmem_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(inmem_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(ssd_server ssd_server.cpp)
if(MSVC)
target_link_options(ssd_server PRIVATE /MACHINE:x64)
target_link_libraries(ssd_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(ssd_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(ssd_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(multiple_ssdindex_server multiple_ssdindex_server.cpp)
if(MSVC)
target_link_options(multiple_ssdindex_server PRIVATE /MACHINE:x64)
target_link_libraries(multiple_ssdindex_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(multiple_ssdindex_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(multiple_ssdindex_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()
add_executable(client client.cpp)
if(MSVC)
target_link_options(client PRIVATE /MACHINE:x64)
target_link_libraries(client debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
target_link_libraries(client optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
else()
target_link_libraries(client ${PROJECT_NAME} -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
endif()

View File

@@ -0,0 +1,124 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <cpprest/http_client.h>
#include <restapi/common.h>
using namespace web;
using namespace web::http;
using namespace web::http::client;
using namespace diskann;
namespace po = boost::program_options;
template <typename T>
void query_loop(const std::string &ip_addr_port, const std::string &query_file, const unsigned nq, const unsigned Ls,
const unsigned k_value)
{
web::http::client::http_client client(U(ip_addr_port));
T *data;
size_t npts = 1, ndims = 128, rounded_dim = 128;
diskann::load_aligned_bin<T>(query_file, data, npts, ndims, rounded_dim);
for (unsigned i = 0; i < nq; ++i)
{
T *vec = data + i * rounded_dim;
web::http::http_request http_query(methods::POST);
web::json::value queryJson = web::json::value::object();
queryJson[QUERY_ID_KEY] = i;
queryJson[K_KEY] = k_value;
queryJson[L_KEY] = Ls;
for (size_t i = 0; i < ndims; ++i)
{
queryJson[VECTOR_KEY][i] = web::json::value::number(vec[i]);
}
http_query.set_body(queryJson);
client.request(http_query)
.then([](web::http::http_response response) -> pplx::task<utility::string_t> {
if (response.status_code() == status_codes::OK)
{
return response.extract_string();
}
std::cerr << "Query failed" << std::endl;
return pplx::task_from_result(utility::string_t());
})
.then([](pplx::task<utility::string_t> previousTask) {
try
{
std::cout << previousTask.get() << std::endl;
}
catch (http_exception const &e)
{
std::wcout << e.what() << std::endl;
}
})
.wait();
}
}
int main(int argc, char *argv[])
{
std::string data_type, query_file, address;
uint32_t num_queries;
uint32_t l_search, k_value;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
"File containing the queries to search");
desc.add_options()("num_queries,Q", po::value<uint32_t>(&num_queries)->required(),
"Number of queries to search");
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
desc.add_options()("k_value,K", po::value<uint32_t>(&k_value)->default_value(10), "Value of K (default 10)");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
query_loop<float>(address, query_file, num_queries, l_search, k_value);
}
else if (data_type == std::string("int8"))
{
query_loop<int8_t>(address, query_file, num_queries, l_search, k_value);
}
else if (data_type == std::string("uint8"))
{
query_loop<uint8_t>(address, query_file, num_queries, l_search, k_value);
}
else
{
std::cerr << "Unsupported type " << argv[2] << std::endl;
return -1;
}
return 0;
}

View File

@@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_inMemorySearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
uint32_t num_threads;
uint32_t l_search;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("data_file", po::value<std::string>(&data_file)->required(),
"File containing the data found in the index");
desc.add_options()("index_path_prefix", po::value<std::string>(&index_file)->required(),
"Path prefix for saving index file components");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->required(),
"Number of threads used for building index");
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<float>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else if (data_type == std::string("int8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else if (data_type == std::string("uint8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
}
else
{
std::cerr << "Unsupported data type " << argv[2] << std::endl;
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <restapi/server.h>
#include <restapi/in_memory_search.h>
#include <codecvt>
#include <iostream>
std::unique_ptr<Server> g_httpServer(nullptr);
std::unique_ptr<diskann::InMemorySearch> g_inMemorySearch(nullptr);
void setup(const utility::string_t &address)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::wcout << L"Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch));
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
void loadIndex(const char *indexFile, const char *baseFile, const char *idsFile)
{
auto nsgSearch = new diskann::InMemorySearch(baseFile, indexFile, idsFile, diskann::L2);
g_inMemorySearch = std::unique_ptr<diskann::InMemorySearch>(nsgSearch);
}
std::wstring getHostingAddress(const char *hostNameAndPort)
{
wchar_t buffer[4096];
mbstowcs_s(nullptr, buffer, sizeof(buffer) / sizeof(buffer[0]), hostNameAndPort,
sizeof(buffer) / sizeof(buffer[0]));
return std::wstring(buffer);
}
int main(int argc, char *argv[])
{
if (argc != 5)
{
std::cout << "Usage: nsg_server <ip_addr_and_port> <index_file> "
"<base_file> <ids_file> "
<< std::endl;
exit(1);
}
auto address = getHostingAddress(argv[1]);
loadIndex(argv[2], argv[3], argv[4]);
while (1)
{
try
{
setup(address);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,182 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <omp.h>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("index_prefix_paths", po::value<std::string>(&index_prefix_paths)->required(),
"Path prefix for loading index file components");
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
"Number of nodes to cache during search");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
std::vector<std::pair<std::string, std::string>> index_tag_paths;
std::ifstream index_in(index_prefix_paths);
if (!index_in.is_open())
{
std::cerr << "Could not open " << index_prefix_paths << std::endl;
exit(-1);
}
std::ifstream tags_in(tags_file);
if (!tags_in.is_open())
{
std::cerr << "Could not open " << tags_file << std::endl;
exit(-1);
}
std::string prefix, tagfile;
while (std::getline(index_in, prefix))
{
if (std::getline(tags_in, tagfile))
{
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
}
else
{
std::cerr << "The number of tags specified does not match the number of "
"indices specified"
<< std::endl;
exit(-1);
}
}
index_in.close();
tags_in.close();
if (data_type == std::string("float"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<float>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else if (data_type == std::string("int8"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else if (data_type == std::string("uint8"))
{
for (auto &index_tag : index_tag_paths)
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<uint8_t>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
}
else
{
std::cerr << "Unsupported data type " << data_type << std::endl;
exit(-1);
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,141 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <ctime>
#include <functional>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <codecvt>
#include <boost/program_options.hpp>
#include <omp.h>
#include <restapi/server.h>
using namespace diskann;
namespace po = boost::program_options;
std::unique_ptr<Server> g_httpServer(nullptr);
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
void setup(const utility::string_t &address, const std::string &typestring)
{
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
std::cout << "Created a server object" << std::endl;
g_httpServer->open().wait();
ucout << U"Listening for requests on: " << address << std::endl;
}
void teardown(const utility::string_t &address)
{
g_httpServer->close().wait();
}
int main(int argc, char *argv[])
{
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
"Path prefix for loading index file components");
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
"Number of nodes to cache during search");
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else
{
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
return -1;
}
if (data_type == std::string("float"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else if (data_type == std::string("int8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<int8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else if (data_type == std::string("uint8"))
{
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<uint8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
}
else
{
std::cerr << "Unsupported data type " << argv[2] << std::endl;
exit(-1);
}
while (1)
{
try
{
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit")
{
teardown(address);
g_httpServer->close().wait();
exit(0);
}
}
catch (const std::exception &ex)
{
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
catch (...)
{
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

View File

@@ -0,0 +1,499 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include "common_includes.h"
#include <boost/program_options.hpp>
#include "index.h"
#include "disk_utils.h"
#include "math_utils.h"
#include "memory_mapper.h"
#include "partition.h"
#include "pq_flash_index.h"
#include "timer.h"
#include "percentile_stats.h"
#include "program_options_utils.hpp"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include "linux_aligned_file_reader.h"
#else
#ifdef USE_BING_INFRA
#include "bing_aligned_file_reader.h"
#else
#include "windows_aligned_file_reader.h"
#endif
#endif
#define WARMUP false
namespace po = boost::program_options;
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
{
diskann::cout << std::setw(20) << category << ": " << std::flush;
for (uint32_t s = 0; s < percentiles.size(); s++)
{
diskann::cout << std::setw(8) << percentiles[s] << "%";
}
diskann::cout << std::endl;
diskann::cout << std::setw(22) << " " << std::flush;
for (uint32_t s = 0; s < percentiles.size(); s++)
{
diskann::cout << std::setw(9) << results[s];
}
diskann::cout << std::endl;
}
template <typename T, typename LabelT = uint32_t>
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
const std::string &result_output_prefix, const std::string &query_file, std::string &gt_file,
const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth,
const uint32_t num_nodes_to_cache, const uint32_t search_io_limit,
const std::vector<uint32_t> &Lvec, const float fail_if_recall_below,
const std::vector<std::string> &query_filters, const bool use_reorder_data = false)
{
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
if (beamwidth <= 0)
diskann::cout << "beamwidth to be optimized for each L value" << std::flush;
else
diskann::cout << " beamwidth: " << beamwidth << std::flush;
if (search_io_limit == std::numeric_limits<uint32_t>::max())
diskann::cout << "." << std::endl;
else
diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl;
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
// load query bin
T *query = nullptr;
uint32_t *gt_ids = nullptr;
float *gt_dists = nullptr;
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
bool filtered_search = false;
if (!query_filters.empty())
{
filtered_search = true;
if (query_filters.size() != 1 && query_filters.size() != query_num)
{
std::cout << "Error. Mismatch in number of queries and size of query "
"filters file"
<< std::endl;
return -1; // To return -1 or some other error handling?
}
}
bool calc_recall_flag = false;
if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file))
{
diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim);
if (gt_num != query_num)
{
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
}
calc_recall_flag = true;
}
std::shared_ptr<AlignedFileReader> reader = nullptr;
#ifdef _WINDOWS
#ifndef USE_BING_INFRA
reader.reset(new WindowsAlignedFileReader());
#else
reader.reset(new diskann::BingAlignedFileReader());
#endif
#else
reader.reset(new LinuxAlignedFileReader());
#endif
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
if (res != 0)
{
return res;
}
std::vector<uint32_t> node_list;
diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl;
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
// if (num_nodes_to_cache > 0)
// _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache,
// num_threads, node_list);
_pFlashIndex->load_cache_list(node_list);
node_list.clear();
node_list.shrink_to_fit();
omp_set_num_threads(num_threads);
uint64_t warmup_L = 20;
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
T *warmup = nullptr;
if (WARMUP)
{
if (file_exists(warmup_query_file))
{
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
}
else
{
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
warmup_dim = query_dim;
warmup_aligned_dim = query_aligned_dim;
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-128, 127);
for (uint32_t i = 0; i < warmup_num; i++)
{
for (uint32_t d = 0; d < warmup_dim; d++)
{
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
}
}
}
diskann::cout << "Warming up index... " << std::flush;
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<float> warmup_result_dists(warmup_num, 0);
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
{
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
warmup_result_ids_64.data() + (i * 1),
warmup_result_dists.data() + (i * 1), 4);
}
diskann::cout << "..done" << std::endl;
}
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
diskann::cout.precision(2);
std::string recall_string = "Recall@" + std::to_string(recall_at);
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
<< "Mean IO (us)" << std::setw(16) << "CPU (s)";
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall_string << std::endl;
}
else
diskann::cout << std::endl;
diskann::cout << "=================================================================="
"================================================================="
<< std::endl;
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
std::vector<std::vector<float>> query_result_dists(Lvec.size());
uint32_t optimized_beamwidth = 2;
double best_recall = 0.0;
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
{
uint32_t L = Lvec[test_id];
if (L < recall_at)
{
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
continue;
}
if (beamwidth <= 0)
{
diskann::cout << "Tuning beamwidth.." << std::endl;
optimized_beamwidth =
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
}
else
optimized_beamwidth = beamwidth;
query_result_ids[test_id].resize(recall_at * query_num);
query_result_dists[test_id].resize(recall_at * query_num);
auto stats = new diskann::QueryStats[query_num];
std::vector<uint64_t> query_result_ids_64(recall_at * query_num);
auto s = std::chrono::high_resolution_clock::now();
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)query_num; i++)
{
if (!filtered_search)
{
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at),
optimized_beamwidth, use_reorder_data, stats + i);
}
else
{
LabelT label_for_search;
if (query_filters.size() == 1)
{ // one label for all queries
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
}
else
{ // one label for each query
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
}
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
use_reorder_data, stats + i);
}
}
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
double qps = (1.0 * query_num) / (1.0 * diff.count());
diskann::convert_types<uint64_t, uint32_t>(query_result_ids_64.data(), query_result_ids[test_id].data(),
query_num, recall_at);
auto mean_latency = diskann::get_mean_stats<float>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
auto latency_999 = diskann::get_percentile_stats<float>(
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.n_ios; });
auto mean_cpuus = diskann::get_mean_stats<float>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.cpu_us; });
auto mean_io_us = diskann::get_mean_stats<float>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.io_us; });
double recall = 0;
if (calc_recall_flag)
{
recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
query_result_ids[test_id].data(), recall_at, recall_at);
best_recall = std::max(recall, best_recall);
}
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
<< std::setw(16) << mean_io_us << std::setw(16) << mean_cpuus;
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall << std::endl;
}
else
diskann::cout << std::endl;
delete[] stats;
}
diskann::cout << "Done searching. Now saving results " << std::endl;
uint64_t test_id = 0;
for (auto L : Lvec)
{
if (L < recall_at)
continue;
std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin";
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin";
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at);
}
diskann::aligned_free(query);
if (warmup != nullptr)
diskann::aligned_free(warmup);
return best_recall >= fail_if_recall_below ? 0 : -1;
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label,
label_type, query_filters_file;
uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit;
std::vector<uint32_t> Lvec;
bool use_reorder_data = false;
float fail_if_recall_below = 0.0f;
po::options_description desc{
program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("result_path", po::value<std::string>(&result_path_prefix)->required(),
program_options_utils::RESULT_PATH_DESCRIPTION);
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
program_options_utils::QUERY_FILE_DESCRIPTION);
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
required_configs.add_options()("search_list,L",
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
program_options_utils::SEARCH_LIST_DESCRIPTION);
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("gt_file", po::value<std::string>(&gt_file)->default_value(std::string("null")),
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
program_options_utils::BEAMWIDTH);
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
optional_configs.add_options()(
"search_io_limit",
po::value<uint32_t>(&search_io_limit)->default_value(std::numeric_limits<uint32_t>::max()),
"Max #IOs for search. Default value: uint32::max()");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false),
"Include full precision data in the index. Use only in "
"conjuction with compressed data on SSD. Default value: false");
optional_configs.add_options()("filter_label",
po::value<std::string>(&filter_label)->default_value(std::string("")),
program_options_utils::FILTER_LABEL_DESCRIPTION);
optional_configs.add_options()("query_filters_file",
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
program_options_utils::FILTERS_FILE_DESCRIPTION);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
optional_configs.add_options()("fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
if (vm["use_reorder_data"].as<bool>())
use_reorder_data = true;
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cout << "Unsupported distance function. Currently only L2/ Inner "
"Product/Cosine are supported."
<< std::endl;
return -1;
}
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
{
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
return -1;
}
if (use_reorder_data && data_type != std::string("float"))
{
std::cout << "Error: Reorder data for reordering currently only "
"supported for float data type."
<< std::endl;
return -1;
}
if (filter_label != "" && query_filters_file != "")
{
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
return -1;
}
std::vector<std::string> query_filters;
if (filter_label != "")
{
query_filters.push_back(filter_label);
}
else if (query_filters_file != "")
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
}
try
{
if (!query_filters.empty() && label_type == "ushort")
{
if (data_type == std::string("float"))
return search_disk_index<float, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
return -1;
}
}
else
{
if (data_type == std::string("float"))
return search_disk_index<float>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
return -1;
}
}
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index search failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,477 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <numeric>
#include <omp.h>
#include <set>
#include <string.h>
#include <boost/program_options.hpp>
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <time.h>
#include <unistd.h>
#endif
#include "index.h"
#include "memory_mapper.h"
#include "utils.h"
#include "program_options_utils.hpp"
#include "index_factory.h"
namespace po = boost::program_options;
template <typename T, typename LabelT = uint32_t>
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
const bool dynamic, const bool tags, const bool show_qps_per_thread,
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
{
using TagT = uint32_t;
// Load the query file
T *query = nullptr;
uint32_t *gt_ids = nullptr;
float *gt_dists = nullptr;
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
bool calc_recall_flag = false;
if (truthset_file != std::string("null") && file_exists(truthset_file))
{
diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim);
if (gt_num != query_num)
{
std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
}
calc_recall_flag = true;
}
else
{
diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl;
}
bool filtered_search = false;
if (!query_filters.empty())
{
filtered_search = true;
if (query_filters.size() != 1 && query_filters.size() != query_num)
{
std::cout << "Error. Mismatch in number of queries and size of query "
"filters file"
<< std::endl;
return -1; // To return -1 or some other error handling?
}
}
const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path);
auto config = diskann::IndexConfigBuilder()
.with_metric(metric)
.with_dimension(query_dim)
.with_max_points(0)
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
.with_data_type(diskann_type_to_name<T>())
.with_label_type(diskann_type_to_name<LabelT>())
.with_tag_type(diskann_type_to_name<TagT>())
.is_dynamic_index(dynamic)
.is_enable_tags(tags)
.is_concurrent_consolidate(false)
.is_pq_dist_build(false)
.is_use_opq(false)
.with_num_pq_chunks(0)
.with_num_frozen_pts(num_frozen_pts)
.build();
auto index_factory = diskann::IndexFactory(config);
auto index = index_factory.create_instance();
index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end())));
std::cout << "Index loaded" << std::endl;
if (metric == diskann::FAST_L2)
index->optimize_index_layout();
std::cout << "Using " << num_threads << " threads to search" << std::endl;
std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
std::cout.precision(2);
const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS";
uint32_t table_width = 0;
if (tags)
{
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)"
<< std::setw(15) << "99.9 Latency";
table_width += 4 + 12 + 20 + 15;
}
else
{
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps"
<< std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency";
table_width += 4 + 12 + 18 + 20 + 15;
}
uint32_t recalls_to_print = 0;
const uint32_t first_recall = print_all_recalls ? 1 : recall_at;
if (calc_recall_flag)
{
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
{
std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall));
}
recalls_to_print = recall_at + 1 - first_recall;
table_width += recalls_to_print * 12;
}
std::cout << std::endl;
std::cout << std::string(table_width, '=') << std::endl;
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
std::vector<std::vector<float>> query_result_dists(Lvec.size());
std::vector<float> latency_stats(query_num, 0);
std::vector<uint32_t> cmp_stats;
if (not tags || filtered_search)
{
cmp_stats = std::vector<uint32_t>(query_num, 0);
}
std::vector<TagT> query_result_tags;
if (tags)
{
query_result_tags.resize(recall_at * query_num);
}
double best_recall = 0.0;
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
{
uint32_t L = Lvec[test_id];
if (L < recall_at)
{
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
continue;
}
query_result_ids[test_id].resize(recall_at * query_num);
query_result_dists[test_id].resize(recall_at * query_num);
std::vector<T *> res = std::vector<T *>();
auto s = std::chrono::high_resolution_clock::now();
omp_set_num_threads(num_threads);
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)query_num; i++)
{
auto qs = std::chrono::high_resolution_clock::now();
if (filtered_search && !tags)
{
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
query_result_ids[test_id].data() + i * recall_at,
query_result_dists[test_id].data() + i * recall_at);
cmp_stats[i] = retval.second;
}
else if (metric == diskann::FAST_L2)
{
index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L,
query_result_ids[test_id].data() + i * recall_at);
}
else if (tags)
{
if (!filtered_search)
{
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
query_result_tags.data() + i * recall_at, nullptr, res);
}
else
{
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter);
}
for (int64_t r = 0; r < (int64_t)recall_at; r++)
{
query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r];
}
}
else
{
cmp_stats[i] = index
->search(query + i * query_aligned_dim, recall_at, L,
query_result_ids[test_id].data() + i * recall_at)
.second;
}
auto qe = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = qe - qs;
latency_stats[i] = (float)(diff.count() * 1000000);
}
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
double displayed_qps = query_num / diff.count();
if (show_qps_per_thread)
displayed_qps /= num_threads;
std::vector<double> recalls;
if (calc_recall_flag)
{
recalls.reserve(recalls_to_print);
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
{
recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
query_result_ids[test_id].data(), recall_at, curr_recall));
}
}
std::sort(latency_stats.begin(), latency_stats.end());
double mean_latency =
std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast<float>(query_num);
float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num;
if (tags && !filtered_search)
{
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency
<< std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)];
}
else
{
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps
<< std::setw(20) << (float)mean_latency << std::setw(15)
<< (float)latency_stats[(uint64_t)(0.999 * query_num)];
}
for (double recall : recalls)
{
std::cout << std::setw(12) << recall;
best_recall = std::max(recall, best_recall);
}
std::cout << std::endl;
}
std::cout << "Done searching. Now saving results " << std::endl;
uint64_t test_id = 0;
for (auto L : Lvec)
{
if (L < recall_at)
{
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
continue;
}
std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L);
std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin";
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
cur_result_path = cur_result_path_prefix + "_dists_float.bin";
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at);
test_id++;
}
diskann::aligned_free(query);
return best_recall >= fail_if_recall_below ? 0 : -1;
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
query_filters_file;
uint32_t num_threads, K;
std::vector<uint32_t> Lvec;
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
float fail_if_recall_below = 0.0f;
po::options_description desc{
program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")};
try
{
desc.add_options()("help,h", "Print this information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("result_path", po::value<std::string>(&result_path)->required(),
program_options_utils::RESULT_PATH_DESCRIPTION);
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
program_options_utils::QUERY_FILE_DESCRIPTION);
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
required_configs.add_options()("search_list,L",
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
program_options_utils::SEARCH_LIST_DESCRIPTION);
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("filter_label",
po::value<std::string>(&filter_label)->default_value(std::string("")),
program_options_utils::FILTER_LABEL_DESCRIPTION);
optional_configs.add_options()("query_filters_file",
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
program_options_utils::FILTERS_FILE_DESCRIPTION);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
optional_configs.add_options()("gt_file", po::value<std::string>(&gt_file)->default_value(std::string("null")),
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()(
"dynamic", po::value<bool>(&dynamic)->default_value(false),
"Whether the index is dynamic. Dynamic indices must have associated tags. Default false.");
optional_configs.add_options()("tags", po::value<bool>(&tags)->default_value(false),
"Whether to search with external identifiers (tags). Default false.");
optional_configs.add_options()("fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);
// Output controls
po::options_description output_controls("Output controls");
output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls),
"Print recalls at all positions, from 1 up to specified "
"recall_at value");
output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread),
"Print overall QPS divided by the number of threads in "
"the output table");
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs).add(output_controls);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
diskann::Metric metric;
if ((dist_fn == std::string("mips")) && (data_type == std::string("float")))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float")))
{
metric = diskann::Metric::FAST_L2;
}
else
{
std::cout << "Unsupported distance function. Currently only l2/ cosine are "
"supported in general, and mips/fast_l2 only for floating "
"point data."
<< std::endl;
return -1;
}
if (dynamic && not tags)
{
std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl;
return -1;
}
if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0)
{
std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl;
return -1;
}
if (filter_label != "" && query_filters_file != "")
{
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
return -1;
}
std::vector<std::string> query_filters;
if (filter_label != "")
{
query_filters.push_back(filter_label);
}
else if (query_filters_file != "")
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
}
try
{
if (!query_filters.empty() && label_type == "ushort")
{
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else
{
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
return -1;
}
}
else
{
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else
{
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
return -1;
}
}
}
catch (std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index search failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,536 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <index.h>
#include <numeric>
#include <omp.h>
#include <string.h>
#include <time.h>
#include <timer.h>
#include <boost/program_options.hpp>
#include <future>
#include "utils.h"
#include "filter_utils.h"
#include "program_options_utils.hpp"
#include "index_factory.h"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#include "memory_mapper.h"
namespace po = boost::program_options;
// load_aligned_bin modified to read pieces of the file, but using ifstream
// instead of cached_ifstream.
template <typename T>
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
{
diskann::Timer timer;
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(bin_file, std::ios::binary | std::ios::ate);
size_t actual_file_size = reader.tellg();
reader.seekg(0, std::ios::beg);
int npts_i32, dim_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&dim_i32, sizeof(int));
size_t npts = (uint32_t)npts_i32;
size_t dim = (uint32_t)dim_i32;
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
if (actual_file_size != expected_actual_file_size)
{
std::stringstream stream;
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
<< std::endl;
std::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
if (offset_points + points_to_read > npts)
{
std::stringstream stream;
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
<< " points, but have only " << npts << " points" << std::endl;
std::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
const size_t rounded_dim = ROUND_UP(dim, 8);
for (size_t i = 0; i < points_to_read; i++)
{
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
}
reader.close();
const double elapsedSeconds = timer.elapsed() / 1000000.0;
std::cout << "Read " << points_to_read << " points using non-cached reads in " << elapsedSeconds << std::endl;
}
std::string get_save_filename(const std::string &save_path, size_t points_to_skip, size_t points_deleted,
size_t last_point_threshold)
{
std::string final_path = save_path;
if (points_to_skip > 0)
{
final_path += "skip" + std::to_string(points_to_skip) + "-";
}
final_path += "del" + std::to_string(points_deleted) + "-";
final_path += std::to_string(last_point_threshold);
return final_path;
}
template <typename T, typename TagT, typename LabelT>
void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data,
size_t aligned_dim, std::vector<std::vector<LabelT>> &location_to_labels)
{
diskann::Timer insert_timer;
#pragma omp parallel for num_threads(thread_count) schedule(dynamic)
for (int64_t j = start; j < (int64_t)end; j++)
{
if (!location_to_labels.empty())
{
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
location_to_labels[j - start]);
}
else
{
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
}
}
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
<< " points/second overall, " << (end - start) / elapsedSeconds / thread_count << " per thread)\n ";
}
template <typename T, typename TagT>
void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params,
size_t points_to_skip, size_t points_to_delete_from_beginning)
{
try
{
std::cout << std::endl
<< "Lazy deleting points " << points_to_skip << " to "
<< points_to_skip + points_to_delete_from_beginning << "... ";
for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i)
index.lazy_delete(static_cast<TagT>(i + 1)); // Since tags are data location + 1
std::cout << "done." << std::endl;
auto report = index.consolidate_deletes(delete_params);
std::cout << "#active points: " << report._active_points << std::endl
<< "max points: " << report._max_points << std::endl
<< "empty slots: " << report._empty_slots << std::endl
<< "deletes processed: " << report._slots_released << std::endl
<< "latest delete size: " << report._delete_set_size << std::endl
<< "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, "
<< points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)"
<< std::endl;
}
catch (std::system_error &e)
{
std::cout << "Exception caught in deletion thread: " << e.what() << std::endl;
}
}
template <typename T>
void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters &params, size_t points_to_skip,
size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm,
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
const std::string &save_path, size_t points_to_delete_from_beginning,
size_t start_deletes_after, bool concurrent, const std::string &label_file,
const std::string &universal_label)
{
size_t dim, aligned_dim;
size_t num_points;
diskann::get_bin_metadata(data_path, num_points, dim);
aligned_dim = ROUND_UP(dim, 8);
bool has_labels = label_file != "";
using TagT = uint32_t;
using LabelT = uint32_t;
size_t current_point_offset = points_to_skip;
const size_t last_point_threshold = points_to_skip + max_points_to_insert;
bool enable_tags = true;
using TagT = uint32_t;
auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads);
diskann::IndexConfig index_config = diskann::IndexConfigBuilder()
.with_metric(diskann::L2)
.with_dimension(dim)
.with_max_points(max_points_to_insert)
.is_dynamic_index(true)
.with_index_write_params(params)
.with_index_search_params(index_search_params)
.with_data_type(diskann_type_to_name<T>())
.with_tag_type(diskann_type_to_name<TagT>())
.with_label_type(diskann_type_to_name<LabelT>())
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
.is_enable_tags(enable_tags)
.is_filtered(has_labels)
.with_num_frozen_pts(num_start_pts)
.is_concurrent_consolidate(concurrent)
.build();
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
auto index = index_factory.create_instance();
if (universal_label != "")
{
LabelT u_label = 0;
index->set_universal_label(u_label);
}
if (points_to_skip > num_points)
{
throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__);
}
if (max_points_to_insert == 0)
{
max_points_to_insert = num_points;
}
if (points_to_skip + max_points_to_insert > num_points)
{
max_points_to_insert = num_points - points_to_skip;
std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert
<< " points since the data file has only that many" << std::endl;
}
if (beginning_index_size > max_points_to_insert)
{
beginning_index_size = max_points_to_insert;
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size
<< " points since the data file has only that many" << std::endl;
}
if (checkpoints_per_snapshot > 0 && beginning_index_size > points_per_checkpoint)
{
beginning_index_size = points_per_checkpoint;
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size << std::endl;
}
T *data = nullptr;
diskann::alloc_aligned(
(void **)&data, std::max(points_per_checkpoint, beginning_index_size) * aligned_dim * sizeof(T), 8 * sizeof(T));
std::vector<TagT> tags(beginning_index_size);
std::iota(tags.begin(), tags.end(), 1 + static_cast<TagT>(current_point_offset));
load_aligned_bin_part(data_path, data, current_point_offset, beginning_index_size);
std::cout << "load aligned bin succeeded" << std::endl;
diskann::Timer timer;
if (beginning_index_size > 0)
{
index->build(data, beginning_index_size, tags);
}
else
{
index->set_start_points_at_random(static_cast<T>(start_point_norm));
}
const double elapsedSeconds = timer.elapsed() / 1000000.0;
std::cout << "Initial non-incremental index build time for " << beginning_index_size << " points took "
<< elapsedSeconds << " seconds (" << beginning_index_size / elapsedSeconds << " points/second)\n ";
current_point_offset += beginning_index_size;
if (points_to_delete_from_beginning > max_points_to_insert)
{
points_to_delete_from_beginning = static_cast<uint32_t>(max_points_to_insert);
std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning
<< " points since the data file has only that many" << std::endl;
}
std::vector<std::vector<LabelT>> location_to_labels;
if (concurrent)
{
// handle labels
const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip,
points_to_delete_from_beginning, last_point_threshold);
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
if (has_labels)
{
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
location_to_labels = std::get<0>(parse_result);
}
int32_t sub_threads = (params.num_threads + 1) / 2;
bool delete_launched = false;
std::future<void> delete_task;
diskann::Timer timer;
for (size_t start = current_point_offset; start < last_point_threshold;
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
{
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
auto insert_task = std::async(std::launch::async, [&]() {
load_aligned_bin_part(data_path, data, start, end - start);
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, sub_threads, data, aligned_dim,
location_to_labels);
});
insert_task.wait();
if (!delete_launched && end >= start_deletes_after &&
end >= points_to_skip + points_to_delete_from_beginning)
{
delete_launched = true;
diskann::IndexWriteParameters delete_params =
diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build();
delete_task = std::async(std::launch::async, [&]() {
delete_from_beginning<T, TagT>(*index, delete_params, points_to_skip,
points_to_delete_from_beginning);
});
}
}
delete_task.wait();
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
index->save(save_path_inc.c_str(), true);
}
else
{
const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip,
points_to_delete_from_beginning, last_point_threshold);
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
if (has_labels)
{
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
location_to_labels = std::get<0>(parse_result);
}
size_t last_snapshot_points_threshold = 0;
size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot;
for (size_t start = current_point_offset; start < last_point_threshold;
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
{
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
load_aligned_bin_part(data_path, data, start, end - start);
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, (int32_t)params.num_threads, data,
aligned_dim, location_to_labels);
if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0)
{
diskann::Timer save_timer;
const auto save_path_inc =
get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end);
index->save(save_path_inc.c_str(), false);
const double elapsedSeconds = save_timer.elapsed() / 1000000.0;
const size_t points_saved = end - points_to_skip;
std::cout << "Saved " << points_saved << " points in " << elapsedSeconds << " seconds ("
<< points_saved / elapsedSeconds << " points/second)\n";
num_checkpoints_till_snapshot = checkpoints_per_snapshot;
last_snapshot_points_threshold = end;
}
std::cout << "Number of points in the index post insertion " << end << std::endl;
}
if (checkpoints_per_snapshot > 0 && last_snapshot_points_threshold != last_point_threshold)
{
const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip,
points_to_delete_from_beginning, last_point_threshold);
// index.save(save_path_inc.c_str(), false);
}
if (points_to_delete_from_beginning > 0)
{
delete_from_beginning<T, TagT>(*index, params, points_to_skip, points_to_delete_from_beginning);
}
index->save(save_path_inc.c_str(), true);
}
diskann::aligned_free(data);
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix;
uint32_t num_threads, R, L, num_start_pts;
float alpha, start_point_norm;
size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot,
points_to_delete_from_beginning, start_deletes_after;
bool concurrent;
// label options
std::string label_file, label_type, universal_label;
std::uint32_t Lf, unique_labels_supported;
po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate",
"Test insert deletes & consolidate")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
required_configs.add_options()("points_to_skip", po::value<uint64_t>(&points_to_skip)->required(),
"Skip these first set of points from file");
required_configs.add_options()("beginning_index_size", po::value<uint64_t>(&beginning_index_size)->required(),
"Batch build will be called on these set of points");
required_configs.add_options()("points_per_checkpoint", po::value<uint64_t>(&points_per_checkpoint)->required(),
"Insertions are done in batches of points_per_checkpoint");
required_configs.add_options()("checkpoints_per_snapshot",
po::value<uint64_t>(&checkpoints_per_snapshot)->required(),
"Save the index to disk every few checkpoints");
required_configs.add_options()("points_to_delete_from_beginning",
po::value<uint64_t>(&points_to_delete_from_beginning)->required(), "");
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("max_points_to_insert",
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
"These number of points from the file are inserted after "
"points_to_skip");
optional_configs.add_options()("do_concurrent", po::value<bool>(&concurrent)->default_value(false), "");
optional_configs.add_options()("start_deletes_after",
po::value<uint64_t>(&start_deletes_after)->default_value(0), "");
optional_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->default_value(0),
"Set the start point to a random point on a sphere of this radius");
// optional params for filters
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
"Input label file in txt format for Filtered Index search. "
"The file should contain comma separated filters for each node "
"with each line corresponding to a graph node");
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
"Universal label, if using it, only in conjunction with labels_file");
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
"Build complexity for filtered points, higher value "
"results in better graphs");
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
"Storage type of Labels <uint/ushort>, default value is uint which "
"will consume memory 4 bytes per filter");
optional_configs.add_options()("unique_labels_supported",
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
"Number of unique labels supported by the dynamic index.");
optional_configs.add_options()(
"num_start_points",
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
"Set the number of random start (frozen) points to use when "
"inserting and searching");
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
if (beginning_index_size == 0)
if (start_point_norm == 0)
{
std::cout << "When beginning_index_size is 0, use a start "
"point with "
"appropriate norm"
<< std::endl;
return -1;
}
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
bool has_labels = false;
if (!label_file.empty() || label_file != "")
{
has_labels = true;
}
if (num_start_pts < unique_labels_supported)
{
num_start_pts = unique_labels_supported;
}
try
{
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
.with_max_occlusion_size(500)
.with_alpha(alpha)
.with_num_threads(num_threads)
.with_filter_list_size(Lf)
.build();
if (data_type == std::string("int8"))
build_incremental_index<int8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
else if (data_type == std::string("uint8"))
build_incremental_index<uint8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
else if (data_type == std::string("float"))
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
start_deletes_after, concurrent, label_file, universal_label);
else
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
}
catch (const std::exception &e)
{
std::cerr << "Caught exception: " << e.what() << std::endl;
exit(-1);
}
catch (...)
{
std::cerr << "Caught unknown exception" << std::endl;
exit(-1);
}
return 0;
}

View File

@@ -0,0 +1,523 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <index.h>
#include <numeric>
#include <omp.h>
#include <string.h>
#include <time.h>
#include <timer.h>
#include <boost/program_options.hpp>
#include <future>
#include <abstract_index.h>
#include <index_factory.h>
#include "utils.h"
#include "filter_utils.h"
#include "program_options_utils.hpp"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#include "memory_mapper.h"
namespace po = boost::program_options;
// load_aligned_bin modified to read pieces of the file, but using ifstream
// instead of cached_ifstream.
template <typename T>
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(bin_file, std::ios::binary | std::ios::ate);
size_t actual_file_size = reader.tellg();
reader.seekg(0, std::ios::beg);
int npts_i32, dim_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&dim_i32, sizeof(int));
size_t npts = (uint32_t)npts_i32;
size_t dim = (uint32_t)dim_i32;
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
if (actual_file_size != expected_actual_file_size)
{
std::stringstream stream;
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
<< std::endl;
std::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
if (offset_points + points_to_read > npts)
{
std::stringstream stream;
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
<< " points, but have only " << npts << " points" << std::endl;
std::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
const size_t rounded_dim = ROUND_UP(dim, 8);
for (size_t i = 0; i < points_to_read; i++)
{
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
}
reader.close();
}
std::string get_save_filename(const std::string &save_path, size_t active_window, size_t consolidate_interval,
size_t max_points_to_insert)
{
std::string final_path = save_path;
final_path += "act" + std::to_string(active_window) + "-";
final_path += "cons" + std::to_string(consolidate_interval) + "-";
final_path += "max" + std::to_string(max_points_to_insert);
return final_path;
}
template <typename T, typename TagT, typename LabelT>
void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data,
size_t aligned_dim, std::vector<std::vector<LabelT>> &pts_to_labels)
{
try
{
diskann::Timer insert_timer;
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
size_t num_failed = 0;
#pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed)
for (int64_t j = start; j < (int64_t)end; j++)
{
int insert_result = -1;
if (pts_to_labels.size() > 0)
{
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
pts_to_labels[j - start]);
}
else
{
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
}
if (insert_result != 0)
{
std::cerr << "Insert failed " << j << std::endl;
num_failed++;
}
}
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
<< " points/second overall, " << (end - start) / elapsedSeconds / insert_threads << " per thread)"
<< std::endl;
if (num_failed > 0)
std::cout << num_failed << " of " << end - start << "inserts failed" << std::endl;
}
catch (std::system_error &e)
{
std::cout << "Exiting after catching exception in insertion task: " << e.what() << std::endl;
exit(-1);
}
}
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start,
size_t end)
{
try
{
std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... ";
for (size_t i = start; i < end; ++i)
index.lazy_delete(static_cast<TagT>(1 + i));
std::cout << "lazy delete done." << std::endl;
auto report = index.consolidate_deletes(delete_params);
while (report._status != diskann::consolidation_report::status_code::SUCCESS)
{
int wait_time = 5;
if (report._status == diskann::consolidation_report::status_code::LOCK_FAIL)
{
diskann::cerr << "Unable to acquire consolidate delete lock after "
<< "deleting points " << start << " to " << end << ". Will retry in " << wait_time
<< "seconds." << std::endl;
}
else if (report._status == diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR)
{
diskann::cerr << "Inconsistent counts in data structure. "
<< "Will retry in " << wait_time << "seconds." << std::endl;
}
else
{
std::cerr << "Exiting after unknown error in consolidate delete" << std::endl;
exit(-1);
}
std::this_thread::sleep_for(std::chrono::seconds(wait_time));
report = index.consolidate_deletes(delete_params);
}
auto points_processed = report._active_points + report._slots_released;
auto deletion_rate = points_processed / report._time;
std::cout << "#active points: " << report._active_points << std::endl
<< "max points: " << report._max_points << std::endl
<< "empty slots: " << report._empty_slots << std::endl
<< "deletes processed: " << report._slots_released << std::endl
<< "latest delete size: " << report._delete_set_size << std::endl
<< "Deletion rate: " << deletion_rate << "/sec "
<< "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl;
}
catch (std::system_error &e)
{
std::cerr << "Exiting after catching exception in deletion task: " << e.what() << std::endl;
exit(-1);
}
}
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha,
const uint32_t insert_threads, const uint32_t consolidate_threads,
size_t max_points_to_insert, size_t active_window, size_t consolidate_interval,
const float start_point_norm, uint32_t num_start_pts, const std::string &save_path,
const std::string &label_file, const std::string &universal_label, const uint32_t Lf)
{
const uint32_t C = 500;
const bool saturate_graph = false;
bool has_labels = label_file != "";
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
.with_max_occlusion_size(C)
.with_alpha(alpha)
.with_saturate_graph(saturate_graph)
.with_num_threads(insert_threads)
.with_filter_list_size(Lf)
.build();
auto index_search_params = diskann::IndexSearchParams(L, insert_threads);
diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R)
.with_max_occlusion_size(C)
.with_alpha(alpha)
.with_saturate_graph(saturate_graph)
.with_num_threads(consolidate_threads)
.with_filter_list_size(Lf)
.build();
size_t dim, aligned_dim;
size_t num_points;
std::vector<std::vector<LabelT>> pts_to_labels;
const auto save_path_inc =
get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert);
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
if (has_labels)
{
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
pts_to_labels = std::get<0>(parse_result);
}
diskann::get_bin_metadata(data_path, num_points, dim);
diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims"
<< std::endl;
aligned_dim = ROUND_UP(dim, 8);
auto index_config = diskann::IndexConfigBuilder()
.with_metric(diskann::L2)
.with_dimension(dim)
.with_max_points(active_window + 4 * consolidate_interval)
.is_dynamic_index(true)
.is_enable_tags(true)
.is_use_opq(false)
.is_filtered(has_labels)
.with_num_pq_chunks(0)
.is_pq_dist_build(false)
.with_num_frozen_pts(num_start_pts)
.with_tag_type(diskann_type_to_name<TagT>())
.with_label_type(diskann_type_to_name<LabelT>())
.with_data_type(diskann_type_to_name<T>())
.with_index_write_params(params)
.with_index_search_params(index_search_params)
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
.build();
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
auto index = index_factory.create_instance();
if (universal_label != "")
{
LabelT u_label = 0;
index->set_universal_label(u_label);
}
if (max_points_to_insert == 0)
{
max_points_to_insert = num_points;
}
if (num_points < max_points_to_insert)
throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) +
") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")",
-1, __FUNCSIG__, __FILE__, __LINE__);
if (max_points_to_insert < active_window + consolidate_interval)
throw diskann::ANNException("ERROR: max_points_to_insert < "
"active_window + consolidate_interval",
-1, __FUNCSIG__, __FILE__, __LINE__);
if (consolidate_interval < max_points_to_insert / 1000)
throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__);
index->set_start_points_at_random(static_cast<T>(start_point_norm));
T *data = nullptr;
diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T),
8 * sizeof(T));
std::vector<TagT> tags(max_points_to_insert);
std::iota(tags.begin(), tags.end(), static_cast<TagT>(0));
diskann::Timer timer;
std::vector<std::future<void>> delete_tasks;
auto insert_task = std::async(std::launch::async, [&]() {
load_aligned_bin_part(data_path, data, 0, active_window);
insert_next_batch<T, TagT, LabelT>(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim,
pts_to_labels);
});
insert_task.wait();
for (size_t start = active_window; start + consolidate_interval <= max_points_to_insert;
start += consolidate_interval)
{
auto end = std::min(start + consolidate_interval, max_points_to_insert);
auto insert_task = std::async(std::launch::async, [&]() {
load_aligned_bin_part(data_path, data, start, end - start);
insert_next_batch<T, TagT, LabelT>(*index, start, end, params.num_threads, data, aligned_dim,
pts_to_labels);
});
insert_task.wait();
if (delete_tasks.size() > 0)
delete_tasks[delete_tasks.size() - 1].wait();
if (start >= active_window + consolidate_interval)
{
auto start_del = start - active_window - consolidate_interval;
auto end_del = start - active_window;
delete_tasks.emplace_back(std::async(std::launch::async, [&]() {
delete_and_consolidate<T, TagT, LabelT>(*index, delete_params, (size_t)start_del, (size_t)end_del);
}));
}
}
if (delete_tasks.size() > 0)
delete_tasks[delete_tasks.size() - 1].wait();
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
index->save(save_path_inc.c_str(), true);
diskann::aligned_free(data);
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, unique_labels_supported;
float alpha, start_point_norm;
size_t max_points_to_insert, active_window, consolidate_interval;
po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario",
"Test insert deletes & consolidate")};
try
{
desc.add_options()("help,h", "Print information on arguments");
// Required parameters
po::options_description required_configs("Required");
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
program_options_utils::DATA_TYPE_DESCRIPTION);
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
program_options_utils::INPUT_DATA_PATH);
required_configs.add_options()("active_window", po::value<uint64_t>(&active_window)->required(),
"Program maintains an index over an active window of "
"this size that slides through the data");
required_configs.add_options()("consolidate_interval", po::value<uint64_t>(&consolidate_interval)->required(),
"The program simultaneously adds this number of points to the "
"right of "
"the window while deleting the same number from the left");
required_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->required(),
"Set the start point to a random point on a sphere of this radius");
// Optional parameters
po::options_description optional_configs("Optional");
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
program_options_utils::MAX_BUILD_DEGREE);
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
program_options_utils::GRAPH_BUILD_COMPLEXITY);
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("insert_threads",
po::value<uint32_t>(&insert_threads)->default_value(omp_get_num_procs() / 2),
"Number of threads used for inserting into the index (defaults to "
"omp_get_num_procs()/2)");
optional_configs.add_options()(
"consolidate_threads", po::value<uint32_t>(&consolidate_threads)->default_value(omp_get_num_procs() / 2),
"Number of threads used for consolidating deletes to "
"the index (defaults to omp_get_num_procs()/2)");
optional_configs.add_options()("max_points_to_insert",
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
"The number of points from the file that the program streams "
"over ");
optional_configs.add_options()(
"num_start_points",
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
"Set the number of random start (frozen) points to use when "
"inserting and searching");
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
"Input label file in txt format for Filtered Index search. "
"The file should contain comma separated filters for each node "
"with each line corresponding to a graph node");
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
"Universal label, if using it, only in conjunction with labels_file");
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
"Build complexity for filtered points, higher value "
"results in better graphs");
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
"Storage type of Labels <uint/ushort>, default value is uint which "
"will consume memory 4 bytes per filter");
optional_configs.add_options()("unique_labels_supported",
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
"Number of unique labels supported by the dynamic index.");
// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
// Validate arguments
if (start_point_norm == 0)
{
std::cout << "When beginning_index_size is 0, use a start point with "
"appropriate norm"
<< std::endl;
return -1;
}
if (label_type != std::string("ushort") && label_type != std::string("uint"))
{
std::cerr << "Invalid label type. Supported types are uint and ushort" << std::endl;
return -1;
}
if (data_type != std::string("int8") && data_type != std::string("uint8") && data_type != std::string("float"))
{
std::cerr << "Invalid data type. Supported types are int8, uint8 and float" << std::endl;
return -1;
}
// TODO: Are additional distance functions supported?
if (dist_fn != std::string("l2") && dist_fn != std::string("mips"))
{
std::cerr << "Invalid distance function. Supported functions are l2 and mips" << std::endl;
return -1;
}
if (num_start_pts < unique_labels_supported)
{
num_start_pts = unique_labels_supported;
}
try
{
if (data_type == std::string("uint8"))
{
if (label_type == std::string("ushort"))
{
build_incremental_index<uint8_t, uint32_t, uint16_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
else if (label_type == std::string("uint"))
{
build_incremental_index<uint8_t, uint32_t, uint32_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
}
else if (data_type == std::string("int8"))
{
if (label_type == std::string("ushort"))
{
build_incremental_index<int8_t, uint32_t, uint16_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
else if (label_type == std::string("uint"))
{
build_incremental_index<int8_t, uint32_t, uint32_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
}
else if (data_type == std::string("float"))
{
if (label_type == std::string("ushort"))
{
build_incremental_index<float, uint32_t, uint16_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
else if (label_type == std::string("uint"))
{
build_incremental_index<float, uint32_t, uint32_t>(
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
universal_label, Lf);
}
}
}
catch (const std::exception &e)
{
std::cerr << "Caught exception: " << e.what() << std::endl;
exit(-1);
}
catch (...)
{
std::cerr << "Caught unknown exception" << std::endl;
exit(-1);
}
return 0;
}

View File

@@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
add_executable(fvecs_to_bin fvecs_to_bin.cpp)
add_executable(fvecs_to_bvecs fvecs_to_bvecs.cpp)
add_executable(rand_data_gen rand_data_gen.cpp)
target_link_libraries(rand_data_gen ${PROJECT_NAME} Boost::program_options)
add_executable(float_bin_to_int8 float_bin_to_int8.cpp)
add_executable(ivecs_to_bin ivecs_to_bin.cpp)
add_executable(count_bfs_levels count_bfs_levels.cpp)
target_link_libraries(count_bfs_levels ${PROJECT_NAME} Boost::program_options)
add_executable(tsv_to_bin tsv_to_bin.cpp)
add_executable(bin_to_tsv bin_to_tsv.cpp)
add_executable(int8_to_float int8_to_float.cpp)
target_link_libraries(int8_to_float ${PROJECT_NAME})
add_executable(int8_to_float_scale int8_to_float_scale.cpp)
target_link_libraries(int8_to_float_scale ${PROJECT_NAME})
add_executable(uint8_to_float uint8_to_float.cpp)
target_link_libraries(uint8_to_float ${PROJECT_NAME})
add_executable(uint32_to_uint8 uint32_to_uint8.cpp)
target_link_libraries(uint32_to_uint8 ${PROJECT_NAME})
add_executable(vector_analysis vector_analysis.cpp)
target_link_libraries(vector_analysis ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(gen_random_slice gen_random_slice.cpp)
target_link_libraries(gen_random_slice ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(simulate_aggregate_recall simulate_aggregate_recall.cpp)
add_executable(calculate_recall calculate_recall.cpp)
target_link_libraries(calculate_recall ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
# Compute ground truth thing outside of DiskANN main source that depends on MKL.
add_executable(compute_groundtruth compute_groundtruth.cpp)
target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp)
target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
add_executable(generate_pq generate_pq.cpp)
target_link_libraries(generate_pq ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(partition_data partition_data.cpp)
target_link_libraries(partition_data ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(partition_with_ram_budget partition_with_ram_budget.cpp)
target_link_libraries(partition_with_ram_budget ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(merge_shards merge_shards.cpp)
target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB})
add_executable(create_disk_layout create_disk_layout.cpp)
target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(generate_synthetic_labels generate_synthetic_labels.cpp)
target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options)
add_executable(stats_label_data stats_label_data.cpp)
target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options)
if (NOT MSVC)
include(GNUInstallDirs)
install(TARGETS fvecs_to_bin
fvecs_to_bvecs
rand_data_gen
float_bin_to_int8
ivecs_to_bin
count_bfs_levels
tsv_to_bin
bin_to_tsv
int8_to_float
int8_to_float_scale
uint8_to_float
uint32_to_uint8
vector_analysis
gen_random_slice
simulate_aggregate_recall
calculate_recall
compute_groundtruth
compute_groundtruth_for_filters
generate_pq
partition_data
partition_with_ram_budget
merge_shards
create_disk_layout
generate_synthetic_labels
stats_label_data
RUNTIME
)
endif()

View File

@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "util.h"
void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts,
uint64_t ndims)
{
writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned)));
#pragma omp parallel for
for (uint64_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
}
readr.read((char *)write_buf, npts * ndims * sizeof(float));
}
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_bin output_fvecs" << std::endl;
exit(-1);
}
std::ifstream readr(argv[1], std::ios::binary);
int npts_s32;
int ndims_s32;
readr.read((char *)&npts_s32, sizeof(int32_t));
readr.read((char *)&ndims_s32, sizeof(int32_t));
size_t npts = npts_s32;
size_t ndims = ndims_s32;
uint32_t ndims_u32 = (uint32_t)ndims_s32;
// uint64_t fsize = writr.tellg();
readr.seekg(0, std::ios::beg);
unsigned ndims_u32;
writr.write((char *)&ndims_u32, sizeof(unsigned));
writr.seekg(0, std::ios::beg);
uint64_t ndims = (uint64_t)ndims_u32;
uint64_t npts = fsize / ((ndims + 1) * sizeof(float));
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
uint64_t blk_size = 131072;
uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writr(argv[2], std::ios::binary);
float *read_buf = new float[npts * (ndims + 1)];
float *write_buf = new float[npts * ndims];
for (uint64_t i = 0; i < nblks; i++)
{
uint64_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
writr.close();
readr.close();
}

View File

@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
template <class T>
void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims)
{
reader.read((char *)read_buf, npts * ndims * sizeof(float));
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; d++)
{
writer << read_buf[d + i * ndims];
if (d < ndims - 1)
writer << "\t";
else
writer << "\n";
}
}
}
int main(int argc, char **argv)
{
if (argc != 4)
{
std::cout << argv[0] << " <float/int8/uint8> input_bin output_tsv" << std::endl;
exit(-1);
}
std::string type_string(argv[1]);
if ((type_string != std::string("float")) && (type_string != std::string("int8")) &&
(type_string != std::string("uin8")))
{
std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl;
}
std::ifstream reader(argv[2], std::ios::binary);
uint32_t npts_u32;
uint32_t ndims_u32;
reader.read((char *)&npts_u32, sizeof(uint32_t));
reader.read((char *)&ndims_u32, sizeof(uint32_t));
size_t npts = npts_u32;
size_t ndims = ndims_u32;
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::ofstream writer(argv[3]);
char *read_buf = new char[blk_size * ndims * 4];
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
if (type_string == std::string("float"))
block_convert<float>(writer, reader, (float *)read_buf, cblk_size, ndims);
else if (type_string == std::string("int8"))
block_convert<int8_t>(writer, reader, (int8_t *)read_buf, cblk_size, ndims);
else if (type_string == std::string("uint8"))
block_convert<uint8_t>(writer, reader, (uint8_t *)read_buf, cblk_size, ndims);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
writer.close();
reader.close();
}

View File

@@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <cstddef>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <set>
#include <string>
#include <vector>
#include "utils.h"
#include "disk_utils.h"
int main(int argc, char **argv)
{
if (argc != 4)
{
std::cout << argv[0] << " <ground_truth_bin> <our_results_bin> <r> " << std::endl;
return -1;
}
uint32_t *gold_std = NULL;
float *gs_dist = nullptr;
uint32_t *our_results = NULL;
float *or_dist = nullptr;
size_t points_num, points_num_gs, points_num_or;
size_t dim_gs;
size_t dim_or;
diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs);
diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or);
if (points_num_gs != points_num_or)
{
std::cout << "Error. Number of queries mismatch in ground truth and "
"our results"
<< std::endl;
return -1;
}
points_num = points_num_gs;
uint32_t recall_at = std::atoi(argv[3]);
if ((dim_or < recall_at) || (recall_at > dim_gs))
{
std::cout << "ground truth has size " << dim_gs << "; our set has " << dim_or << " points. Asking for recall "
<< recall_at << std::endl;
return -1;
}
std::cout << "Calculating recall@" << recall_at << std::endl;
double recall_val = diskann::calculate_recall((uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs,
our_results, (uint32_t)dim_or, (uint32_t)recall_at);
// double avg_recall = (recall*1.0)/(points_num*1.0);
std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n";
}

View File

@@ -0,0 +1,574 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <string>
#include <iostream>
#include <fstream>
#include <cassert>
#include <vector>
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <random>
#include <limits>
#include <cstring>
#include <queue>
#include <omp.h>
#include <mkl.h>
#include <boost/program_options.hpp>
#include <unordered_map>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>
#ifdef _WINDOWS
#include <malloc.h>
#else
#include <stdlib.h>
#endif
#include "filter_utils.h"
#include "utils.h"
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
#define PARTSIZE 10000000
#define ALIGNMENT 512
// custom types (for readability)
typedef tsl::robin_set<std::string> label_set;
typedef std::string path;
namespace po = boost::program_options;
template <class T> T div_round_up(const T numerator, const T denominator)
{
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
}
using pairIF = std::pair<size_t, float>;
struct cmpmaxstruct
{
bool operator()(const pairIF &l, const pairIF &r)
{
return l.second < r.second;
};
};
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
{
#ifdef _WINDOWS
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
#else
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
#endif
}
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
{
return a.second < b.second;
}
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
{
assert(points_l2sq != NULL);
#pragma omp parallel for schedule(static, 65536)
for (int64_t d = 0; d < num_points; ++d)
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
}
void distsq_to_points(const size_t dim,
float *dist_matrix, // Col Major, cols are queries, rows are points
size_t npoints, const float *const points,
const float *const points_l2sq, // points in Col major
size_t nqueries, const float *const queries,
const float *const queries_l2sq, // queries in Col major
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
{
bool ones_vec_alloc = false;
if (ones_vec == NULL)
{
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
ones_vec_alloc = true;
}
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
(float)0.0, dist_matrix, npoints);
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
if (ones_vec_alloc)
delete[] ones_vec;
}
void inner_prod_to_points(const size_t dim,
float *dist_matrix, // Col Major, cols are queries, rows are points
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
{
bool ones_vec_alloc = false;
if (ones_vec == NULL)
{
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
ones_vec_alloc = true;
}
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
(float)0.0, dist_matrix, npoints);
if (ones_vec_alloc)
delete[] ones_vec;
}
void exact_knn(const size_t dim, const size_t k,
size_t *const closest_points, // k * num_queries preallocated, col
// major, queries columns
float *const dist_closest_points, // k * num_queries
// preallocated, Dist to
// corresponding closes_points
size_t npoints,
float *points_in, // points in Col major
size_t nqueries, float *queries_in,
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
{
float *points_l2sq = new float[npoints];
float *queries_l2sq = new float[nqueries];
compute_l2sq(points_l2sq, points_in, npoints, dim);
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
float *points = points_in;
float *queries = queries_in;
if (metric == diskann::Metric::COSINE)
{ // we convert cosine distance as
// normalized L2 distnace
points = new float[npoints * dim];
queries = new float[nqueries * dim];
#pragma omp parallel for schedule(static, 4096)
for (int64_t i = 0; i < (int64_t)npoints; i++)
{
float norm = std::sqrt(points_l2sq[i]);
if (norm == 0)
{
norm = std::numeric_limits<float>::epsilon();
}
for (uint32_t j = 0; j < dim; j++)
{
points[i * dim + j] = points_in[i * dim + j] / norm;
}
}
#pragma omp parallel for schedule(static, 4096)
for (int64_t i = 0; i < (int64_t)nqueries; i++)
{
float norm = std::sqrt(queries_l2sq[i]);
if (norm == 0)
{
norm = std::numeric_limits<float>::epsilon();
}
for (uint32_t j = 0; j < dim; j++)
{
queries[i * dim + j] = queries_in[i * dim + j] / norm;
}
}
// recalculate norms after normalizing, they should all be one.
compute_l2sq(points_l2sq, points, npoints, dim);
compute_l2sq(queries_l2sq, queries, nqueries, dim);
}
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
<< dim << " dimensions using";
if (metric == diskann::Metric::INNER_PRODUCT)
std::cout << " MIPS ";
else if (metric == diskann::Metric::COSINE)
std::cout << " Cosine ";
else
std::cout << " L2 ";
std::cout << "distance fn. " << std::endl;
size_t q_batch_size = (1 << 9);
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
{
int64_t q_b = b * q_batch_size;
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
{
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
}
else
{
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
}
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
#pragma omp parallel for schedule(dynamic, 16)
for (long long q = q_b; q < q_e; q++)
{
maxPQIFCS point_dist;
for (size_t p = 0; p < k; p++)
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
for (size_t p = k; p < npoints; p++)
{
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
if (point_dist.size() > k)
point_dist.pop();
}
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
{
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
point_dist.pop();
}
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
}
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
}
delete[] dist_matrix;
delete[] points_l2sq;
delete[] queries_l2sq;
if (metric == diskann::Metric::COSINE)
{
delete[] points;
delete[] queries;
}
}
template <typename T> inline int get_num_parts(const char *filename)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(filename, std::ios::binary);
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
reader.close();
uint32_t num_parts =
(npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
std::cout << "Number of parts: " << num_parts << std::endl;
return num_parts;
}
template <typename T>
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(filename, std::ios::binary);
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
uint64_t start_id = part_num * PARTSIZE;
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
npts = end_id - start_id;
ndims = (uint64_t)ndims_i32;
std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B"
<< std::endl;
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
T *data_T = new T[npts * ndims];
reader.read((char *)data_T, sizeof(T) * npts * ndims);
std::cout << "Finished reading part of the bin file." << std::endl;
reader.close();
data = aligned_malloc<float>(npts * ndims, ALIGNMENT);
#pragma omp parallel for schedule(dynamic, 32768)
for (int64_t i = 0; i < (int64_t)npts; i++)
{
for (int64_t j = 0; j < (int64_t)ndims; j++)
{
float cur_val_float = (float)data_T[i * ndims + j];
std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float));
}
}
delete[] data_T;
std::cout << "Finished converting part data to float." << std::endl;
}
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
{
std::ofstream writer;
writer.exceptions(std::ios::failbit | std::ios::badbit);
writer.open(filename, std::ios::binary | std::ios::out);
std::cout << "Writing bin: " << filename << "\n";
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
writer.write((char *)&npts_i32, sizeof(int));
writer.write((char *)&ndims_i32, sizeof(int));
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
writer.write((char *)data, npts * ndims * sizeof(T));
writer.close();
std::cout << "Finished writing bin" << std::endl;
}
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
size_t ndims)
{
std::ofstream writer(filename, std::ios::binary | std::ios::out);
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
writer.write((char *)&npts_i32, sizeof(int));
writer.write((char *)&ndims_i32, sizeof(int));
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
"npts*dim dist-matrix) with npts = "
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
<< "B" << std::endl;
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
writer.write((char *)distances, npts * ndims * sizeof(float));
writer.close();
std::cout << "Finished writing truthset" << std::endl;
}
template <typename T>
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
size_t &nqueries, size_t &npoints,
size_t &dim, size_t &k, float *query_data,
const diskann::Metric &metric,
std::vector<uint32_t> &location_to_tag)
{
float *base_data = nullptr;
int num_parts = get_num_parts<T>(base_file.c_str());
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
for (int p = 0; p < num_parts; p++)
{
size_t start_id = p * PARTSIZE;
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
size_t *closest_points_part = new size_t[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];
auto part_k = k < npoints ? k : npoints;
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
metric);
for (size_t i = 0; i < nqueries; i++)
{
for (size_t j = 0; j < part_k; j++)
{
if (!location_to_tag.empty())
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
dist_closest_points_part[i * part_k + j]));
}
}
delete[] closest_points_part;
delete[] dist_closest_points_part;
diskann::aligned_free(base_data);
}
return res;
};
template <typename T>
int aux_main(const std::string &base_file, const std::string &query_file, const std::string &gt_file, size_t k,
const diskann::Metric &metric, const std::string &tags_file = std::string(""))
{
size_t npoints, nqueries, dim;
float *query_data;
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
if (nqueries > PARTSIZE)
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
// load tags
const bool tags_enabled = tags_file.empty() ? false : true;
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
int *closest_points = new int[nqueries * k];
float *dist_closest_points = new float[nqueries * k];
std::vector<std::vector<std::pair<uint32_t, float>>> results =
processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
for (size_t i = 0; i < nqueries; i++)
{
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
size_t j = 0;
for (auto iter : cur_res)
{
if (j == k)
break;
if (tags_enabled)
{
std::uint32_t index_with_tag = location_to_tag[iter.first];
closest_points[i * k + j] = (int32_t)index_with_tag;
}
else
{
closest_points[i * k + j] = (int32_t)iter.first;
}
if (metric == diskann::Metric::INNER_PRODUCT)
dist_closest_points[i * k + j] = -iter.second;
else
dist_closest_points[i * k + j] = iter.second;
++j;
}
if (j < k)
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
}
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
delete[] closest_points;
delete[] dist_closest_points;
diskann::aligned_free(query_data);
return 0;
}
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
{
size_t read_blk_size = 64 * 1024 * 1024;
cached_ifstream reader(bin_file, read_blk_size);
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
size_t actual_file_size = reader.get_file_size();
int npts_i32, dim_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&dim_i32, sizeof(int));
npts = (uint32_t)npts_i32;
dim = (uint32_t)dim_i32;
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
// only ids, -1 is error
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
if (actual_file_size == expected_file_size_with_dists)
truthset_type = 1;
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
if (actual_file_size == expected_file_size_just_ids)
truthset_type = 2;
if (truthset_type == -1)
{
std::stringstream stream;
stream << "Error. File size mismatch. File should have bin format, with "
"npts followed by ngt followed by npts*ngt ids and optionally "
"followed by npts*ngt distance values; actual size: "
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
<< expected_file_size_just_ids;
diskann::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
ids = new uint32_t[npts * dim];
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
if (truthset_type == 1)
{
dists = new float[npts * dim];
reader.read((char *)dists, npts * dim * sizeof(float));
}
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file;
uint64_t K;
try
{
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
"distance function <l2/mips/cosine>");
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
"File containing the base vectors in binary format");
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
"File containing the query vectors in binary format");
desc.add_options()("gt_file", po::value<std::string>(&gt_file)->required(),
"File name for the writing ground truth in binary "
"format, please don' append .bin at end if "
"no filter_label or filter_label_file is provided it "
"will save the file with '.bin' at end."
"else it will save the file as filename_label.bin");
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
"Number of ground truth nearest neighbors to compute");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"File containing the tags in binary format");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
{
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
return -1;
}
try
{
if (data_type == std::string("float"))
aux_main<float>(base_file, query_file, gt_file, K, metric, tags_file);
if (data_type == std::string("int8"))
aux_main<int8_t>(base_file, query_file, gt_file, K, metric, tags_file);
if (data_type == std::string("uint8"))
aux_main<uint8_t>(base_file, query_file, gt_file, K, metric, tags_file);
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Compute GT failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,919 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <string>
#include <iostream>
#include <fstream>
#include <cassert>
#include <vector>
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <random>
#include <limits>
#include <cstring>
#include <queue>
#include <omp.h>
#include <mkl.h>
#include <boost/program_options.hpp>
#include <unordered_map>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>
#ifdef _WINDOWS
#include <malloc.h>
#else
#include <stdlib.h>
#endif
#include "filter_utils.h"
#include "utils.h"
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
#define PARTSIZE 10000000
#define ALIGNMENT 512
// custom types (for readability)
typedef tsl::robin_set<std::string> label_set;
typedef std::string path;
namespace po = boost::program_options;
template <class T> T div_round_up(const T numerator, const T denominator)
{
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
}
using pairIF = std::pair<size_t, float>;
struct cmpmaxstruct
{
bool operator()(const pairIF &l, const pairIF &r)
{
return l.second < r.second;
};
};
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
{
#ifdef _WINDOWS
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
#else
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
#endif
}
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
{
return a.second < b.second;
}
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
{
assert(points_l2sq != NULL);
#pragma omp parallel for schedule(static, 65536)
for (int64_t d = 0; d < num_points; ++d)
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
}
void distsq_to_points(const size_t dim,
float *dist_matrix, // Col Major, cols are queries, rows are points
size_t npoints, const float *const points,
const float *const points_l2sq, // points in Col major
size_t nqueries, const float *const queries,
const float *const queries_l2sq, // queries in Col major
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
{
bool ones_vec_alloc = false;
if (ones_vec == NULL)
{
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
ones_vec_alloc = true;
}
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
(float)0.0, dist_matrix, npoints);
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
if (ones_vec_alloc)
delete[] ones_vec;
}
void inner_prod_to_points(const size_t dim,
float *dist_matrix, // Col Major, cols are queries, rows are points
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
{
bool ones_vec_alloc = false;
if (ones_vec == NULL)
{
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
ones_vec_alloc = true;
}
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
(float)0.0, dist_matrix, npoints);
if (ones_vec_alloc)
delete[] ones_vec;
}
void exact_knn(const size_t dim, const size_t k,
size_t *const closest_points, // k * num_queries preallocated, col
// major, queries columns
float *const dist_closest_points, // k * num_queries
// preallocated, Dist to
// corresponding closes_points
size_t npoints,
float *points_in, // points in Col major
size_t nqueries, float *queries_in,
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
{
float *points_l2sq = new float[npoints];
float *queries_l2sq = new float[nqueries];
compute_l2sq(points_l2sq, points_in, npoints, dim);
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
float *points = points_in;
float *queries = queries_in;
if (metric == diskann::Metric::COSINE)
{ // we convert cosine distance as
// normalized L2 distnace
points = new float[npoints * dim];
queries = new float[nqueries * dim];
#pragma omp parallel for schedule(static, 4096)
for (int64_t i = 0; i < (int64_t)npoints; i++)
{
float norm = std::sqrt(points_l2sq[i]);
if (norm == 0)
{
norm = std::numeric_limits<float>::epsilon();
}
for (uint32_t j = 0; j < dim; j++)
{
points[i * dim + j] = points_in[i * dim + j] / norm;
}
}
#pragma omp parallel for schedule(static, 4096)
for (int64_t i = 0; i < (int64_t)nqueries; i++)
{
float norm = std::sqrt(queries_l2sq[i]);
if (norm == 0)
{
norm = std::numeric_limits<float>::epsilon();
}
for (uint32_t j = 0; j < dim; j++)
{
queries[i * dim + j] = queries_in[i * dim + j] / norm;
}
}
// recalculate norms after normalizing, they should all be one.
compute_l2sq(points_l2sq, points, npoints, dim);
compute_l2sq(queries_l2sq, queries, nqueries, dim);
}
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
<< dim << " dimensions using";
if (metric == diskann::Metric::INNER_PRODUCT)
std::cout << " MIPS ";
else if (metric == diskann::Metric::COSINE)
std::cout << " Cosine ";
else
std::cout << " L2 ";
std::cout << "distance fn. " << std::endl;
size_t q_batch_size = (1 << 9);
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
{
int64_t q_b = b * q_batch_size;
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
{
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
}
else
{
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
}
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
#pragma omp parallel for schedule(dynamic, 16)
for (long long q = q_b; q < q_e; q++)
{
maxPQIFCS point_dist;
for (size_t p = 0; p < k; p++)
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
for (size_t p = k; p < npoints; p++)
{
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
if (point_dist.size() > k)
point_dist.pop();
}
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
{
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
point_dist.pop();
}
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
}
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
}
delete[] dist_matrix;
delete[] points_l2sq;
delete[] queries_l2sq;
if (metric == diskann::Metric::COSINE)
{
delete[] points;
delete[] queries;
}
}
template <typename T> inline int get_num_parts(const char *filename)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(filename, std::ios::binary);
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
reader.close();
int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
std::cout << "Number of parts: " << num_parts << std::endl;
return num_parts;
}
template <typename T>
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num)
{
std::ifstream reader;
reader.exceptions(std::ios::failbit | std::ios::badbit);
reader.open(filename, std::ios::binary);
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
uint64_t start_id = part_num * PARTSIZE;
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
npts_u64 = end_id - start_id;
ndims_u64 = (uint64_t)ndims_i32;
std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64
<< ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl;
reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
T *data_T = new T[npts_u64 * ndims_u64];
reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64);
std::cout << "Finished reading part of the bin file." << std::endl;
reader.close();
data = aligned_malloc<float>(npts_u64 * ndims_u64, ALIGNMENT);
#pragma omp parallel for schedule(dynamic, 32768)
for (int64_t i = 0; i < (int64_t)npts_u64; i++)
{
for (int64_t j = 0; j < (int64_t)ndims_u64; j++)
{
float cur_val_float = (float)data_T[i * ndims_u64 + j];
std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float));
}
}
delete[] data_T;
std::cout << "Finished converting part data to float." << std::endl;
}
template <typename T>
inline std::vector<size_t> load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims,
int part_num, const char *label_file,
const std::string &filter_label,
const std::string &universal_label, size_t &npoints_filt,
std::vector<std::vector<std::string>> &pts_to_labels)
{
std::ifstream reader(filename, std::ios::binary);
if (reader.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
}
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
std::vector<size_t> rev_map;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&ndims_i32, sizeof(int));
uint64_t start_id = part_num * PARTSIZE;
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
npts = end_id - start_id;
ndims = (uint32_t)ndims_i32;
uint64_t nptsuint64_t = (uint64_t)npts;
uint64_t ndimsuint64_t = (uint64_t)ndims;
npoints_filt = 0;
std::cout << "#pts in part = " << npts << ", #dims = " << ndims
<< ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl;
std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl;
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
T *data_T = new T[nptsuint64_t * ndimsuint64_t];
reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t);
std::cout << "Finished reading part of the bin file." << std::endl;
reader.close();
data = aligned_malloc<float>(nptsuint64_t * ndimsuint64_t, ALIGNMENT);
for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++)
{
if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) !=
pts_to_labels[start_id + i].end() ||
std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) !=
pts_to_labels[start_id + i].end())
{
rev_map.push_back(start_id + i);
for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++)
{
float cur_val_float = (float)data_T[i * ndimsuint64_t + j];
std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float));
}
npoints_filt++;
}
}
delete[] data_T;
std::cout << "Finished converting part data to float.. identified " << npoints_filt
<< " points matching the filter." << std::endl;
return rev_map;
}
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
{
std::ofstream writer;
writer.exceptions(std::ios::failbit | std::ios::badbit);
writer.open(filename, std::ios::binary | std::ios::out);
std::cout << "Writing bin: " << filename << "\n";
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
writer.write((char *)&npts_i32, sizeof(int));
writer.write((char *)&ndims_i32, sizeof(int));
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
writer.write((char *)data, npts * ndims * sizeof(T));
writer.close();
std::cout << "Finished writing bin" << std::endl;
}
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
size_t ndims)
{
std::ofstream writer(filename, std::ios::binary | std::ios::out);
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
writer.write((char *)&npts_i32, sizeof(int));
writer.write((char *)&ndims_i32, sizeof(int));
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
"npts*dim dist-matrix) with npts = "
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
<< "B" << std::endl;
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
writer.write((char *)distances, npts * ndims * sizeof(float));
writer.close();
std::cout << "Finished writing truthset" << std::endl;
}
inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file,
std::vector<std::vector<std::string>> &pts_to_labels)
{
std::ifstream infile(map_file);
std::string line, token;
std::set<std::string> labels;
infile.clear();
infile.seekg(0, std::ios::beg);
while (std::getline(infile, line))
{
std::istringstream iss(line);
std::vector<std::string> lbls(0);
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
lbls.push_back(token);
labels.insert(token);
}
std::sort(lbls.begin(), lbls.end());
pts_to_labels.push_back(lbls);
}
std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for "
<< pts_to_labels.size() << " points" << std::endl;
}
template <typename T>
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
size_t &nqueries, size_t &npoints,
size_t &dim, size_t &k, float *query_data,
const diskann::Metric &metric,
std::vector<uint32_t> &location_to_tag)
{
float *base_data = nullptr;
int num_parts = get_num_parts<T>(base_file.c_str());
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
for (int p = 0; p < num_parts; p++)
{
size_t start_id = p * PARTSIZE;
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
size_t *closest_points_part = new size_t[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];
auto part_k = k < npoints ? k : npoints;
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
metric);
for (size_t i = 0; i < nqueries; i++)
{
for (uint64_t j = 0; j < part_k; j++)
{
if (!location_to_tag.empty())
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
dist_closest_points_part[i * part_k + j]));
}
}
delete[] closest_points_part;
delete[] dist_closest_points_part;
diskann::aligned_free(base_data);
}
return res;
};
template <typename T>
std::vector<std::vector<std::pair<uint32_t, float>>> processFilteredParts(
const std::string &base_file, const std::string &label_file, const std::string &filter_label,
const std::string &universal_label, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data,
const diskann::Metric &metric, std::vector<uint32_t> &location_to_tag)
{
size_t npoints_filt = 0;
float *base_data = nullptr;
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
int num_parts = get_num_parts<T>(base_file.c_str());
std::vector<std::vector<std::string>> pts_to_labels;
if (filter_label != "")
parse_label_file_into_vec(npoints, label_file, pts_to_labels);
for (int p = 0; p < num_parts; p++)
{
size_t start_id = p * PARTSIZE;
std::vector<size_t> rev_map;
if (filter_label != "")
rev_map = load_filtered_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(),
filter_label, universal_label, npoints_filt, pts_to_labels);
size_t *closest_points_part = new size_t[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];
auto part_k = k < npoints_filt ? k : npoints_filt;
if (npoints_filt > 0)
{
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries,
query_data, metric);
}
for (size_t i = 0; i < nqueries; i++)
{
for (uint64_t j = 0; j < part_k; j++)
{
if (!location_to_tag.empty())
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;
res[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]),
dist_closest_points_part[i * part_k + j]));
}
}
delete[] closest_points_part;
delete[] dist_closest_points_part;
diskann::aligned_free(base_data);
}
return res;
};
template <typename T>
int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file,
const std::string &gt_file, size_t k, const std::string &universal_label, const diskann::Metric &metric,
const std::string &filter_label, const std::string &tags_file = std::string(""))
{
size_t npoints, nqueries, dim;
float *query_data = nullptr;
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
if (nqueries > PARTSIZE)
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
// load tags
const bool tags_enabled = tags_file.empty() ? false : true;
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
int *closest_points = new int[nqueries * k];
float *dist_closest_points = new float[nqueries * k];
std::vector<std::vector<std::pair<uint32_t, float>>> results;
if (filter_label == "")
{
results = processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
}
else
{
results = processFilteredParts<T>(base_file, label_file, filter_label, universal_label, nqueries, npoints, dim,
k, query_data, metric, location_to_tag);
}
for (size_t i = 0; i < nqueries; i++)
{
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
size_t j = 0;
for (auto iter : cur_res)
{
if (j == k)
break;
if (tags_enabled)
{
std::uint32_t index_with_tag = location_to_tag[iter.first];
closest_points[i * k + j] = (int32_t)index_with_tag;
}
else
{
closest_points[i * k + j] = (int32_t)iter.first;
}
if (metric == diskann::Metric::INNER_PRODUCT)
dist_closest_points[i * k + j] = -iter.second;
else
dist_closest_points[i * k + j] = iter.second;
++j;
}
if (j < k)
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
}
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
delete[] closest_points;
delete[] dist_closest_points;
diskann::aligned_free(query_data);
return 0;
}
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
{
size_t read_blk_size = 64 * 1024 * 1024;
cached_ifstream reader(bin_file, read_blk_size);
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
size_t actual_file_size = reader.get_file_size();
int npts_i32, dim_i32;
reader.read((char *)&npts_i32, sizeof(int));
reader.read((char *)&dim_i32, sizeof(int));
npts = (uint32_t)npts_i32;
dim = (uint32_t)dim_i32;
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
// only ids, -1 is error
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
if (actual_file_size == expected_file_size_with_dists)
truthset_type = 1;
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
if (actual_file_size == expected_file_size_just_ids)
truthset_type = 2;
if (truthset_type == -1)
{
std::stringstream stream;
stream << "Error. File size mismatch. File should have bin format, with "
"npts followed by ngt followed by npts*ngt ids and optionally "
"followed by npts*ngt distance values; actual size: "
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
<< expected_file_size_just_ids;
diskann::cout << stream.str();
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
ids = new uint32_t[npts * dim];
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
if (truthset_type == 1)
{
dists = new float[npts * dim];
reader.read((char *)dists, npts * dim * sizeof(float));
}
}
int main(int argc, char **argv)
{
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label,
universal_label, filter_label_file;
uint64_t K;
try
{
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(), "distance function <l2/mips>");
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
"File containing the base vectors in binary format");
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
"File containing the query vectors in binary format");
desc.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
"Input labels file in txt format if present");
desc.add_options()("filter_label", po::value<std::string>(&filter_label)->default_value(""),
"Input filter label if doing filtered groundtruth");
desc.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
"Universal label, if using it, only in conjunction with label_file");
desc.add_options()("gt_file", po::value<std::string>(&gt_file)->required(),
"File name for the writing ground truth in binary "
"format, please don' append .bin at end if "
"no filter_label or filter_label_file is provided it "
"will save the file with '.bin' at end."
"else it will save the file as filename_label.bin");
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
"Number of ground truth nearest neighbors to compute");
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
"File containing the tags in binary format");
desc.add_options()("filter_label_file",
po::value<std::string>(&filter_label_file)->default_value(std::string("")),
"Filter file for Queries for Filtered Search ");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
{
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
return -1;
}
if (filter_label != "" && filter_label_file != "")
{
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
{
metric = diskann::Metric::L2;
}
else if (dist_fn == std::string("mips"))
{
metric = diskann::Metric::INNER_PRODUCT;
}
else if (dist_fn == std::string("cosine"))
{
metric = diskann::Metric::COSINE;
}
else
{
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
return -1;
}
std::vector<std::string> filter_labels;
if (filter_label != "")
{
filter_labels.push_back(filter_label);
}
else if (filter_label_file != "")
{
filter_labels = read_file_to_vector_of_strings(filter_label_file, false);
}
// only when there is no filter label or 1 filter label for all queries
if (filter_labels.size() == 1)
{
try
{
if (data_type == std::string("float"))
aux_main<float>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
filter_labels[0], tags_file);
if (data_type == std::string("int8"))
aux_main<int8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
filter_labels[0], tags_file);
if (data_type == std::string("uint8"))
aux_main<uint8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
filter_labels[0], tags_file);
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Compute GT failed." << std::endl;
return -1;
}
}
else
{ // Each query has its own filter label
// Split up data and query bins into label specific ones
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
tsl::robin_map<std::string, uint32_t> labels_to_number_of_queries;
label_set all_labels;
for (size_t i = 0; i < filter_labels.size(); i++)
{
std::string label = filter_labels[i];
all_labels.insert(label);
if (labels_to_number_of_queries.find(label) == labels_to_number_of_queries.end())
{
labels_to_number_of_queries[label] = 0;
}
labels_to_number_of_queries[label] += 1;
}
size_t npoints;
std::vector<std::vector<std::string>> point_to_labels;
parse_label_file_into_vec(npoints, label_file, point_to_labels);
std::vector<label_set> point_ids_to_labels(point_to_labels.size());
std::vector<label_set> query_ids_to_labels(filter_labels.size());
for (size_t i = 0; i < point_to_labels.size(); i++)
{
for (size_t j = 0; j < point_to_labels[i].size(); j++)
{
std::string label = point_to_labels[i][j];
if (all_labels.find(label) != all_labels.end())
{
point_ids_to_labels[i].insert(point_to_labels[i][j]);
if (labels_to_number_of_points.find(label) == labels_to_number_of_points.end())
{
labels_to_number_of_points[label] = 0;
}
labels_to_number_of_points[label] += 1;
}
}
}
for (size_t i = 0; i < filter_labels.size(); i++)
{
query_ids_to_labels[i].insert(filter_labels[i]);
}
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id;
tsl::robin_map<std::string, std::vector<uint32_t>> label_query_id_to_orig_id;
if (data_type == std::string("float"))
{
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
query_file, labels_to_number_of_queries, query_ids_to_labels,
all_labels); // query_filters acts like query_ids_to_labels
}
else if (data_type == std::string("int8"))
{
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
query_file, labels_to_number_of_queries, query_ids_to_labels,
all_labels); // query_filters acts like query_ids_to_labels
}
else if (data_type == std::string("uint8"))
{
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
query_file, labels_to_number_of_queries, query_ids_to_labels,
all_labels); // query_filters acts like query_ids_to_labels
}
else
{
diskann::cerr << "Invalid data type" << std::endl;
return -1;
}
// Generate label specific ground truths
try
{
for (const auto &label : all_labels)
{
std::string filtered_base_file = base_file + "_" + label;
std::string filtered_query_file = query_file + "_" + label;
std::string filtered_gt_file = gt_file + "_" + label;
if (data_type == std::string("float"))
aux_main<float>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
if (data_type == std::string("int8"))
aux_main<int8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
if (data_type == std::string("uint8"))
aux_main<uint8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
}
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Compute GT failed." << std::endl;
return -1;
}
// Combine the label specific ground truths to produce a single GT file
uint32_t *gt_ids = nullptr;
float *gt_dists = nullptr;
size_t gt_num, gt_dim;
std::vector<std::vector<int32_t>> final_gt_ids;
std::vector<std::vector<float>> final_gt_dists;
uint32_t query_num = 0;
for (const auto &lbl : all_labels)
{
query_num += labels_to_number_of_queries[lbl];
}
for (uint32_t i = 0; i < query_num; i++)
{
final_gt_ids.push_back(std::vector<int32_t>(K));
final_gt_dists.push_back(std::vector<float>(K));
}
for (const auto &lbl : all_labels)
{
std::string filtered_gt_file = gt_file + "_" + lbl;
load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim);
for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++)
{
uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i];
for (uint64_t j = 0; j < K; j++)
{
final_gt_ids[orig_query_id][j] = label_id_to_orig_id[lbl][gt_ids[i * K + j]];
final_gt_dists[orig_query_id][j] = gt_dists[i * K + j];
}
}
}
int32_t *closest_points = new int32_t[query_num * K];
float *dist_closest_points = new float[query_num * K];
for (uint32_t i = 0; i < query_num; i++)
{
for (uint32_t j = 0; j < K; j++)
{
closest_points[i * K + j] = final_gt_ids[i][j];
dist_closest_points[i * K + j] = final_gt_dists[i][j];
}
}
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, query_num, K);
// cleanup artifacts
std::cout << "Cleaning up artifacts..." << std::endl;
tsl::robin_set<std::string> paths_to_clean{gt_file, base_file, query_file};
clean_up_artifacts(paths_to_clean, all_labels);
}
}

View File

@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <numeric>
#include <omp.h>
#include <set>
#include <string.h>
#include <boost/program_options.hpp>
#ifndef _WINDOWS
#include <sys/mman.h>
#include <sys/stat.h>
#include <time.h>
#include <unistd.h>
#endif
#include "utils.h"
#include "index.h"
#include "memory_mapper.h"
namespace po = boost::program_options;
template <typename T> void bfs_count(const std::string &index_path, uint32_t data_dims)
{
using TagT = uint32_t;
using LabelT = uint32_t;
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false, false,
false, 0, false);
std::cout << "Index class instantiated" << std::endl;
index.load(index_path.c_str(), 1, 100);
std::cout << "Index loaded" << std::endl;
index.count_nodes_at_bfs_levels();
}
int main(int argc, char **argv)
{
std::string data_type, index_path_prefix;
uint32_t data_dims;
po::options_description desc{"Arguments"};
try
{
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
"Path prefix to the index");
desc.add_options()("data_dims", po::value<uint32_t>(&data_dims)->required(), "Dimensionality of the data");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
try
{
if (data_type == std::string("int8"))
bfs_count<int8_t>(index_path_prefix, data_dims);
else if (data_type == std::string("uint8"))
bfs_count<uint8_t>(index_path_prefix, data_dims);
if (data_type == std::string("float"))
bfs_count<float>(index_path_prefix, data_dims);
}
catch (std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index BFS failed." << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <limits>
#include <vector>
#include "utils.h"
#include "disk_utils.h"
#include "cached_io.h"
template <typename T> int create_disk_layout(char **argv)
{
std::string base_file(argv[2]);
std::string vamana_file(argv[3]);
std::string output_file(argv[4]);
diskann::create_disk_layout<T>(base_file, vamana_file, output_file);
return 0;
}
int main(int argc, char **argv)
{
if (argc != 5)
{
std::cout << argv[0]
<< " data_type <float/int8/uint8> data_bin "
"vamana_index_file output_diskann_index_file"
<< std::endl;
exit(-1);
}
int ret_val = -1;
if (std::string(argv[1]) == std::string("float"))
ret_val = create_disk_layout<float>(argv);
else if (std::string(argv[1]) == std::string("int8"))
ret_val = create_disk_layout<int8_t>(argv);
else if (std::string(argv[1]) == std::string("uint8"))
ret_val = create_disk_layout<uint8_t>(argv);
else
{
std::cout << "unsupported type. use int8/uint8/float " << std::endl;
ret_val = -2;
}
return ret_val;
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts,
size_t ndims, float bias, float scale)
{
reader.read((char *)read_buf, npts * ndims * sizeof(float));
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; d++)
{
write_buf[d + i * ndims] = (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale));
}
}
writer.write((char *)write_buf, npts * ndims);
}
int main(int argc, char **argv)
{
if (argc != 5)
{
std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" << std::endl;
exit(-1);
}
std::ifstream reader(argv[1], std::ios::binary);
uint32_t npts_u32;
uint32_t ndims_u32;
reader.read((char *)&npts_u32, sizeof(uint32_t));
reader.read((char *)&ndims_u32, sizeof(uint32_t));
size_t npts = npts_u32;
size_t ndims = ndims_u32;
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::ofstream writer(argv[2], std::ios::binary);
auto read_buf = new float[blk_size * ndims];
auto write_buf = new int8_t[blk_size * ndims];
float bias = (float)atof(argv[3]);
float scale = (float)atof(argv[4]);
writer.write((char *)(&npts_u32), sizeof(uint32_t));
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
writer.close();
reader.close();
}

View File

@@ -0,0 +1,95 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
// Convert float types
void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts,
size_t ndims)
{
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
for (size_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
}
writer.write((char *)write_buf, npts * ndims * sizeof(float));
}
// Convert byte types
void block_convert_byte(std::ifstream &reader, std::ofstream &writer, uint8_t *read_buf, uint8_t *write_buf,
size_t npts, size_t ndims)
{
reader.read((char *)read_buf, npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t)));
for (size_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t),
ndims * sizeof(uint8_t));
}
writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t));
}
int main(int argc, char **argv)
{
if (argc != 4)
{
std::cout << argv[0] << " <float/int8/uint8> input_vecs output_bin" << std::endl;
exit(-1);
}
int datasize = sizeof(float);
if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0)
{
datasize = sizeof(uint8_t);
}
else if (strcmp(argv[1], "float") != 0)
{
std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl;
exit(-1);
}
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
size_t fsize = reader.tellg();
reader.seekg(0, std::ios::beg);
uint32_t ndims_u32;
reader.read((char *)&ndims_u32, sizeof(uint32_t));
reader.seekg(0, std::ios::beg);
size_t ndims = (size_t)ndims_u32;
size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t));
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writer(argv[3], std::ios::binary);
int32_t npts_s32 = (int32_t)npts;
int32_t ndims_s32 = (int32_t)ndims;
writer.write((char *)&npts_s32, sizeof(int32_t));
writer.write((char *)&ndims_s32, sizeof(int32_t));
size_t chunknpts = std::min(npts, blk_size);
uint8_t *read_buf = new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))];
uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize];
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
if (datasize == sizeof(float))
{
block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, cblk_size, ndims);
}
else
{
block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims);
}
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
reader.close();
writer.close();
}

View File

@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts,
size_t ndims)
{
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
for (size_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(uint32_t));
for (size_t d = 0; d < ndims; d++)
write_buf[i * (ndims + 4) + 4 + d] = (uint8_t)read_buf[i * (ndims + 1) + 1 + d];
}
writer.write((char *)write_buf, npts * (ndims * 1 + 4));
}
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl;
exit(-1);
}
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
size_t fsize = reader.tellg();
reader.seekg(0, std::ios::beg);
uint32_t ndims_u32;
reader.read((char *)&ndims_u32, sizeof(uint32_t));
reader.seekg(0, std::ios::beg);
size_t ndims = (size_t)ndims_u32;
size_t npts = fsize / ((ndims + 1) * sizeof(float));
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writer(argv[2], std::ios::binary);
auto read_buf = new float[npts * (ndims + 1)];
auto write_buf = new uint8_t[npts * (ndims + 4)];
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
reader.close();
writer.close();
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <omp.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <ctime>
#include <iostream>
#include <iterator>
#include <map>
#include <sstream>
#include <string>
#include "partition.h"
#include "utils.h"
#include <fcntl.h>
#include <sys/stat.h>
#include <time.h>
#include <typeinfo>
template <typename T> int aux_main(char **argv)
{
std::string base_file(argv[2]);
std::string output_prefix(argv[3]);
float sampling_rate = (float)(std::atof(argv[4]));
gen_random_slice<T>(base_file, output_prefix, sampling_rate);
return 0;
}
int main(int argc, char **argv)
{
if (argc != 5)
{
std::cout << argv[0]
<< " data_type [float/int8/uint8] base_bin_file "
"sample_output_prefix sampling_probability"
<< std::endl;
exit(-1);
}
if (std::string(argv[1]) == std::string("float"))
{
aux_main<float>(argv);
}
else if (std::string(argv[1]) == std::string("int8"))
{
aux_main<int8_t>(argv);
}
else if (std::string(argv[1]) == std::string("uint8"))
{
aux_main<uint8_t>(argv);
}
else
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
return 0;
}

View File

@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include "math_utils.h"
#include "pq.h"
#include "partition.h"
#define KMEANS_ITERS_FOR_PQ 15
template <typename T>
bool generate_pq(const std::string &data_path, const std::string &index_prefix_path, const size_t num_pq_centers,
const size_t num_pq_chunks, const float sampling_rate, const bool opq)
{
std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin";
// generates random sample and sets it to train_data and updates train_size
size_t train_size, train_dim;
float *train_data;
gen_random_slice<T>(data_path, sampling_rate, train_data, train_size, train_dim);
std::cout << "For computing pivots, loaded sample data of size " << train_size << std::endl;
if (opq)
{
diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
(uint32_t)num_pq_chunks, pq_pivots_path, true);
}
else
{
diskann::generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
(uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path);
}
diskann::generate_pq_data_from_pivots<T>(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks,
pq_pivots_path, pq_compressed_vectors_path, true);
delete[] train_data;
return 0;
}
int main(int argc, char **argv)
{
if (argc != 7)
{
std::cout << "Usage: \n"
<< argv[0]
<< " <data_type[float/uint8/int8]> <data_file[.bin]>"
" <PQ_prefix_path> <target-bytes/data-point> "
"<sampling_rate> <PQ(0)/OPQ(1)>"
<< std::endl;
}
else
{
const std::string data_path(argv[2]);
const std::string index_prefix_path(argv[3]);
const size_t num_pq_centers = 256;
const size_t num_pq_chunks = (size_t)atoi(argv[4]);
const float sampling_rate = (float)atof(argv[5]);
const bool opq = atoi(argv[6]) == 0 ? false : true;
if (std::string(argv[1]) == std::string("float"))
generate_pq<float>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
else if (std::string(argv[1]) == std::string("int8"))
generate_pq<int8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
else if (std::string(argv[1]) == std::string("uint8"))
generate_pq<uint8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
else
std::cout << "Error. wrong file type" << std::endl;
}
}

View File

@@ -0,0 +1,204 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include <random>
#include <boost/program_options.hpp>
#include <math.h>
#include <cmath>
#include "utils.h"
namespace po = boost::program_options;
class ZipfDistribution
{
public:
ZipfDistribution(uint64_t num_points, uint32_t num_labels)
: num_labels(num_labels), num_points(num_points),
uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0))
{
}
std::unordered_map<uint32_t, uint32_t> createDistributionMap()
{
std::unordered_map<uint32_t, uint32_t> map;
uint32_t primary_label_freq = (uint32_t)ceil(num_points * distribution_factor);
for (uint32_t i{1}; i < num_labels + 1; i++)
{
map[i] = (uint32_t)ceil(primary_label_freq / i);
}
return map;
}
int writeDistribution(std::ofstream &outfile)
{
auto distribution_map = createDistributionMap();
for (uint32_t i{0}; i < num_points; i++)
{
bool label_written = false;
for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); it++)
{
auto label_selection_probability = std::bernoulli_distribution(distribution_factor / (double)it->first);
if (label_selection_probability(rand_engine) && distribution_map[it->first] > 0)
{
if (label_written)
{
outfile << ',';
}
outfile << it->first;
label_written = true;
// remove label from map if we have used all labels
distribution_map[it->first] -= 1;
}
}
if (!label_written)
{
outfile << 0;
}
if (i < num_points - 1)
{
outfile << '\n';
}
}
return 0;
}
int writeDistribution(std::string filename)
{
std::ofstream outfile(filename);
if (!outfile.is_open())
{
std::cerr << "Error: could not open output file " << filename << '\n';
return -1;
}
writeDistribution(outfile);
outfile.close();
}
private:
const uint32_t num_labels;
const uint64_t num_points;
const double distribution_factor = 0.7;
std::knuth_b rand_engine;
const std::uniform_real_distribution<double> uniform_zero_to_one;
};
int main(int argc, char **argv)
{
std::string output_file, distribution_type;
uint32_t num_labels;
uint64_t num_points;
try
{
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("output_file,O", po::value<std::string>(&output_file)->required(),
"Filename for saving the label file");
desc.add_options()("num_points,N", po::value<uint64_t>(&num_points)->required(), "Number of points in dataset");
desc.add_options()("num_labels,L", po::value<uint32_t>(&num_labels)->required(),
"Number of unique labels, up to 5000");
desc.add_options()("distribution_type,DT", po::value<std::string>(&distribution_type)->default_value("random"),
"Distribution function for labels <random/zipf/one_per_point> defaults "
"to random");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
if (num_labels > 5000)
{
std::cerr << "Error: num_labels must be 5000 or less" << '\n';
return -1;
}
if (num_points <= 0)
{
std::cerr << "Error: num_points must be greater than 0" << '\n';
return -1;
}
std::cout << "Generating synthetic labels for " << num_points << " points with " << num_labels << " unique labels"
<< '\n';
try
{
std::ofstream outfile(output_file);
if (!outfile.is_open())
{
std::cerr << "Error: could not open output file " << output_file << '\n';
return -1;
}
if (distribution_type == "zipf")
{
ZipfDistribution zipf(num_points, num_labels);
zipf.writeDistribution(outfile);
}
else if (distribution_type == "random")
{
for (size_t i = 0; i < num_points; i++)
{
bool label_written = false;
for (size_t j = 1; j <= num_labels; j++)
{
// 50% chance to assign each label
if (rand() > (RAND_MAX / 2))
{
if (label_written)
{
outfile << ',';
}
outfile << j;
label_written = true;
}
}
if (!label_written)
{
outfile << 0;
}
if (i < num_points - 1)
{
outfile << '\n';
}
}
}
else if (distribution_type == "one_per_point")
{
std::random_device rd; // obtain a random number from hardware
std::mt19937 gen(rd()); // seed the generator
std::uniform_int_distribution<> distr(0, num_labels); // define the range
for (size_t i = 0; i < num_points; i++)
{
outfile << distr(gen);
if (i != num_points - 1)
outfile << '\n';
}
}
if (outfile.is_open())
{
outfile.close();
}
std::cout << "Labels written to " << output_file << '\n';
}
catch (const std::exception &ex)
{
std::cerr << "Label generation failed: " << ex.what() << '\n';
return -1;
}
return 0;
}

View File

@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl;
exit(-1);
}
int8_t *input;
size_t npts, nd;
diskann::load_bin<int8_t>(argv[1], input, npts, nd);
float *output = new float[npts * nd];
diskann::convert_types<int8_t, float>(input, output, npts, nd);
diskann::save_bin<float>(argv[2], output, npts, nd);
delete[] output;
delete[] input;
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts,
size_t ndims, float bias, float scale)
{
reader.read((char *)read_buf, npts * ndims * sizeof(int8_t));
for (size_t i = 0; i < npts; i++)
{
for (size_t d = 0; d < ndims; d++)
{
write_buf[d + i * ndims] = (((float)read_buf[d + i * ndims] - bias) * scale);
}
}
writer.write((char *)write_buf, npts * ndims * sizeof(float));
}
int main(int argc, char **argv)
{
if (argc != 5)
{
std::cout << "Usage: " << argv[0] << " input-int8.bin output-float.bin bias scale" << std::endl;
exit(-1);
}
std::ifstream reader(argv[1], std::ios::binary);
uint32_t npts_u32;
uint32_t ndims_u32;
reader.read((char *)&npts_u32, sizeof(uint32_t));
reader.read((char *)&ndims_u32, sizeof(uint32_t));
size_t npts = npts_u32;
size_t ndims = ndims_u32;
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::ofstream writer(argv[2], std::ios::binary);
auto read_buf = new int8_t[blk_size * ndims];
auto write_buf = new float[blk_size * ndims];
float bias = (float)atof(argv[3]);
float scale = (float)atof(argv[4]);
writer.write((char *)(&npts_u32), sizeof(uint32_t));
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
writer.close();
reader.close();
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include "utils.h"
void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts,
size_t ndims)
{
reader.read((char *)read_buf, npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t)));
for (size_t i = 0; i < npts; i++)
{
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(uint32_t));
}
writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t));
}
int main(int argc, char **argv)
{
if (argc != 3)
{
std::cout << argv[0] << " input_ivecs output_bin" << std::endl;
exit(-1);
}
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
size_t fsize = reader.tellg();
reader.seekg(0, std::ios::beg);
uint32_t ndims_u32;
reader.read((char *)&ndims_u32, sizeof(uint32_t));
reader.seekg(0, std::ios::beg);
size_t ndims = (size_t)ndims_u32;
size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t));
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
std::ofstream writer(argv[2], std::ios::binary);
int npts_s32 = (int)npts;
int ndims_s32 = (int)ndims;
writer.write((char *)&npts_s32, sizeof(int));
writer.write((char *)&ndims_s32, sizeof(int));
uint32_t *read_buf = new uint32_t[npts * (ndims + 1)];
uint32_t *write_buf = new uint32_t[npts * ndims];
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
std::cout << "Block #" << i << " written" << std::endl;
}
delete[] read_buf;
delete[] write_buf;
reader.close();
writer.close();
}

View File

@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <algorithm>
#include <atomic>
#include <cassert>
#include <fstream>
#include <iostream>
#include <set>
#include <string>
#include <vector>
#include "disk_utils.h"
#include "cached_io.h"
#include "utils.h"
int main(int argc, char **argv)
{
if (argc != 9)
{
std::cout << argv[0]
<< " vamana_index_prefix[1] vamana_index_suffix[2] "
"idmaps_prefix[3] "
"idmaps_suffix[4] n_shards[5] max_degree[6] "
"output_vamana_path[7] "
"output_medoids_path[8]"
<< std::endl;
exit(-1);
}
std::string vamana_prefix(argv[1]);
std::string vamana_suffix(argv[2]);
std::string idmaps_prefix(argv[3]);
std::string idmaps_suffix(argv[4]);
uint64_t nshards = (uint64_t)std::atoi(argv[5]);
uint32_t max_degree = (uint64_t)std::atoi(argv[6]);
std::string output_index(argv[7]);
std::string output_medoids(argv[8]);
return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, idmaps_suffix, nshards, max_degree,
output_index, output_medoids);
}

View File

@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <index.h>
#include <math_utils.h>
#include "cached_io.h"
#include "partition.h"
// DEPRECATED: NEED TO REPROGRAM
int main(int argc, char **argv)
{
if (argc != 7)
{
std::cout << "Usage:\n"
<< argv[0]
<< " datatype<int8/uint8/float> <data_path>"
" <prefix_path> <sampling_rate> "
" <num_partitions> <k_index>"
<< std::endl;
exit(-1);
}
const std::string data_path(argv[2]);
const std::string prefix_path(argv[3]);
const float sampling_rate = (float)atof(argv[4]);
const size_t num_partitions = (size_t)std::atoi(argv[5]);
const size_t max_reps = 15;
const size_t k_index = (size_t)std::atoi(argv[6]);
if (std::string(argv[1]) == std::string("float"))
partition<float>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
else if (std::string(argv[1]) == std::string("int8"))
partition<int8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
else if (std::string(argv[1]) == std::string("uint8"))
partition<uint8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
else
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
}

View File

@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <index.h>
#include <math_utils.h>
#include "cached_io.h"
#include "partition.h"
// DEPRECATED: NEED TO REPROGRAM
int main(int argc, char **argv)
{
if (argc != 8)
{
std::cout << "Usage:\n"
<< argv[0]
<< " datatype<int8/uint8/float> <data_path>"
" <prefix_path> <sampling_rate> "
" <ram_budget(GB)> <graph_degree> <k_index>"
<< std::endl;
exit(-1);
}
const std::string data_path(argv[2]);
const std::string prefix_path(argv[3]);
const float sampling_rate = (float)atof(argv[4]);
const double ram_budget = (double)std::atof(argv[5]);
const size_t graph_degree = (size_t)std::atoi(argv[6]);
const size_t k_index = (size_t)std::atoi(argv[7]);
if (std::string(argv[1]) == std::string("float"))
partition_with_ram_budget<float>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
else if (std::string(argv[1]) == std::string("int8"))
partition_with_ram_budget<int8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
else if (std::string(argv[1]) == std::string("uint8"))
partition_with_ram_budget<uint8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
else
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
}

View File

@@ -0,0 +1,237 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include <cstdlib>
#include <random>
#include <cmath>
#include <boost/program_options.hpp>
#include "utils.h"
namespace po = boost::program_options;
int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm,
float rand_scale)
{
auto vec = new float[ndims];
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<> normal_rand{0, 1};
std::uniform_real_distribution<> unif_dis(1.0, rand_scale);
for (size_t i = 0; i < npts; i++)
{
float sum = 0;
float scale = 1.0f;
if (rand_scale > 1.0f)
scale = (float)unif_dis(gen);
for (size_t d = 0; d < ndims; ++d)
vec[d] = scale * (float)normal_rand(gen);
if (normalization)
{
for (size_t d = 0; d < ndims; ++d)
sum += vec[d] * vec[d];
for (size_t d = 0; d < ndims; ++d)
vec[d] = vec[d] * norm / std::sqrt(sum);
}
writer.write((char *)vec, ndims * sizeof(float));
}
delete[] vec;
return 0;
}
int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
{
auto vec = new float[ndims];
auto vec_T = new int8_t[ndims];
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<> normal_rand{0, 1};
for (size_t i = 0; i < npts; i++)
{
float sum = 0;
for (size_t d = 0; d < ndims; ++d)
vec[d] = (float)normal_rand(gen);
for (size_t d = 0; d < ndims; ++d)
sum += vec[d] * vec[d];
for (size_t d = 0; d < ndims; ++d)
vec[d] = vec[d] * norm / std::sqrt(sum);
for (size_t d = 0; d < ndims; ++d)
{
vec_T[d] = (int8_t)std::round(vec[d]);
}
writer.write((char *)vec_T, ndims * sizeof(int8_t));
}
delete[] vec;
delete[] vec_T;
return 0;
}
int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
{
auto vec = new float[ndims];
auto vec_T = new int8_t[ndims];
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<> normal_rand{0, 1};
for (size_t i = 0; i < npts; i++)
{
float sum = 0;
for (size_t d = 0; d < ndims; ++d)
vec[d] = (float)normal_rand(gen);
for (size_t d = 0; d < ndims; ++d)
sum += vec[d] * vec[d];
for (size_t d = 0; d < ndims; ++d)
vec[d] = vec[d] * norm / std::sqrt(sum);
for (size_t d = 0; d < ndims; ++d)
{
vec_T[d] = 128 + (int8_t)std::round(vec[d]);
}
writer.write((char *)vec_T, ndims * sizeof(uint8_t));
}
delete[] vec;
delete[] vec_T;
return 0;
}
int main(int argc, char **argv)
{
std::string data_type, output_file;
size_t ndims, npts;
float norm, rand_scaling;
bool normalization = false;
try
{
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
desc.add_options()("output_file", po::value<std::string>(&output_file)->required(),
"File name for saving the random vectors");
desc.add_options()("ndims,D", po::value<uint64_t>(&ndims)->required(), "Dimensoinality of the vector");
desc.add_options()("npts,N", po::value<uint64_t>(&npts)->required(), "Number of vectors");
desc.add_options()("norm", po::value<float>(&norm)->default_value(-1.0f),
"Norm of the vectors (if not specified, vectors are not normalized)");
desc.add_options()("rand_scaling", po::value<float>(&rand_scaling)->default_value(1.0f),
"Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from "
"[1, rand_scale]. Only applicable for floating point data");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help"))
{
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception &ex)
{
std::cerr << ex.what() << '\n';
return -1;
}
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
{
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
return -1;
}
if (norm > 0.0)
{
normalization = true;
}
if (rand_scaling < 1.0)
{
std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl;
return -1;
}
if ((rand_scaling > 1.0) && (normalization == true))
{
std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl;
return -1;
}
if (data_type == std::string("int8") || data_type == std::string("uint8"))
{
if (norm > 127)
{
std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be "
"greater "
"than 127"
<< std::endl;
return -1;
}
if (rand_scaling > 1.0)
{
std::cout << "Data scaling only supported for floating point data." << std::endl;
return -1;
}
}
try
{
std::ofstream writer;
writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
writer.open(output_file, std::ios::binary);
auto npts_u32 = (uint32_t)npts;
auto ndims_u32 = (uint32_t)ndims;
writer.write((char *)&npts_u32, sizeof(uint32_t));
writer.write((char *)&ndims_u32, sizeof(uint32_t));
size_t blk_size = 131072;
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
std::cout << "# blks: " << nblks << std::endl;
int ret = 0;
for (size_t i = 0; i < nblks; i++)
{
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
if (data_type == std::string("float"))
{
ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling);
}
else if (data_type == std::string("int8"))
{
ret = block_write_int8(writer, ndims, cblk_size, norm);
}
else if (data_type == std::string("uint8"))
{
ret = block_write_uint8(writer, ndims, cblk_size, norm);
}
if (ret == 0)
std::cout << "Block #" << i << " written" << std::endl;
else
{
writer.close();
std::cout << "failed to write" << std::endl;
return -1;
}
}
writer.close();
}
catch (const std::exception &e)
{
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index build failed." << std::endl;
return -1;
}
return 0;
}

Some files were not shown because too many files have changed in this diff Show More